diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 6ec3c27a4..f1f6fb739 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -16,6 +16,7 @@ assert_allclose, pytest_parametrize_wrapper, use_jax_gemm, + _check_mxfp8_gemm_support, ) from transformer_engine.jax.layernorm import layernorm from transformer_engine.jax.layernorm_mlp import layernorm_mlp @@ -57,7 +58,9 @@ (2048, 2048, 1024), (2048, 1024, 1024), ] - +TEST_SHAPES = [(64, 32, 64)] +if is_hip_extension(): + TEST_SHAPES += [(64, 64, 128), (128, 256, 256)] jnp_float8_e4m3_type = get_jnp_float8_e4m3_type() jnp_float8_e5m2_type = get_jnp_float8_e5m2_type() @@ -103,7 +106,7 @@ def assert_bitwise_scaled_tensors( assert_allclose(a.scale_inv, b.scale_inv, dtype=a.dq_dtype) elif a.scaling_mode == ScalingMode.MXFP8_1D_SCALING: # Compare MXFP8 scales as uint8 - assert_allclose(a.scale_inv.astype(jnp.uint8), b.scale_inv.astype(jnp.uint8)) + assert_allclose(a.scale_inv.view(jnp.uint8), b.scale_inv.view(jnp.uint8)) else: raise ValueError(f"Unsupported scaling mode {a.scaling_mode}") assert_allclose(a.data, b.data) @@ -917,7 +920,7 @@ def test_gemm_bf16(self, m, n, k, data_layout): assert_allclose(primitive_out, ref_out, dtype=jnp.bfloat16) @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) - @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)]) + @pytest_parametrize_wrapper("m,n,k", TEST_SHAPES) @pytest_parametrize_wrapper("x_qtype,w_qtype", valid_fp8_gemm_operand_types) @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) @pytest_parametrize_wrapper("data_layout", ["TN", "NT", "NN", "TT"]) @@ -927,8 +930,11 @@ def test_gemm_fp8(self, m, n, k, x_qtype, w_qtype, scaling_mode, data_layout, wi not with_jax_gemm and scaling_mode.is_1d_block_scaling() and jnp_float8_e5m2_type in (x_qtype, w_qtype) + and not is_hip_extension() ): pytest.skip("Float8E5M2 is not recommended for MXFP8 GEMM.") + if scaling_mode.is_1d_block_scaling(): + _check_mxfp8_gemm_support(with_jax_gemm, m, n, k) x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout) quantizer_set = QuantizerFactory.create_set( @@ -979,15 +985,23 @@ def ref_func(x, w, data_layout): assert_allclose(primitive_w_grad, ref_w_grad, dtype=jnp.bfloat16) @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) - @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)]) + @pytest_parametrize_wrapper("m,n,k", TEST_SHAPES) @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) - def test_dense_grad_fp8(self, m, n, k, scaling_mode, with_jax_gemm): + @pytest_parametrize_wrapper("use_bias", [False, True] if is_hip_extension() else [True]) + def test_dense_grad_fp8(self, m, n, k, scaling_mode, with_jax_gemm, use_bias): data_layout = "NN" x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout) key = jax.random.PRNGKey(1) - bias = jax.random.uniform(key, n, dtype=jnp.bfloat16) + bias = jax.random.uniform(key, n, dtype=jnp.bfloat16) if use_bias else None + + if scaling_mode.is_1d_block_scaling(): + # Check for first GEMM + _check_mxfp8_gemm_support(with_jax_gemm, m, n, k, use_bias) + # Check for second GEMM + _check_mxfp8_gemm_support(with_jax_gemm, m, k, n, use_bias) + def primitive_func(x, w, bias, contracting_dims, quantizer_set): primitive_out = dense( @@ -996,9 +1010,10 @@ def primitive_func(x, w, bias, contracting_dims, quantizer_set): return jnp.mean(primitive_out) def ref_func(x, w, bias, data_layout): - return jnp.mean( - self._ref_gemm_with_jnp_dot(x, w, data_layout) + jnp.expand_dims(bias, axis=0) - ) + out = self._ref_gemm_with_jnp_dot(x, w, data_layout) + if bias is not None: + out = out + jnp.expand_dims(bias, axis=0) + return jnp.mean(out) value_n_grad_primitive_func = value_and_grad(primitive_func, (0, 1, 2)) value_n_grad_ref_func = value_and_grad(ref_func, (0, 1, 2)) @@ -1024,7 +1039,8 @@ def ref_func(x, w, bias, data_layout): assert_allclose(primitive_out, ref_out, dtype=jnp_float8_e4m3_type) assert_allclose(primitive_x_grad, ref_x_grad, dtype=jnp_float8_e5m2_type) assert_allclose(primitive_w_grad, ref_w_grad, dtype=jnp_float8_e5m2_type) - assert_allclose(primitive_bias_grad, ref_bias_grad, dtype=jnp_float8_e5m2_type) + if bias is not None: + assert_allclose(primitive_bias_grad, ref_bias_grad, dtype=jnp_float8_e5m2_type) @pytest.fixture(name="random_inputs") @@ -1049,7 +1065,7 @@ def _ref_jax_norm_impl(x, gamma, beta, norm_type, zero_centered_gamma, eps, quan class TestFusedDense: @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) - @pytest.mark.parametrize("m,n,k", [(64, 32, 64)]) + @pytest.mark.parametrize("m,n,k", TEST_SHAPES) @pytest.mark.parametrize("scaling_mode", supported_scaling_modes) @pytest.mark.parametrize("norm_type", ["layernorm", "rmsnorm"]) @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) @@ -1057,6 +1073,11 @@ def test_layernorm_dense_grad(self, m, n, k, scaling_mode, norm_type, with_jax_g """ Test layernorm_dense VJP Rule """ + if scaling_mode.is_1d_block_scaling(): + # Check for fwd GEMM + _check_mxfp8_gemm_support(with_jax_gemm, m, n, k) + # Check for bwd GEMM + _check_mxfp8_gemm_support(with_jax_gemm, m, k, n) # zero_centered_gamma is already tested in TestNorm zero_centered_gamma = False eps = 1e-6 @@ -1128,7 +1149,7 @@ def ref_func(x, w, gamma, beta): assert_allclose(prim_beta_grad, ref_beta_grad, dtype=jnp_float8_e5m2_type) @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) - @pytest.mark.parametrize("m,n,k", [(64, 32, 64)]) + @pytest.mark.parametrize("m,n,k", TEST_SHAPES) @pytest.mark.parametrize("activation_type", [("gelu",), ("gelu", "linear")]) @pytest.mark.parametrize("scaling_mode", supported_scaling_modes) @pytest.mark.parametrize("norm_type", ["layernorm", "rmsnorm"]) @@ -1140,6 +1161,12 @@ def test_layernorm_mlp_grad( """ Test layernorm_mlp VJP Rule """ + if scaling_mode.is_1d_block_scaling(): + # Check for first GEMM + _check_mxfp8_gemm_support(with_jax_gemm, m, n, k, use_bias) + # Check for second GEMM + _check_mxfp8_gemm_support(with_jax_gemm, m, k, n, use_bias) + # zero_centered_gamma is already tested in TestNorm zero_centered_gamma = False eps = 1e-6 @@ -1345,6 +1372,9 @@ def test_grouped_gemm_fp16(self, dtype, input_shape, layout): @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) @pytest_parametrize_wrapper("layout", ["NN"]) def test_grouped_gemm_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape, layout): + if is_hip_extension() and scaling_mode.is_1d_block_scaling(): + pytest.skip("MXFP8 grouped GEMM is not fully supported yet in ROCm.") + fwd_dtype, bwd_dtype = fwd_bwd_dtype quantizer_set = QuantizerFactory.create_set( scaling_mode=scaling_mode, @@ -1425,6 +1455,9 @@ def test_grouped_dense_grad_fp16(self, dtype, input_shape): ) @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) def test_grouped_dense_grad_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape): + if is_hip_extension() and scaling_mode.is_1d_block_scaling(): + pytest.skip("MXFP8 grouped GEMM is not fully supported yet in ROCm.") + fwd_dtype, bwd_dtype = fwd_bwd_dtype dtype = jnp.bfloat16 x, kernel, group_sizes, contracting_dims, bias = self._generate_grouped_dense_input( diff --git a/tests/jax/test_distributed_layernorm_mlp.py b/tests/jax/test_distributed_layernorm_mlp.py index 0af10d050..074dffcc5 100644 --- a/tests/jax/test_distributed_layernorm_mlp.py +++ b/tests/jax/test_distributed_layernorm_mlp.py @@ -16,6 +16,8 @@ is_devices_enough, pytest_parametrize_wrapper, use_jax_gemm, + _check_mxfp8_layernorm_mlp_grad_support, + _check_mxfp8_layernorm_mlp_support, ) from transformer_engine.common import recipe @@ -70,6 +72,13 @@ BIAS_2_AXES = (W_NO_SHARD_AXES,) INTERMEDIATE = 64 +# We set to 256 to ensure compatibility with hipblaslt MXFP8 GEMM which +# requires the reduction dim to be multiple of 128 after sharding. +if is_hip_extension(): + INPUT_SHAPE += [[4, 64, 256]] + # TODO: Calculate intermediate size dynamically based on mesh config tpsp axis + INTERMEDIATE = 128 * 2 + # Only test with FSDP and TPSP as DP is not used def generate_fsdp_and_tpsp_configs(): @@ -167,6 +176,25 @@ def _test_layernorm_mlp_grad( use_shardy, with_jax_gemm, ): + if ( + is_hip_extension() + and (not with_jax_gemm) + and use_bias + and (fp8_recipe is None) + and (dtype == jnp.bfloat16) + ): + pytest.xfail("Skip known failure case.") + if isinstance(fp8_recipe, recipe.MXFP8BlockScaling): + _check_mxfp8_layernorm_mlp_grad_support( + input_shape[0]*input_shape[1], + INTERMEDIATE, + len(activation_type)*INTERMEDIATE, + input_shape[2], + input_shape[2], + mesh_config[1][1], + use_bias, + with_jax_gemm + ) jax.config.update("jax_use_shardy_partitioner", use_shardy) device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config layernorm_type = "rmsnorm" @@ -339,6 +367,17 @@ def _test_layernorm_mlp( use_shardy, with_jax_gemm, ): + if isinstance(fp8_recipe, recipe.MXFP8BlockScaling): + _check_mxfp8_layernorm_mlp_support( + input_shape[0]*input_shape[1], + INTERMEDIATE, + len(activation_type)*INTERMEDIATE, + input_shape[2], + input_shape[2], + mesh_config[1][1], + use_bias, + with_jax_gemm + ) jax.config.update("jax_use_shardy_partitioner", use_shardy) batch, seqlen, hidden_in = input_shape layernorm_type = "rmsnorm" diff --git a/tests/jax/test_layer.py b/tests/jax/test_layer.py index 6f672ade7..d38ca3381 100644 --- a/tests/jax/test_layer.py +++ b/tests/jax/test_layer.py @@ -17,8 +17,12 @@ dtype_tols, sync_params_values, ) -from utils import DecoderLayer as RefDecoderLayer -from utils import EncoderLayer as RefEncoderLayer +from utils import ( + DecoderLayer as RefDecoderLayer, + EncoderLayer as RefEncoderLayer, + _check_mxfp8_layernorm_mlp_grad_support, + _check_mxfp8_layernorm_mlp_support, + ) from transformer_engine.common import recipe from transformer_engine.jax.flax import TransformerLayer, TransformerLayerType @@ -521,6 +525,15 @@ def test_backward(self, data_shape, dtype, attrs): @pytest.mark.parametrize("fp8_recipe", QUANTIZE_RECIPES) def test_forward_with_fp8(self, data_shape, dtype, attrs, fp8_recipe): """Test forward with fp8 enabled""" + if isinstance(fp8_recipe, recipe.MXFP8BlockScaling): + _check_mxfp8_layernorm_mlp_support( + data_shape[0]*data_shape[1], + 2048, + 2048, + data_shape[2], + data_shape[2], + use_bias=attrs.get(_KEY_OF_USE_BIAS, False), + ) # Empty MeshResource is used as we are running on a single device with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, mesh_resource=MeshResource()): self.runner(attrs).test_forward(data_shape, dtype, rtol=1e-4, atol=1e-3) @@ -529,6 +542,15 @@ def test_forward_with_fp8(self, data_shape, dtype, attrs, fp8_recipe): @pytest.mark.parametrize("fp8_recipe", QUANTIZE_RECIPES) def test_backward_with_fp8(self, data_shape, dtype, attrs, fp8_recipe): """Test backward with fp8 enabled""" + if isinstance(fp8_recipe, recipe.MXFP8BlockScaling): + _check_mxfp8_layernorm_mlp_grad_support( + data_shape[0]*data_shape[1], + 2048, + 2048, + data_shape[2], + data_shape[2], + use_bias=attrs.get(_KEY_OF_USE_BIAS, False), + ) # Empty MeshResource is used as we are running on a single device with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, mesh_resource=MeshResource()): self.runner(attrs).test_backward(data_shape, dtype, rtol=1e-4, atol=1e-3) diff --git a/tests/jax/utils.py b/tests/jax/utils.py index 56d5df8e3..08c866a73 100644 --- a/tests/jax/utils.py +++ b/tests/jax/utils.py @@ -11,6 +11,7 @@ import operator from typing import Any, Callable, Dict, Tuple, Sequence, Union, Iterable, Optional, NewType from contextlib import contextmanager +from packaging import version import jax import jax.numpy as jnp @@ -28,6 +29,7 @@ ) from transformer_engine.jax.quantize.helper import DType as TEDType from transformer_engine.jax.util import get_jnp_float8_e4m3_type, get_jnp_float8_e5m2_type +from transformer_engine.jax.cpp_extensions.misc import is_hip_extension PRNGKey = Any Shape = Tuple[int, ...] @@ -49,6 +51,94 @@ def is_devices_enough(required): return len(jax.devices()) >= required +def _check_mxfp8_gemm_support(with_jax_gemm, m, n, k, use_bias=False): + if not is_hip_extension(): + return + + if not with_jax_gemm: + if (m % 16 != 0) or (n % 16 != 0) or (k % 128 != 0): + pytest.skip( + f"Input shape {(m, k)} x {(k, n)} is not supported by hipblaslt MXFP8 GEMM." + ) + if use_bias: + pytest.skip("hipblaslt GEMM does not yet support MXFP8 with bias.") + else: + jax_version = version.parse(jax.__version__) + if jax_version < version.parse("0.8.0"): + pytest.skip( + "MXFP8 support for JAX GEMM is added in version 0.8.0, " + f"but the current detected version is {jax_version}." + ) + +def _check_mxfp8_layernorm_mlp_support( + batch_size, + intermediate_size, + activation_size, + hidden_in, + hidden_out, + n_tp_shards=1, + use_bias=False, + with_jax_gemm=False, +): + # Check input shape compatibility with MXFP8 GEMMs + # FWD 1 + m = batch_size + k = hidden_in // n_tp_shards # Account for TP sharding + n = activation_size + _check_mxfp8_gemm_support( + with_jax_gemm, + m, n, k, + use_bias + ) + # FWD 2 + k = intermediate_size // n_tp_shards # Account for TP sharding + n = hidden_out + _check_mxfp8_gemm_support( + with_jax_gemm, + m, n, k, + use_bias + ) + +def _check_mxfp8_layernorm_mlp_grad_support( + batch_size, + intermediate_size, + activation_size, + hidden_in, + hidden_out, + n_tp_shards=1, + use_bias=False, + with_jax_gemm=False, +): + # Check forwards + _check_mxfp8_layernorm_mlp_support( + batch_size, + intermediate_size, + activation_size, + hidden_in, + hidden_out, + n_tp_shards, + use_bias, + with_jax_gemm, + ) + # BWD 1 + m = batch_size + k = hidden_out // n_tp_shards # Account for TP sharding + n = intermediate_size + _check_mxfp8_gemm_support( + with_jax_gemm, + m, n, k, + use_bias + ) + # BWD 2 + m = intermediate_size + k = batch_size // n_tp_shards # Account for TP sharding + n = hidden_out + _check_mxfp8_gemm_support( + with_jax_gemm, + m, n, k, + use_bias + ) + def _generate_drop_path_shape(shape: Sequence[int], batch_dim: int) -> Sequence[int]: # Generate broadcast dims for drop_path. drop_path_shape = list(range(0, len(shape))) diff --git a/transformer_engine/common/gemm/rocm_gemm.cu b/transformer_engine/common/gemm/rocm_gemm.cu index fef3966a5..c2c4c502a 100644 --- a/transformer_engine/common/gemm/rocm_gemm.cu +++ b/transformer_engine/common/gemm/rocm_gemm.cu @@ -1522,6 +1522,13 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, NVTE_CHECK((is_transb ? B0 : B1) == k, "GEMM inputs have incompatible dimensions (A is ", A0, "x", A1, ", B is ", B0, "x", B1, ")"); + // Check that K is a multiple of 128, and M/N are multiples of 16 for MXFP8 GEMM + if (inputA->scaling_mode == NVTE_MXFP8_1D_SCALING || inputB->scaling_mode == NVTE_MXFP8_1D_SCALING) { + NVTE_CHECK(inputBias->data.dptr == nullptr, "MXFP8 GEMM does not yet support bias."); + NVTE_CHECK((k % 128) == 0, "GEMM K dimension must be multiple of 128 for MXFP8 scaling (got K=", k, ")"); + NVTE_CHECK((m % 16) == 0, "GEMM M dimension must be multiple of 16 for MXFP8 scaling (got M=", m, ")"); + NVTE_CHECK((n % 16) == 0, "GEMM N dimension must be multiple of 16 for MXFP8 scaling (got N=", n, ")"); + } const int lda = is_transa ? k : m; const int ldb = is_transb ? n : k; diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 4ba581c66..28a42020f 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -406,8 +406,9 @@ def impl( rhs_scale_inv = apply_padding_to_scale_inv( rhs_scale_inv, scaling_mode, rhs.shape, not rhs_transposed, rhs_flatten_axis ) - lhs_scale_inv = swizzled_scale(lhs_scale_inv, lhs_flatten_axis, lhs_transposed) - rhs_scale_inv = swizzled_scale(rhs_scale_inv, rhs_flatten_axis, not rhs_transposed) + if not is_hip_extension(): + lhs_scale_inv = swizzled_scale(lhs_scale_inv, lhs_flatten_axis, lhs_transposed) + rhs_scale_inv = swizzled_scale(rhs_scale_inv, rhs_flatten_axis, not rhs_transposed) outputs = GemmPrimitive.inner_primitive.bind( lhs, diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py index 89731e24a..ecf9aa5a8 100644 --- a/transformer_engine/jax/cpp_extensions/normalization.py +++ b/transformer_engine/jax/cpp_extensions/normalization.py @@ -321,14 +321,20 @@ def impl( rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode( scaling_mode ).get_scale_shape_2x(x.shape, is_padded=False) - # slice out padding for mxfp8, noop for DelayedScaling - scale_inv = scale_inv.flatten()[: reduce(operator.mul, rowwise_scale_inv_shape, 1)].reshape( + # Slice out the padding for mxfp8 -- the kernel writes to strided 2D + # positions, not contiguous. For 1D MXFP8: allocated [padded_rows, + # padded_cols], kernel writes [:actual_rows, :actual_cols] + scale_inv = jax.lax.slice( + scale_inv, + [0] * scale_inv.ndim, rowwise_scale_inv_shape ) if is_2x: - colwise_scale_inv = colwise_scale_inv.flatten()[ - : reduce(operator.mul, colwise_scale_inv_shape, 1) - ].reshape(colwise_scale_inv_shape) + colwise_scale_inv = jax.lax.slice( + colwise_scale_inv, + [0] * colwise_scale_inv.ndim, + colwise_scale_inv_shape + ) return ( out, colwise_out, @@ -1002,20 +1008,6 @@ def layernorm_fwd( ) colwise_scale_inv = rowwise_scale_inv - # cuDNN MXFP8 Norm does not support padding but we enforced padded scale inputs for nvte APIs. - # So here we need to slice out the zero tail and reshape it to the unpadded scale shape. - # The ScaledTensorFactory takes care of padding when creating the ScaledTensor - if quantizer.scaling_mode == ScalingMode.MXFP8_1D_SCALING: - rowwise_unpadded_shape, colwise_unpadded_shape = quantizer.get_scale_shapes( - x.shape, is_padded=False - ) - rowwise_scale_inv = rowwise_scale_inv.flatten()[ - : reduce(operator.mul, rowwise_unpadded_shape) - ].reshape(rowwise_unpadded_shape) - colwise_scale_inv = colwise_scale_inv.flatten()[ - : reduce(operator.mul, colwise_unpadded_shape) - ].reshape(colwise_unpadded_shape) - scaled_tensor = ScaledTensorFactory.create( data=rowwise_casted_output, scale_inv=rowwise_scale_inv, @@ -1204,20 +1196,6 @@ def rmsnorm_fwd( ) colwise_scale_inv = rowwise_scale_inv - # cuDNN MXFP8 Norm does not support padding but we enforced padded scale inputs for nvte APIs. - # So here we need to slice out the zero tail and reshape it to the unpadded scale shape. - # The ScaledTensorFactory takes care of padding when creating the ScaledTensor - if quantizer.scaling_mode == ScalingMode.MXFP8_1D_SCALING: - rowwise_unpadded_shape, colwise_unpadded_shape = quantizer.get_scale_shapes( - x.shape, is_padded=False - ) - rowwise_scale_inv = rowwise_scale_inv.flatten()[ - : reduce(operator.mul, rowwise_unpadded_shape) - ].reshape(rowwise_unpadded_shape) - colwise_scale_inv = colwise_scale_inv.flatten()[ - : reduce(operator.mul, colwise_unpadded_shape) - ].reshape(colwise_unpadded_shape) - scaled_tensor = ScaledTensorFactory.create( data=rowwise_casted_output, scale_inv=rowwise_scale_inv, diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 7015c2f5e..fba9bb916 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -321,11 +321,19 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type auto bias_shape = std::vector{has_bias ? n : 0}; const int arch = cuda::sm_arch(); +#ifndef __HIP_PLATFORM_AMD__ if (arch < 100 && is_fp8_gemm) { NVTE_CHECK(!lhs_is_trans && rhs_is_trans, "For SM90 or older archs and FP8 input, only NT (row-major) GEMM is supported, ", "got lhs_is_trans=", lhs_is_trans, ", rhs_is_trans=", rhs_is_trans); } +#else + if (arch < 95 && is_fp8_gemm) { + NVTE_CHECK(!lhs_is_trans && rhs_is_trans, + "For FP8 input on gfx942, only NT (row-major) GEMM is supported, ", + "got lhs_is_trans=", lhs_is_trans, ", rhs_is_trans=", rhs_is_trans); + } +#endif // These lists are to keep the TensorWrapper objects alive std::vector lhs_wrapper_list; @@ -514,6 +522,7 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type size_t num_non_empty_gemms = lhs_list.size(); +#ifndef __HIP_PLATFORM_AMD__ if (is_mxfp8_scaling) { for (int i = 0; i < num_non_empty_gemms; i++) { // The i-th GEMM will use the (i % num_streams)-th stream to compute, @@ -525,8 +534,9 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type nvte_swizzle_scaling_factors(rhs_swizzle_list[i], rhs_list[i], stream_i); } } +#endif - // Launch zero-out kernels before the GEMM calls to use the sync in the multi-stream GEMM +// Launch zero-out kernels before the GEMM calls to use the sync in the multi-stream GEMM size_t num_zero_outs = zero_out_dptr_list.size(); for (int i = 0; i < num_zero_outs; i++) { int stream_id = i % num_streams; diff --git a/transformer_engine/jax/layernorm_mlp.py b/transformer_engine/jax/layernorm_mlp.py index fc957801a..264d9fc2b 100644 --- a/transformer_engine/jax/layernorm_mlp.py +++ b/transformer_engine/jax/layernorm_mlp.py @@ -260,7 +260,7 @@ def _layernorm_mlp_fwd_rule( assert x.shape[x_contracting_dims[0]] == kernel_1.shape[k_contracting_dims[0]] use_bias_1 = bias_1 is not None - use_bias_2 = bias_1 is not None + use_bias_2 = bias_2 is not None x = with_sharding_constraint_by_logical_axes(x, norm_input_axes) diff --git a/transformer_engine/jax/quantize/helper.py b/transformer_engine/jax/quantize/helper.py index 4037eae80..46b739bfe 100644 --- a/transformer_engine/jax/quantize/helper.py +++ b/transformer_engine/jax/quantize/helper.py @@ -90,6 +90,8 @@ def _check_block_scaling_fp8_support(gpu_arch) -> Tuple[bool, str]: A tuple of (bool, str) indicating support and any error message """ if is_hip_extension(): + if gpu_arch >= 95: + return True, "" return False, "FP8 block scaled gemm not yet supported for ROCm" if gpu_arch >= 100: # blackwell and above return True, ""