diff --git a/entmax/activations.py b/entmax/activations.py index 55e10ed..f245f8d 100644 --- a/entmax/activations.py +++ b/entmax/activations.py @@ -13,6 +13,7 @@ import torch import torch.nn as nn from torch.autograd import Function +from torch.cuda.amp import custom_fwd, custom_bwd def _make_ix_like(X, dim): @@ -144,6 +145,7 @@ def _entmax_threshold_and_support(X, dim=-1, k=None): class SparsemaxFunction(Function): @classmethod + @custom_fwd(cast_inputs=torch.float32) def forward(cls, ctx, X, dim=-1, k=None): ctx.dim = dim max_val, _ = X.max(dim=dim, keepdim=True) @@ -154,6 +156,7 @@ def forward(cls, ctx, X, dim=-1, k=None): return output, supp_size @classmethod + @custom_bwd def backward(cls, ctx, grad_output, supp): supp_size, output = ctx.saved_tensors dim = ctx.dim @@ -168,6 +171,7 @@ def backward(cls, ctx, grad_output, supp): class Entmax15Function(Function): @classmethod + @custom_fwd(cast_inputs=torch.float32) def forward(cls, ctx, X, dim=0, k=None): ctx.dim = dim @@ -182,6 +186,7 @@ def forward(cls, ctx, X, dim=0, k=None): return Y, supp_size @classmethod + @custom_bwd def backward(cls, ctx, dY, supp): Y, = ctx.saved_tensors gppr = Y.sqrt() # = 1 / g'' (Y) diff --git a/entmax/root_finding.py b/entmax/root_finding.py index 2df8923..d89e23c 100644 --- a/entmax/root_finding.py +++ b/entmax/root_finding.py @@ -11,6 +11,7 @@ import torch import torch.nn as nn from torch.autograd import Function +from torch.cuda.amp import custom_fwd, custom_bwd class EntmaxBisectFunction(Function): @@ -27,6 +28,7 @@ def _p(cls, X, alpha): return cls._gp_inv(torch.clamp(X, min=0), alpha) @classmethod + @custom_fwd(cast_inputs=torch.float32) def forward(cls, ctx, X, alpha=1.5, dim=-1, n_iter=50, ensure_sum_one=True): if not isinstance(alpha, torch.Tensor): @@ -71,6 +73,7 @@ def forward(cls, ctx, X, alpha=1.5, dim=-1, n_iter=50, ensure_sum_one=True): return p_m @classmethod + @custom_bwd def backward(cls, ctx, dY): Y, = ctx.saved_tensors @@ -118,12 +121,14 @@ def _p(cls, x, alpha): return torch.clamp(x, min=0) @classmethod + @custom_fwd(cast_inputs=torch.float32) def forward(cls, ctx, X, dim=-1, n_iter=50, ensure_sum_one=True): return super().forward( ctx, X, alpha=2, dim=dim, n_iter=50, ensure_sum_one=True ) @classmethod + @custom_bwd def backward(cls, ctx, dY): Y, = ctx.saved_tensors gppr = (Y > 0).to(dtype=dY.dtype) diff --git a/entmax/test_amp.py b/entmax/test_amp.py new file mode 100644 index 0000000..afd5509 --- /dev/null +++ b/entmax/test_amp.py @@ -0,0 +1,44 @@ +import pytest +import torch +from functools import partial + +from entmax import entmax15, sparsemax, entmax_bisect + +torch.manual_seed(42) + +def make_negatives(dtype, max_pow): + negatives = [] + for i in range(2, max_pow + 1): + negative = torch.randn(128, dtype=dtype, device="cuda") - 10 ** i + negative[0] += 5 + negatives.append(negative) + return negatives + + +if torch.cuda.is_available(): + + mappings = [entmax15, sparsemax, partial(entmax_bisect, alpha=1.5), partial(entmax_bisect, alpha=2)] + + long_bf16 = [ + torch.randn(32000, dtype=torch.bfloat16, device="cuda") + for _ in range(5) + ] + negatives_bf16 = make_negatives(torch.bfloat16, 7) + + long_fp16 = [ + torch.randn(32000, dtype=torch.float16, device="cuda") + for _ in range(5) + ] + negatives_fp16 = make_negatives(torch.float16, 4) + + @pytest.mark.parametrize("Xs", (long_bf16, negatives_bf16, long_fp16, negatives_fp16)) + @pytest.mark.parametrize("func", mappings) + def test_probs_close(Xs, func): + dtype = Xs[0].dtype + + full_precision_probs = [func(X.to(torch.float32), dim=-1) for X in Xs] + _Xs = [X.to(dtype) for X in Xs] + with torch.autocast(device_type="cuda", dtype=dtype): + for _X, fpp in zip(_Xs, full_precision_probs): + probs = func(_X, dim=-1) + assert torch.allclose(probs, fpp)