Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions ci/pytorch.sh
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ run_test_config(){
run_default_fa 1 triton_kernels/test_cast.py
run_default_fa 1 triton_kernels/test_cast_mxfp8.py
run_default_fa 1 triton_kernels/test_norm_common.py
run_default_fa 1 triton_kernels/test_norms.py
NVTE_TEST_TRITON_AUTOTUNE=1 run_default_fa_lbl "autotune" 3 triton_kernels/test_norms.py
NVTE_ROCM_ENABLE_MXFP8=1 run_default_fa 1 triton_kernels/test_norms.py
NVTE_ROCM_ENABLE_MXFP8=1 NVTE_TEST_TRITON_AUTOTUNE=1 run_default_fa_lbl "autotune" 3 triton_kernels/test_norms.py
run_default_fa 1 test_parallel_cross_entropy.py
NVTE_USE_DEQUANTIZE_TRITON=1 NVTE_USE_CAST_TRANSPOSE_TRITON=1 NVTE_USE_RMSNORM_TRITON=1 NVTE_USE_LAYERNORM_TRITON=1 run_default_fa_lbl "triton" 3 test_numerics.py
NVTE_USE_CAST_TRANSPOSE_TRITON=1 NVTE_USE_RMSNORM_TRITON=1 run_default_fa_lbl "triton" 1 test_fusible_ops.py
Expand Down
62 changes: 32 additions & 30 deletions tests/pytorch/triton_kernels/test_norms.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,21 @@
)
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer, Float8Tensor
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer, MXFP8Tensor
from transformer_engine.pytorch.triton_kernels.rmsnorm import (
te_rmsnorm_bwd_triton,
te_rmsnorm_fwd_triton,
)
from transformer_engine.pytorch.triton_kernels.layernorm import (
from transformer_engine.pytorch.triton_kernels.norms import (
te_layernorm_bwd_triton,
te_layernorm_fwd_triton,
te_rmsnorm_bwd_triton,
te_rmsnorm_fwd_triton,
)
from test_common import dtype_tols, te_compare_results, str_to_torch_dtype, fill_uniform

def _compare_func(actual, expected, atol, rtol, msg, use_torch_semantics=False):
try:
te_compare_results(actual, expected, atol, rtol, msg, use_torch_semantics)
except AssertionError as e:
if "Tensor 'expected' has" in str(e) and "NaN" in str(e):
pytest.skip("HIP reference tensor contains NaN values.")

# Check if FP8 is supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
Expand Down Expand Up @@ -406,11 +411,11 @@ def _compare_output_tensors(
quantization, fp8_dtype
):
tols = dtype_tols(out_triton.dtype if quantization is None else fp8_dtype)
_compare_func = partial(te_compare_results, **tols, use_torch_semantics=True)
compare_func = partial(_compare_func, **tols, use_torch_semantics=True)

dq_out_triton = out_triton.dequantize()
dq_out_hip = out_hip.dequantize()
_compare_func(
compare_func(
actual=dq_out_triton,
expected=dq_out_hip,
msg=lambda msg: f"Output does not match triton <-> hip\n\n{msg}\n",
Expand All @@ -428,7 +433,7 @@ def _compare_output_tensors(
if not out_hip._transpose_invalid:
# The transpose data are generally uint8 so we must convert
# them for floating point comparison.
_compare_func(
compare_func(
actual=out_triton._transpose.view(te_dtype_to_torch_dtype(out_triton._fp8_dtype)).to(torch.float32),
expected=out_hip._transpose.view(te_dtype_to_torch_dtype(out_triton._fp8_dtype)).to(torch.float32),
msg=lambda msg: f"Output transpose does not match triton <-> hip\n\n{msg}\n",
Expand All @@ -438,11 +443,8 @@ def _compare_output_tensors(
if not isinstance(out_triton, MXFP8Tensor):
raise ValueError(f"Expected a MXFP8Tensor but got {type(out_triton)} instead.")

# TODO(micky774): Figure out if we need to apply the same view
# trick to MXFP8 data as we do to FP8 transpose data.
# I suspect not.
if out_hip._rowwise_data is not None:
_compare_func(
compare_func(
actual=out_triton,
expected=out_hip,
msg=lambda msg: f"Output rowwise data does not match triton <-> hip\n\n{msg}\n",
Expand All @@ -452,9 +454,9 @@ def _compare_output_tensors(
assert out_triton._rowwise_data is None, "Expected no rowwise data."

# We use higher precision for the scales
_compare_func = partial(te_compare_results, atol=1e-6, rtol=5e-5, use_torch_semantics=True)
compare_func = partial(_compare_func, atol=1e-6, rtol=5e-5, use_torch_semantics=True)
if quantization == "fp8":
_compare_func(
compare_func(
actual=out_triton._scale_inv,
expected=out_hip._scale_inv,
msg=lambda msg: f"Output scale inverse does not match triton <-> hip\n\n{msg}\n",
Expand All @@ -469,9 +471,9 @@ def _compare_output_tensors(
msg += "be None."
raise ValueError(msg)
if has_rscale_triton:
_compare_func(
actual=out_triton._rowwise_scale_inv.view(te_dtype_to_torch_dtype(out_triton._fp8_dtype)),
expected=out_hip._rowwise_scale_inv.view(te_dtype_to_torch_dtype(out_triton._fp8_dtype)),
compare_func(
actual=out_triton._rowwise_scale_inv.view(torch.uint8),
expected=out_hip._rowwise_scale_inv.view(torch.uint8),
msg=lambda msg: f"Output rowwise scale inverse does not match triton <-> hip\n\n{msg}\n",
)

Expand All @@ -484,9 +486,9 @@ def _compare_output_tensors(
msg += "be None."
raise ValueError(msg)
if has_cscale_triton:
_compare_func(
actual=out_triton._columnwise_scale_inv.view(te_dtype_to_torch_dtype(out_triton._fp8_dtype)),
expected=out_hip._columnwise_scale_inv.view(te_dtype_to_torch_dtype(out_triton._fp8_dtype)),
compare_func(
actual=out_triton._columnwise_scale_inv.view(torch.uint8),
expected=out_hip._columnwise_scale_inv.view(torch.uint8),
msg=lambda msg: f"Output columnwise scale inverse does not match triton <-> hip\n\n{msg}\n",
)

Expand All @@ -497,7 +499,7 @@ def _compare_quantizers(
quantization
):
if quantization is None: return
_compare_func = partial(te_compare_results, atol=1e-6, rtol=5e-5, use_torch_semantics=True)
compare_func = partial(_compare_func, atol=1e-6, rtol=5e-5, use_torch_semantics=True)

if quantizer_triton.dtype != quantizer_hip.dtype:
raise ValueError("Expected matching quantizer dtypes, but got "
Expand All @@ -511,12 +513,12 @@ def _compare_quantizers(
raise ValueError(f"Expected matching quantizer {usage} but got {qt_usage=} != {qh_usage=}")

if quantization == "fp8":
_compare_func(
compare_func(
actual=quantizer_triton.scale,
expected=quantizer_hip.scale,
msg=lambda msg: f"Quantizer scale does not match triton <-> hip\n\n{msg}\n",
)
_compare_func(
compare_func(
actual=quantizer_triton.amax,
expected=quantizer_hip.amax,
msg=lambda msg: f"Quantizer amax does not match triton <-> hip\n\n{msg}\n",
Expand All @@ -529,15 +531,15 @@ def _compare_stat_tensors(
norm
):
# We use higher precision for the remaining outputs
_compare_func = partial(te_compare_results, atol=1e-6, rtol=5e-5, use_torch_semantics=True)
compare_func = partial(_compare_func, atol=1e-6, rtol=5e-5, use_torch_semantics=True)

_compare_func(
compare_func(
actual=rsigma_triton,
expected=rsigma_hip,
msg=lambda msg: f"rsigma does not match triton <-> hip\n\n{msg}\n",
)
if norm == "layer":
_compare_func(
compare_func(
actual=mu_triton,
expected=mu_hip,
msg=lambda msg: f"mu does not match triton <-> hip\n\n{msg}\n",
Expand Down Expand Up @@ -579,20 +581,20 @@ def _compare_bwd_tensors(
dbeta_triton, dbeta_hip,
norm
):
_compare_func = partial(te_compare_results, atol=1.5e-4, rtol=1e-4, use_torch_semantics=True)
compare_func = partial(_compare_func, atol=1.5e-4, rtol=1e-4, use_torch_semantics=True)

_compare_func(
compare_func(
actual=dx_triton,
expected=dx_hip,
msg=lambda msg: f"dx does not match triton <-> hip\n\n{msg}\n",
)
_compare_func(
compare_func(
actual=dgamma_triton,
expected=dgamma_hip,
msg=lambda msg: f"dgamma does not match triton <-> hip\n\n{msg}\n",
)
if norm == "layer":
_compare_func(
compare_func(
actual=dbeta_triton,
expected=dbeta_hip,
msg=lambda msg: f"dbeta does not match triton <-> hip\n\n{msg}\n",
Expand Down
8 changes: 6 additions & 2 deletions transformer_engine/pytorch/module/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,12 @@
from ..export import is_in_onnx_export_mode

if IS_HIP_EXTENSION:
from ..triton_kernels.layernorm import te_layernorm_fwd_triton, te_layernorm_bwd_triton
from ..triton_kernels.rmsnorm import te_rmsnorm_bwd_triton, te_rmsnorm_fwd_triton
from ..triton_kernels.norms import (
te_layernorm_fwd_triton,
te_layernorm_bwd_triton,
te_rmsnorm_fwd_triton,
te_rmsnorm_bwd_triton
)

def _get_normalization_func(normalization: str, forward: bool):
use_rmsnorm_triton = bool( int(os.environ.get('NVTE_USE_RMSNORM_TRITON', '0')) ) and IS_HIP_EXTENSION
Expand Down
3 changes: 1 addition & 2 deletions transformer_engine/pytorch/module/layernorm_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,7 @@
)

if IS_HIP_EXTENSION:
from ..triton_kernels.layernorm import te_layernorm_bwd_triton
from ..triton_kernels.rmsnorm import te_rmsnorm_bwd_triton
from ..triton_kernels.norms import te_layernorm_bwd_triton, te_rmsnorm_bwd_triton


__all__ = ["LayerNormLinear"]
Expand Down
3 changes: 1 addition & 2 deletions transformer_engine/pytorch/module/layernorm_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,7 @@
from ...debug.pytorch.debug_state import TEDebugState

if IS_HIP_EXTENSION:
from ..triton_kernels.layernorm import te_layernorm_bwd_triton
from ..triton_kernels.rmsnorm import te_rmsnorm_bwd_triton
from ..triton_kernels.norms import te_layernorm_bwd_triton, te_rmsnorm_bwd_triton

__all__ = ["LayerNormMLP"]

Expand Down
4 changes: 3 additions & 1 deletion transformer_engine/pytorch/ops/basic/layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
from transformer_engine_torch import layernorm_bwd, layernorm_fwd
from torch.utils.cpp_extension import IS_HIP_EXTENSION
if IS_HIP_EXTENSION:
from ...triton_kernels.layernorm import te_layernorm_fwd_triton, te_layernorm_bwd_triton
from ...triton_kernels.norms import te_layernorm_fwd_triton, te_layernorm_bwd_triton
from ...fp8 import FP8GlobalStateManager
from ...tensor import QuantizedTensor
from ...constants import TE_DType
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...export import is_in_onnx_export_mode
Expand Down
2 changes: 1 addition & 1 deletion transformer_engine/pytorch/ops/basic/rmsnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from transformer_engine_torch import rmsnorm_bwd, rmsnorm_fwd
from torch.utils.cpp_extension import IS_HIP_EXTENSION
if IS_HIP_EXTENSION:
from ...triton_kernels.rmsnorm import (
from ...triton_kernels.norms import (
te_rmsnorm_bwd_triton,
te_rmsnorm_fwd_triton
)
Expand Down
Loading