From 14e3b75f3f76117c3862c319a8ea624b17908c77 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Fri, 16 Jan 2026 10:53:45 -0600 Subject: [PATCH 01/12] GEMMTestSuite: use rocrand for input data generation --- tests/cpp/operator/CMakeLists.txt | 2 +- tests/cpp/test_common.cu | 50 +++++++++++++++++++++++++++++++ tests/cpp/util/CMakeLists.txt | 2 +- 3 files changed, 52 insertions(+), 2 deletions(-) diff --git a/tests/cpp/operator/CMakeLists.txt b/tests/cpp/operator/CMakeLists.txt index e3af4a360..fa51bee19 100644 --- a/tests/cpp/operator/CMakeLists.txt +++ b/tests/cpp/operator/CMakeLists.txt @@ -74,7 +74,7 @@ if(USE_CUDA) list(APPEND test_operator_LINKER_LIBS CUDA::cudart GTest::gtest_main ${TE_LIB} CUDA::nvrtc CUDNN::cudnn) target_link_libraries(test_operator PUBLIC ${test_operator_LINKER_LIBS} OpenMP::OpenMP_CXX) else() - target_link_libraries(test_operator PUBLIC hip::host hip::device GTest::gtest_main ${TE_LIB} OpenMP::OpenMP_CXX) + target_link_libraries(test_operator PUBLIC hip::host hip::device GTest::gtest_main ${TE_LIB} OpenMP::OpenMP_CXX rocrand) endif() target_compile_options(test_operator PRIVATE -O2 -fopenmp) diff --git a/tests/cpp/test_common.cu b/tests/cpp/test_common.cu index a608f6ef2..f29d5b673 100644 --- a/tests/cpp/test_common.cu +++ b/tests/cpp/test_common.cu @@ -822,21 +822,71 @@ void generate_data_uniformly(T* data, const size_t size, std::mt19937* gen) { #endif } +#ifdef __HIP_PLATFORM_AMD__ +#include + +template +__global__ void affine_transform_and_cast(float* __restrict__ in, T* __restrict__ out, size_t n, float lo, float hi) { + // Clamp values in *in* to [lo, hi] and cast to type *T* for *out*. + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + in[idx] = lo + (hi - lo) * in[idx]; + out[idx] = static_cast(in[idx]); + } +} + +void fillUniformDevice(Tensor* t) { + void* dst = t->rowwise() ? t->rowwise_dptr() : t->columnwise_dptr(); + const auto shape = t->rowwise() ? t->rowwise_shape() : t->columnwise_shape(); + const size_t N = product(shape); + + float* tmp = nullptr; + hipMalloc(&tmp, N * sizeof(float)); + + // per-tensor deterministic seed + const unsigned long long seed = static_cast(t->gen()()); + rocrand_generator gen; + rocrand_create_generator(&gen, ROCRAND_RNG_PSEUDO_PHILOX4_32_10); + rocrand_set_seed(gen, seed); + + rocrand_generate_uniform(gen, tmp, N); + + // map to [-2, 1] (like generate_data_uniformly) and cast into tensor dtype + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(t->dtype(), T, { + dim3 block(256); + dim3 grid((N + block.x - 1) / block.x); + hipLaunchKernelGGL(affine_transform_and_cast, grid, block, 0, 0, + tmp, reinterpret_cast(dst), N, -2.0f, 1.0f); + }); + + rocrand_destroy_generator(gen); + hipFree(tmp); +} +#endif + void fillUniform(Tensor *t) { if (t->rowwise()) { const size_t size = product(t->rowwise_shape()); TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(t->dtype(), T, { +#ifdef __HIP_PLATFORM_AMD__ + fillUniformDevice(t); +#else T *data = t->rowwise_cpu_dptr(); generate_data_uniformly(data, size, &(t->gen())); +#endif } ); } else { const size_t size = product(t->columnwise_shape()); TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(t->dtype(), T, { +#ifdef __HIP_PLATFORM_AMD__ + fillUniformDevice(t); +#else T *data = t->columnwise_cpu_dptr(); generate_data_uniformly(data, size, &(t->gen())); +#endif } ); } diff --git a/tests/cpp/util/CMakeLists.txt b/tests/cpp/util/CMakeLists.txt index e80ebffbc..5d494a6d4 100644 --- a/tests/cpp/util/CMakeLists.txt +++ b/tests/cpp/util/CMakeLists.txt @@ -21,7 +21,7 @@ find_package(OpenMP REQUIRED) if(USE_CUDA) target_link_libraries(test_util PUBLIC CUDA::cudart GTest::gtest_main ${TE_LIB} CUDA::nvrtc CUDNN::cudnn OpenMP::OpenMP_CXX) else() -target_link_libraries(test_util PUBLIC hip::host hip::device GTest::gtest_main ${TE_LIB} OpenMP::OpenMP_CXX) +target_link_libraries(test_util PUBLIC hip::host hip::device GTest::gtest_main ${TE_LIB} OpenMP::OpenMP_CXX rocrand) endif() target_compile_options(test_util PRIVATE -O2 -fopenmp) From 0d4d62fbd7ed7b9ed1c775060b25371a9ad4e4c7 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Fri, 16 Jan 2026 11:19:25 -0600 Subject: [PATCH 02/12] adjust comments --- tests/cpp/test_common.cu | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/cpp/test_common.cu b/tests/cpp/test_common.cu index f29d5b673..636b6e516 100644 --- a/tests/cpp/test_common.cu +++ b/tests/cpp/test_common.cu @@ -783,7 +783,6 @@ std::pair getTolerances(const DType type) { template void generate_data_uniformly(T* data, const size_t size, std::mt19937* gen) { #ifdef __HIP_PLATFORM_AMD__ - // TODO: Introduce a parallel RNG library (Random123, PCG, rocRAND) std::uniform_real_distribution<> dis(-2.0, 1.0); for (int i = 0; i < size; i++) { data[i] = static_cast(dis(*gen)); @@ -851,7 +850,7 @@ void fillUniformDevice(Tensor* t) { rocrand_generate_uniform(gen, tmp, N); - // map to [-2, 1] (like generate_data_uniformly) and cast into tensor dtype + // map to [-2.0, 1.0] (like generate_data_uniformly) and cast into tensor dtype TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(t->dtype(), T, { dim3 block(256); dim3 grid((N + block.x - 1) / block.x); From 3f10ed3a87ac605c5969f332f4debfdf6d112808 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 19 Jan 2026 16:57:05 +0000 Subject: [PATCH 03/12] skip copying to device --- tests/cpp/test_common.cu | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/cpp/test_common.cu b/tests/cpp/test_common.cu index 0f78a0419..4f13409b3 100644 --- a/tests/cpp/test_common.cu +++ b/tests/cpp/test_common.cu @@ -895,7 +895,10 @@ void fillUniform(Tensor *t) { } ); } +#ifndef __HIP_PLATFORM_AMD__ +// Data is already on device on AMDGPU t->from_cpu(); +#endif std::uniform_real_distribution<> dis(-2.0, 1.0); t->set_scale_inv(dis(t->gen())); } From 0f008e9859e385ad0b40510ffc8dedc962b1526d Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 19 Jan 2026 17:04:03 -0600 Subject: [PATCH 04/12] move include, use hipify more, fix CPU copy --- tests/cpp/test_common.cu | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/tests/cpp/test_common.cu b/tests/cpp/test_common.cu index 4f13409b3..e0614b191 100644 --- a/tests/cpp/test_common.cu +++ b/tests/cpp/test_common.cu @@ -23,6 +23,10 @@ #include #include "util/logging.h" +#ifdef __HIP_PLATFORM_AMD__ +#include +#endif + namespace test { size_t create_seed_from_tensor_name(const std::string& tensor_name) { @@ -828,8 +832,6 @@ void generate_data_uniformly(T* data, const size_t size, std::mt19937* gen) { } #ifdef __HIP_PLATFORM_AMD__ -#include - template __global__ void affine_transform_and_cast(float* __restrict__ in, T* __restrict__ out, size_t n, float lo, float hi) { // Clamp values in *in* to [lo, hi] and cast to type *T* for *out*. @@ -846,7 +848,7 @@ void fillUniformDevice(Tensor* t) { const size_t N = product(shape); float* tmp = nullptr; - hipMalloc(&tmp, N * sizeof(float)); + cudaMalloc(&tmp, N * sizeof(float)); // per-tensor deterministic seed const unsigned long long seed = static_cast(t->gen()()); @@ -860,12 +862,17 @@ void fillUniformDevice(Tensor* t) { TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(t->dtype(), T, { dim3 block(256); dim3 grid((N + block.x - 1) / block.x); - hipLaunchKernelGGL(affine_transform_and_cast, grid, block, 0, 0, - tmp, reinterpret_cast(dst), N, -2.0f, 1.0f); + affine_transform_and_cast<<>>( + tmp, reinterpret_cast(dst), N, -2.0f, 1.0f); + + // Copy into the CPU mirror. We could use Tensor::to_cpu() here, + // but that does more than just copying the data. + T* cpu_dst = t->rowwise() ? t->rowwise_cpu_dptr() : t->columnwise_cpu_dptr(); + cudaMemcpy(cpu_dst, dst, N * sizeof(T), hipMemcpyDeviceToHost); }); rocrand_destroy_generator(gen); - hipFree(tmp); + cudaFree(tmp); } #endif @@ -896,7 +903,7 @@ void fillUniform(Tensor *t) { ); } #ifndef __HIP_PLATFORM_AMD__ -// Data is already on device on AMDGPU + // Data is already on device on AMDGPU t->from_cpu(); #endif std::uniform_real_distribution<> dis(-2.0, 1.0); From 4a4d138aeaaa43b5499e92c7d2550c6f95dc5d16 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Tue, 20 Jan 2026 10:10:44 -0600 Subject: [PATCH 05/12] remove now-superfluous AMD code and disable generate_data_uniformly --- tests/cpp/operator/CMakeLists.txt | 2 +- tests/cpp/test_common.cu | 10 ++-------- tests/cpp/util/CMakeLists.txt | 2 +- 3 files changed, 4 insertions(+), 10 deletions(-) diff --git a/tests/cpp/operator/CMakeLists.txt b/tests/cpp/operator/CMakeLists.txt index e4728a760..da39d3a7f 100644 --- a/tests/cpp/operator/CMakeLists.txt +++ b/tests/cpp/operator/CMakeLists.txt @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2022-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2022-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/tests/cpp/test_common.cu b/tests/cpp/test_common.cu index e0614b191..bb91d2381 100644 --- a/tests/cpp/test_common.cu +++ b/tests/cpp/test_common.cu @@ -790,15 +790,9 @@ std::pair getTolerances(const DType type) { return {0, 0}; } +#ifndef __HIP_PLATFORM_AMD__ template void generate_data_uniformly(T* data, const size_t size, std::mt19937* gen) { -#ifdef __HIP_PLATFORM_AMD__ - std::uniform_real_distribution<> dis(-2.0, 1.0); - for (int i = 0; i < size; i++) { - data[i] = static_cast(dis(*gen)); - } - gen->discard(size); -#else // Check how many RNG calls are required to generate one uniform random value int rng_calls_per_val = 0; { @@ -828,8 +822,8 @@ void generate_data_uniformly(T* data, const size_t size, std::mt19937* gen) { } } gen->discard(size * rng_calls_per_val); -#endif } +#endif #ifdef __HIP_PLATFORM_AMD__ template diff --git a/tests/cpp/util/CMakeLists.txt b/tests/cpp/util/CMakeLists.txt index 5d494a6d4..11973ee67 100644 --- a/tests/cpp/util/CMakeLists.txt +++ b/tests/cpp/util/CMakeLists.txt @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2022-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2022-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. From 6a89a41a704753b77923977cc250ed2e9ebc9886 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Tue, 20 Jan 2026 14:55:58 -0600 Subject: [PATCH 06/12] split fill function into linear+frontend --- tests/cpp/test_common.cu | 65 +++++++++++++++++++++++++--------------- 1 file changed, 41 insertions(+), 24 deletions(-) diff --git a/tests/cpp/test_common.cu b/tests/cpp/test_common.cu index bb91d2381..8d2d3e47d 100644 --- a/tests/cpp/test_common.cu +++ b/tests/cpp/test_common.cu @@ -836,37 +836,54 @@ __global__ void affine_transform_and_cast(float* __restrict__ in, T* __restrict_ } } -void fillUniformDevice(Tensor* t) { - void* dst = t->rowwise() ? t->rowwise_dptr() : t->columnwise_dptr(); - const auto shape = t->rowwise() ? t->rowwise_shape() : t->columnwise_shape(); - const size_t N = product(shape); +template +static void fillUniformLinearBufferDevice(T* dst_dev, + T* dst_cpu, // nullable + size_t N, + unsigned long long seed, + float lo, float hi) { + // Fill a linear device buffer with uniform randoms in [*lo*, *hi*] and cast them to *T*. + // Optionally mirror the result into a provided CPU pointer. + if (N == 0) + return; float* tmp = nullptr; - cudaMalloc(&tmp, N * sizeof(float)); + NVTE_CHECK_CUDA(cudaMalloc(&tmp, N * sizeof(float))); - // per-tensor deterministic seed - const unsigned long long seed = static_cast(t->gen()()); rocrand_generator gen; - rocrand_create_generator(&gen, ROCRAND_RNG_PSEUDO_PHILOX4_32_10); - rocrand_set_seed(gen, seed); + NVTE_CHECK(rocrand_create_generator(&gen, ROCRAND_RNG_PSEUDO_PHILOX4_32_10) == ROCRAND_STATUS_SUCCESS); + NVTE_CHECK(rocrand_set_seed(gen, seed) == ROCRAND_STATUS_SUCCESS); + NVTE_CHECK(rocrand_generate_uniform(gen, tmp, N) == ROCRAND_STATUS_SUCCESS); + + dim3 block(256); + dim3 grid((N + block.x - 1) / block.x); + affine_transform_and_cast<<>>( + tmp, reinterpret_cast(dst_dev), N, lo, hi); + NVTE_CHECK(cudaGetLastError() == hipSuccess); + + if (dst_cpu != nullptr) { + NVTE_CHECK_CUDA(cudaMemcpy(dst_cpu, dst_dev, N * sizeof(T), cudaMemcpyDeviceToHost)); + } - rocrand_generate_uniform(gen, tmp, N); + NVTE_CHECK(rocrand_destroy_generator(gen) == ROCRAND_STATUS_SUCCESS); + NVTE_CHECK_CUDA(cudaFree(tmp)); +} - // map to [-2.0, 1.0] (like generate_data_uniformly) and cast into tensor dtype - TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(t->dtype(), T, { - dim3 block(256); - dim3 grid((N + block.x - 1) / block.x); - affine_transform_and_cast<<>>( - tmp, reinterpret_cast(dst), N, -2.0f, 1.0f); +static void fillUniformTensorDevice(Tensor* t) { + void* dst_dev_void = t->rowwise() ? t->rowwise_dptr() : t->columnwise_dptr(); + const auto shape = t->rowwise() ? (t->rowwise_shape()) : (t->columnwise_shape()); + const size_t N = product(shape); + + // per-tensor deterministic seed + const unsigned long long seed = static_cast(t->gen()()); - // Copy into the CPU mirror. We could use Tensor::to_cpu() here, + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(t->dtype(), T, { + T* dst_dev = reinterpret_cast(dst_dev_void); + // Keep the CPU mirror in sync. We could use Tensor::to_cpu() here, // but that does more than just copying the data. - T* cpu_dst = t->rowwise() ? t->rowwise_cpu_dptr() : t->columnwise_cpu_dptr(); - cudaMemcpy(cpu_dst, dst, N * sizeof(T), hipMemcpyDeviceToHost); + T* dst_cpu = t->rowwise() ? t->rowwise_cpu_dptr() : t->columnwise_cpu_dptr(); + fillUniformLinearBufferDevice(dst_dev, dst_cpu, N, seed, /*lo=*/-2.0f, /*hi=*/1.0f); }); - - rocrand_destroy_generator(gen); - cudaFree(tmp); } #endif @@ -876,7 +893,7 @@ void fillUniform(Tensor *t) { TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(t->dtype(), T, { #ifdef __HIP_PLATFORM_AMD__ - fillUniformDevice(t); + fillUniformTensorDevice(t); #else T *data = t->rowwise_cpu_dptr(); generate_data_uniformly(data, size, &(t->gen())); @@ -888,7 +905,7 @@ void fillUniform(Tensor *t) { TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(t->dtype(), T, { #ifdef __HIP_PLATFORM_AMD__ - fillUniformDevice(t); + fillUniformTensorDevice(t); #else T *data = t->columnwise_cpu_dptr(); generate_data_uniformly(data, size, &(t->gen())); From b6eee816a8906cbd1df3ea2a10e421406d5970b6 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Tue, 20 Jan 2026 17:25:11 -0600 Subject: [PATCH 07/12] also offload fillCase_special --- tests/cpp/test_common.cu | 84 +++++++++++++++++++++++++++++++--------- 1 file changed, 66 insertions(+), 18 deletions(-) diff --git a/tests/cpp/test_common.cu b/tests/cpp/test_common.cu index 8d2d3e47d..ce005566a 100644 --- a/tests/cpp/test_common.cu +++ b/tests/cpp/test_common.cu @@ -827,21 +827,35 @@ void generate_data_uniformly(T* data, const size_t size, std::mt19937* gen) { #ifdef __HIP_PLATFORM_AMD__ template -__global__ void affine_transform_and_cast(float* __restrict__ in, T* __restrict__ out, size_t n, float lo, float hi) { +__global__ void affine_transform_and_cast(const float* __restrict__ in, + T* __restrict__ out, size_t n, double lo, + double hi) { // Clamp values in *in* to [lo, hi] and cast to type *T* for *out*. size_t idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx < n) { - in[idx] = lo + (hi - lo) * in[idx]; - out[idx] = static_cast(in[idx]); + out[idx] = static_cast(lo + (hi - lo) * in[idx]); + } +} + +template +__global__ void apply_random_sign(T* __restrict__ data, + const float* __restrict__ signs, + size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + if (signs[idx] < 0.5f) { + data[idx] = static_cast(-static_cast(data[idx])); + } } } template static void fillUniformLinearBufferDevice(T* dst_dev, - T* dst_cpu, // nullable + T* dst_cpu, // nullable size_t N, unsigned long long seed, - float lo, float hi) { + double lo, double hi, + bool random_sign=false) { // Fill a linear device buffer with uniform randoms in [*lo*, *hi*] and cast them to *T*. // Optionally mirror the result into a provided CPU pointer. if (N == 0) @@ -850,15 +864,28 @@ static void fillUniformLinearBufferDevice(T* dst_dev, float* tmp = nullptr; NVTE_CHECK_CUDA(cudaMalloc(&tmp, N * sizeof(float))); + float* tmp_sign = nullptr; + if (random_sign) { + NVTE_CHECK_CUDA(cudaMalloc(&tmp_sign, N * sizeof(float))); + } + rocrand_generator gen; NVTE_CHECK(rocrand_create_generator(&gen, ROCRAND_RNG_PSEUDO_PHILOX4_32_10) == ROCRAND_STATUS_SUCCESS); NVTE_CHECK(rocrand_set_seed(gen, seed) == ROCRAND_STATUS_SUCCESS); NVTE_CHECK(rocrand_generate_uniform(gen, tmp, N) == ROCRAND_STATUS_SUCCESS); + if (random_sign) { + NVTE_CHECK(rocrand_generate_uniform(gen, tmp_sign, N) == ROCRAND_STATUS_SUCCESS); + } + dim3 block(256); dim3 grid((N + block.x - 1) / block.x); affine_transform_and_cast<<>>( tmp, reinterpret_cast(dst_dev), N, lo, hi); + if (random_sign) { + apply_random_sign<<>>( + reinterpret_cast(dst_dev), tmp_sign, N); + } NVTE_CHECK(cudaGetLastError() == hipSuccess); if (dst_cpu != nullptr) { @@ -867,9 +894,12 @@ static void fillUniformLinearBufferDevice(T* dst_dev, NVTE_CHECK(rocrand_destroy_generator(gen) == ROCRAND_STATUS_SUCCESS); NVTE_CHECK_CUDA(cudaFree(tmp)); + if (tmp_sign) + cudaFree(tmp_sign); } -static void fillUniformTensorDevice(Tensor* t) { +static void fillUniformTensorDevice(Tensor* t, double lo=-2.0f, + double hi=1.0f, bool random_sign=false) { void* dst_dev_void = t->rowwise() ? t->rowwise_dptr() : t->columnwise_dptr(); const auto shape = t->rowwise() ? (t->rowwise_shape()) : (t->columnwise_shape()); const size_t N = product(shape); @@ -882,7 +912,7 @@ static void fillUniformTensorDevice(Tensor* t) { // Keep the CPU mirror in sync. We could use Tensor::to_cpu() here, // but that does more than just copying the data. T* dst_cpu = t->rowwise() ? t->rowwise_cpu_dptr() : t->columnwise_cpu_dptr(); - fillUniformLinearBufferDevice(dst_dev, dst_cpu, N, seed, /*lo=*/-2.0f, /*hi=*/1.0f); + fillUniformLinearBufferDevice(dst_dev, dst_cpu, N, seed, lo, hi, random_sign); }); } #endif @@ -927,10 +957,18 @@ void fillCase_special(Tensor *t) { if constexpr (Case == InputsFillCase::zeros) { TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(t->dtype(), InputType, { +#ifdef __HIP_PLATFORM_AMD__ + // Fill device and CPU mirror + void* dst_dev = t->rowwise_dptr(); + NVTE_CHECK_CUDA(cudaMemset(dst_dev, 0, size * sizeof(InputType))); + InputType* dst_cpu = t->rowwise_cpu_dptr(); + std::fill_n(dst_cpu, size, static_cast(0)); +#else InputType *data = t->rowwise_cpu_dptr(); for (size_t i = 0; i < size; ++i) { data[i] = static_cast(0); } +#endif }); } else { double minAbs = -2.0; @@ -939,22 +977,32 @@ void fillCase_special(Tensor *t) { minAbs = Quantized_Limits::ranges[Case]; maxAbs = Quantized_Limits::ranges[Case + 1]; } - std::uniform_real_distribution<> dis(minAbs, maxAbs); - std::uniform_real_distribution<> dis_sign(-1.0, 1.0); TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(t->dtype(), InputType, { - InputType *data = t->rowwise_cpu_dptr(); - for (size_t idx = 0; idx < size; ++idx) { - const bool is_negative = (dis_sign(t->gen()) < 0.0); - double val = dis(t->gen()); - if (is_negative) { - val = -val; - } - data[idx] = static_cast(val); - } +#ifdef __HIP_PLATFORM_AMD__ + const unsigned long long seed = static_cast(t->gen()()); + InputType* dst_dev = static_cast(t->rowwise_dptr()); + InputType* dst_cpu = static_cast(t->rowwise_cpu_dptr()); + fillUniformLinearBufferDevice(dst_dev, dst_cpu, size, seed, + minAbs, maxAbs, /*random_sign=*/true); +#else + std::uniform_real_distribution<> dis(minAbs, maxAbs); + std::uniform_real_distribution<> dis_sign(-1.0, 1.0); + InputType *data = t->rowwise_cpu_dptr(); + for (size_t idx = 0; idx < size; ++idx) { + const bool is_negative = (dis_sign(t->gen()) < 0.0); + double val = dis(t->gen()); + if (is_negative) { + val = -val; + } + data[idx] = static_cast(val); + } +#endif }); } t->set_scale_inv(1.0); +#ifndef __HIP_PLATFORM_AMD__ t->from_cpu(); +#endif } template From 0ff7067749352b84c5b240c2b46c1ee6ca5bc9f9 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 21 Jan 2026 13:05:31 -0600 Subject: [PATCH 08/12] move to curand/hiprand curand is already used in other places in TE. --- tests/cpp/operator/CMakeLists.txt | 2 +- tests/cpp/test_common.cu | 18 +++++++++--------- tests/cpp/util/CMakeLists.txt | 2 +- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/tests/cpp/operator/CMakeLists.txt b/tests/cpp/operator/CMakeLists.txt index da39d3a7f..ebee930a1 100644 --- a/tests/cpp/operator/CMakeLists.txt +++ b/tests/cpp/operator/CMakeLists.txt @@ -75,7 +75,7 @@ if(USE_CUDA) list(APPEND test_operator_LINKER_LIBS CUDA::cudart GTest::gtest_main ${TE_LIB} CUDA::nvrtc CUDNN::cudnn) target_link_libraries(test_operator PUBLIC ${test_operator_LINKER_LIBS} OpenMP::OpenMP_CXX) else() - target_link_libraries(test_operator PUBLIC hip::host hip::device GTest::gtest_main ${TE_LIB} OpenMP::OpenMP_CXX rocrand) + target_link_libraries(test_operator PUBLIC hip::host hip::device GTest::gtest_main ${TE_LIB} OpenMP::OpenMP_CXX hiprand) endif() target_compile_options(test_operator PRIVATE -O2 -fopenmp) diff --git a/tests/cpp/test_common.cu b/tests/cpp/test_common.cu index ce005566a..630837f3b 100644 --- a/tests/cpp/test_common.cu +++ b/tests/cpp/test_common.cu @@ -24,7 +24,7 @@ #include "util/logging.h" #ifdef __HIP_PLATFORM_AMD__ -#include +#include #endif namespace test { @@ -869,13 +869,13 @@ static void fillUniformLinearBufferDevice(T* dst_dev, NVTE_CHECK_CUDA(cudaMalloc(&tmp_sign, N * sizeof(float))); } - rocrand_generator gen; - NVTE_CHECK(rocrand_create_generator(&gen, ROCRAND_RNG_PSEUDO_PHILOX4_32_10) == ROCRAND_STATUS_SUCCESS); - NVTE_CHECK(rocrand_set_seed(gen, seed) == ROCRAND_STATUS_SUCCESS); - NVTE_CHECK(rocrand_generate_uniform(gen, tmp, N) == ROCRAND_STATUS_SUCCESS); + curandGenerator_t gen; + NVTE_CHECK(curandCreateGenerator(&gen, CURAND_RNG_PSEUDO_PHILOX4_32_10) == CURAND_STATUS_SUCCESS); + NVTE_CHECK(curandSetPseudoRandomGeneratorSeed(gen, seed) == CURAND_STATUS_SUCCESS); + NVTE_CHECK(curandGenerateUniform(gen, tmp, N) == CURAND_STATUS_SUCCESS); if (random_sign) { - NVTE_CHECK(rocrand_generate_uniform(gen, tmp_sign, N) == ROCRAND_STATUS_SUCCESS); + NVTE_CHECK(curandGenerateUniform(gen, tmp_sign, N) == CURAND_STATUS_SUCCESS); } dim3 block(256); @@ -886,16 +886,16 @@ static void fillUniformLinearBufferDevice(T* dst_dev, apply_random_sign<<>>( reinterpret_cast(dst_dev), tmp_sign, N); } - NVTE_CHECK(cudaGetLastError() == hipSuccess); + NVTE_CHECK_CUDA(cudaGetLastError()); if (dst_cpu != nullptr) { NVTE_CHECK_CUDA(cudaMemcpy(dst_cpu, dst_dev, N * sizeof(T), cudaMemcpyDeviceToHost)); } - NVTE_CHECK(rocrand_destroy_generator(gen) == ROCRAND_STATUS_SUCCESS); + NVTE_CHECK(curandDestroyGenerator(gen) == CURAND_STATUS_SUCCESS); NVTE_CHECK_CUDA(cudaFree(tmp)); if (tmp_sign) - cudaFree(tmp_sign); + NVTE_CHECK_CUDA(cudaFree(tmp_sign)); } static void fillUniformTensorDevice(Tensor* t, double lo=-2.0f, diff --git a/tests/cpp/util/CMakeLists.txt b/tests/cpp/util/CMakeLists.txt index 11973ee67..51c855a91 100644 --- a/tests/cpp/util/CMakeLists.txt +++ b/tests/cpp/util/CMakeLists.txt @@ -21,7 +21,7 @@ find_package(OpenMP REQUIRED) if(USE_CUDA) target_link_libraries(test_util PUBLIC CUDA::cudart GTest::gtest_main ${TE_LIB} CUDA::nvrtc CUDNN::cudnn OpenMP::OpenMP_CXX) else() -target_link_libraries(test_util PUBLIC hip::host hip::device GTest::gtest_main ${TE_LIB} OpenMP::OpenMP_CXX rocrand) +target_link_libraries(test_util PUBLIC hip::host hip::device GTest::gtest_main ${TE_LIB} OpenMP::OpenMP_CXX hiprand) endif() target_compile_options(test_util PRIVATE -O2 -fopenmp) From 097ecd4e8dd9e41c86a0095733e482afb9ab185b Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 21 Jan 2026 14:21:51 -0600 Subject: [PATCH 09/12] add test for correct GPU->CPU mirroring --- tests/cpp/operator/test_cublaslt_gemm.cu | 45 +++++++++++++++++++++++- tests/cpp/test_common.cu | 3 ++ 2 files changed, 47 insertions(+), 1 deletion(-) diff --git a/tests/cpp/operator/test_cublaslt_gemm.cu b/tests/cpp/operator/test_cublaslt_gemm.cu index 61ca86a1e..386deb073 100644 --- a/tests/cpp/operator/test_cublaslt_gemm.cu +++ b/tests/cpp/operator/test_cublaslt_gemm.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. * * License for AMD contributions = MIT. See LICENSE for more information ************************************************************************/ @@ -635,4 +635,47 @@ INSTANTIATE_TEST_SUITE_P(OperatorTest, DqGEMMTestSuite, [](const testing::TestParamInfo& info) { return MKN(std::get<0>(info.param)) + "x" + TN(std::get<3>(info.param)); }); + +TEST(InputGenTest, FillUniform_DoesNotGetOverwrittenByFromCpu) { + const size_t rows = 128; + const size_t cols = 256; + const size_t N = rows * cols; + + test::Tensor t("fillUniform_regression_fp32", + std::vector{rows, cols}, + transformer_engine::DType::kFloat32, + /*rowwise=*/true, + /*columnwise=*/false); + + // Tensor constructor initializes CPU mirror + device to zero. + // If GPU generation happens but CPU mirror is not updated, + // any later test::Tensor::from_cpu() will overwrite device back to zeros. + fillUniform(&t); + + // Check the CPU mirror has *actual* generated values, not all zeros + const float* cpu = t.rowwise_cpu_dptr(); + + bool any_nonzero = false; + for (size_t i = 0; i < N; ++i) { + any_nonzero |= (cpu[i] != 0.0f); + if (any_nonzero) + break; + } + + ASSERT_TRUE(any_nonzero) << "CPU mirror is all zeros. " + << "Likely GPU-generated data got overwritten by from_cpu()."; + + // Check device matches CPU mirror after fillUniform completes + std::vector dev(N, 0.0f); + NVTE_CHECK_CUDA(cudaMemcpy(dev.data(), + t.rowwise_dptr(), + N * sizeof(float), + cudaMemcpyDeviceToHost)); + + for (size_t i = 0; i < N; ++i) { + ASSERT_EQ(dev[i], cpu[i]) << "Mismatch at i=" << i + << " dev=" << dev[i] << " cpu=" << cpu[i]; + } +} + #endif // __HIP_PLATFORM_AMD__ diff --git a/tests/cpp/test_common.cu b/tests/cpp/test_common.cu index 630837f3b..607c0c4a7 100644 --- a/tests/cpp/test_common.cu +++ b/tests/cpp/test_common.cu @@ -880,12 +880,15 @@ static void fillUniformLinearBufferDevice(T* dst_dev, dim3 block(256); dim3 grid((N + block.x - 1) / block.x); + affine_transform_and_cast<<>>( tmp, reinterpret_cast(dst_dev), N, lo, hi); + if (random_sign) { apply_random_sign<<>>( reinterpret_cast(dst_dev), tmp_sign, N); } + NVTE_CHECK_CUDA(cudaGetLastError()); if (dst_cpu != nullptr) { From dfd51e14a866ff4a1de9a0b16c71d8a8cefa6244 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 21 Jan 2026 16:13:29 -0600 Subject: [PATCH 10/12] remove extra __ifdef__ --- tests/cpp/test_common.cu | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/cpp/test_common.cu b/tests/cpp/test_common.cu index 607c0c4a7..93a7fdb9e 100644 --- a/tests/cpp/test_common.cu +++ b/tests/cpp/test_common.cu @@ -23,9 +23,7 @@ #include #include "util/logging.h" -#ifdef __HIP_PLATFORM_AMD__ #include -#endif namespace test { From ddfcf2d059f69475184a7ef8e4e8ceaf1610cf3c Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 22 Jan 2026 12:37:17 -0600 Subject: [PATCH 11/12] fuse signs and transform kernels --- tests/cpp/test_common.cu | 42 ++++++++++++++++++---------------------- 1 file changed, 19 insertions(+), 23 deletions(-) diff --git a/tests/cpp/test_common.cu b/tests/cpp/test_common.cu index 93a7fdb9e..5b93328c8 100644 --- a/tests/cpp/test_common.cu +++ b/tests/cpp/test_common.cu @@ -824,26 +824,23 @@ void generate_data_uniformly(T* data, const size_t size, std::mt19937* gen) { #endif #ifdef __HIP_PLATFORM_AMD__ -template -__global__ void affine_transform_and_cast(const float* __restrict__ in, - T* __restrict__ out, size_t n, double lo, - double hi) { - // Clamp values in *in* to [lo, hi] and cast to type *T* for *out*. +template +__global__ void affine_transform_cast_signs(const float* __restrict__ in, + const float* __restrict__ signs, + T* __restrict__ out, + size_t n, double lo, double hi) { + // Map values in *in* from [0, 1) to [lo, hi) and cast to type *T* for *out*. + // Potentially flip signs if RandomSign==true. size_t idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx < n) { - out[idx] = static_cast(lo + (hi - lo) * in[idx]); - } -} + float val = lo + (hi - lo) * in[idx]; -template -__global__ void apply_random_sign(T* __restrict__ data, - const float* __restrict__ signs, - size_t n) { - size_t idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx < n) { - if (signs[idx] < 0.5f) { - data[idx] = static_cast(-static_cast(data[idx])); + if constexpr (RandomSign) { + if (signs[idx] < 0.5f) + val = -val; } + + out[idx] = static_cast(val); } } @@ -879,13 +876,12 @@ static void fillUniformLinearBufferDevice(T* dst_dev, dim3 block(256); dim3 grid((N + block.x - 1) / block.x); - affine_transform_and_cast<<>>( - tmp, reinterpret_cast(dst_dev), N, lo, hi); - - if (random_sign) { - apply_random_sign<<>>( - reinterpret_cast(dst_dev), tmp_sign, N); - } + if (random_sign) + affine_transform_cast_signs<<>>( + tmp, tmp_sign, reinterpret_cast(dst_dev), N, lo, hi); + else + affine_transform_cast_signs<<>>( + tmp, nullptr, reinterpret_cast(dst_dev), N, lo, hi); NVTE_CHECK_CUDA(cudaGetLastError()); From 33f6124511e4e85af536ba40854a3711b2f7e5aa Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Fri, 30 Jan 2026 17:07:16 +0000 Subject: [PATCH 12/12] clean up type switches --- tests/cpp/test_common.cu | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/tests/cpp/test_common.cu b/tests/cpp/test_common.cu index 5b93328c8..7a89148fd 100644 --- a/tests/cpp/test_common.cu +++ b/tests/cpp/test_common.cu @@ -878,10 +878,10 @@ static void fillUniformLinearBufferDevice(T* dst_dev, if (random_sign) affine_transform_cast_signs<<>>( - tmp, tmp_sign, reinterpret_cast(dst_dev), N, lo, hi); + tmp, tmp_sign, dst_dev, N, lo, hi); else affine_transform_cast_signs<<>>( - tmp, nullptr, reinterpret_cast(dst_dev), N, lo, hi); + tmp, nullptr, dst_dev, N, lo, hi); NVTE_CHECK_CUDA(cudaGetLastError()); @@ -895,6 +895,7 @@ static void fillUniformLinearBufferDevice(T* dst_dev, NVTE_CHECK_CUDA(cudaFree(tmp_sign)); } +template static void fillUniformTensorDevice(Tensor* t, double lo=-2.0f, double hi=1.0f, bool random_sign=false) { void* dst_dev_void = t->rowwise() ? t->rowwise_dptr() : t->columnwise_dptr(); @@ -904,13 +905,11 @@ static void fillUniformTensorDevice(Tensor* t, double lo=-2.0f, // per-tensor deterministic seed const unsigned long long seed = static_cast(t->gen()()); - TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(t->dtype(), T, { - T* dst_dev = reinterpret_cast(dst_dev_void); - // Keep the CPU mirror in sync. We could use Tensor::to_cpu() here, - // but that does more than just copying the data. - T* dst_cpu = t->rowwise() ? t->rowwise_cpu_dptr() : t->columnwise_cpu_dptr(); - fillUniformLinearBufferDevice(dst_dev, dst_cpu, N, seed, lo, hi, random_sign); - }); + T* dst_dev = reinterpret_cast(dst_dev_void); + // Keep the CPU mirror in sync. We could use Tensor::to_cpu() here, + // but that does more than just copying the data. + T* dst_cpu = t->rowwise() ? t->rowwise_cpu_dptr() : t->columnwise_cpu_dptr(); + fillUniformLinearBufferDevice(dst_dev, dst_cpu, N, seed, lo, hi, random_sign); } #endif @@ -920,7 +919,7 @@ void fillUniform(Tensor *t) { TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(t->dtype(), T, { #ifdef __HIP_PLATFORM_AMD__ - fillUniformTensorDevice(t); + fillUniformTensorDevice(t); #else T *data = t->rowwise_cpu_dptr(); generate_data_uniformly(data, size, &(t->gen())); @@ -932,7 +931,7 @@ void fillUniform(Tensor *t) { TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(t->dtype(), T, { #ifdef __HIP_PLATFORM_AMD__ - fillUniformTensorDevice(t); + fillUniformTensorDevice(t); #else T *data = t->columnwise_cpu_dptr(); generate_data_uniformly(data, size, &(t->gen()));