diff --git a/tests/cpp/operator/CMakeLists.txt b/tests/cpp/operator/CMakeLists.txt index 46bcf4242..ebee930a1 100644 --- a/tests/cpp/operator/CMakeLists.txt +++ b/tests/cpp/operator/CMakeLists.txt @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2022-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2022-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -75,7 +75,7 @@ if(USE_CUDA) list(APPEND test_operator_LINKER_LIBS CUDA::cudart GTest::gtest_main ${TE_LIB} CUDA::nvrtc CUDNN::cudnn) target_link_libraries(test_operator PUBLIC ${test_operator_LINKER_LIBS} OpenMP::OpenMP_CXX) else() - target_link_libraries(test_operator PUBLIC hip::host hip::device GTest::gtest_main ${TE_LIB} OpenMP::OpenMP_CXX) + target_link_libraries(test_operator PUBLIC hip::host hip::device GTest::gtest_main ${TE_LIB} OpenMP::OpenMP_CXX hiprand) endif() target_compile_options(test_operator PRIVATE -O2 -fopenmp) diff --git a/tests/cpp/operator/test_cublaslt_gemm.cu b/tests/cpp/operator/test_cublaslt_gemm.cu index e1c963734..801b52712 100644 --- a/tests/cpp/operator/test_cublaslt_gemm.cu +++ b/tests/cpp/operator/test_cublaslt_gemm.cu @@ -740,4 +740,47 @@ INSTANTIATE_TEST_SUITE_P(OperatorTest, DqGEMMTestSuite, [](const testing::TestParamInfo& info) { return MKN(std::get<0>(info.param)) + "x" + TN(std::get<3>(info.param)); }); + +TEST(InputGenTest, FillUniform_DoesNotGetOverwrittenByFromCpu) { + const size_t rows = 128; + const size_t cols = 256; + const size_t N = rows * cols; + + test::Tensor t("fillUniform_regression_fp32", + std::vector{rows, cols}, + transformer_engine::DType::kFloat32, + /*rowwise=*/true, + /*columnwise=*/false); + + // Tensor constructor initializes CPU mirror + device to zero. + // If GPU generation happens but CPU mirror is not updated, + // any later test::Tensor::from_cpu() will overwrite device back to zeros. + fillUniform(&t); + + // Check the CPU mirror has *actual* generated values, not all zeros + const float* cpu = t.rowwise_cpu_dptr(); + + bool any_nonzero = false; + for (size_t i = 0; i < N; ++i) { + any_nonzero |= (cpu[i] != 0.0f); + if (any_nonzero) + break; + } + + ASSERT_TRUE(any_nonzero) << "CPU mirror is all zeros. " + << "Likely GPU-generated data got overwritten by from_cpu()."; + + // Check device matches CPU mirror after fillUniform completes + std::vector dev(N, 0.0f); + NVTE_CHECK_CUDA(cudaMemcpy(dev.data(), + t.rowwise_dptr(), + N * sizeof(float), + cudaMemcpyDeviceToHost)); + + for (size_t i = 0; i < N; ++i) { + ASSERT_EQ(dev[i], cpu[i]) << "Mismatch at i=" << i + << " dev=" << dev[i] << " cpu=" << cpu[i]; + } +} + #endif // __HIP_PLATFORM_AMD__ diff --git a/tests/cpp/test_common.cu b/tests/cpp/test_common.cu index 9f926d07b..7a89148fd 100644 --- a/tests/cpp/test_common.cu +++ b/tests/cpp/test_common.cu @@ -23,6 +23,8 @@ #include #include "util/logging.h" +#include + namespace test { size_t create_seed_from_tensor_name(const std::string& tensor_name) { @@ -786,16 +788,9 @@ std::pair getTolerances(const DType type) { return {0, 0}; } +#ifndef __HIP_PLATFORM_AMD__ template void generate_data_uniformly(T* data, const size_t size, std::mt19937* gen) { -#ifdef __HIP_PLATFORM_AMD__ - // TODO: Introduce a parallel RNG library (Random123, PCG, rocRAND) - std::uniform_real_distribution<> dis(-2.0, 1.0); - for (int i = 0; i < size; i++) { - data[i] = static_cast(dis(*gen)); - } - gen->discard(size); -#else // Check how many RNG calls are required to generate one uniform random value int rng_calls_per_val = 0; { @@ -825,28 +820,129 @@ void generate_data_uniformly(T* data, const size_t size, std::mt19937* gen) { } } gen->discard(size * rng_calls_per_val); +} #endif + +#ifdef __HIP_PLATFORM_AMD__ +template +__global__ void affine_transform_cast_signs(const float* __restrict__ in, + const float* __restrict__ signs, + T* __restrict__ out, + size_t n, double lo, double hi) { + // Map values in *in* from [0, 1) to [lo, hi) and cast to type *T* for *out*. + // Potentially flip signs if RandomSign==true. + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + float val = lo + (hi - lo) * in[idx]; + + if constexpr (RandomSign) { + if (signs[idx] < 0.5f) + val = -val; + } + + out[idx] = static_cast(val); + } +} + +template +static void fillUniformLinearBufferDevice(T* dst_dev, + T* dst_cpu, // nullable + size_t N, + unsigned long long seed, + double lo, double hi, + bool random_sign=false) { + // Fill a linear device buffer with uniform randoms in [*lo*, *hi*] and cast them to *T*. + // Optionally mirror the result into a provided CPU pointer. + if (N == 0) + return; + + float* tmp = nullptr; + NVTE_CHECK_CUDA(cudaMalloc(&tmp, N * sizeof(float))); + + float* tmp_sign = nullptr; + if (random_sign) { + NVTE_CHECK_CUDA(cudaMalloc(&tmp_sign, N * sizeof(float))); + } + + curandGenerator_t gen; + NVTE_CHECK(curandCreateGenerator(&gen, CURAND_RNG_PSEUDO_PHILOX4_32_10) == CURAND_STATUS_SUCCESS); + NVTE_CHECK(curandSetPseudoRandomGeneratorSeed(gen, seed) == CURAND_STATUS_SUCCESS); + NVTE_CHECK(curandGenerateUniform(gen, tmp, N) == CURAND_STATUS_SUCCESS); + + if (random_sign) { + NVTE_CHECK(curandGenerateUniform(gen, tmp_sign, N) == CURAND_STATUS_SUCCESS); + } + + dim3 block(256); + dim3 grid((N + block.x - 1) / block.x); + + if (random_sign) + affine_transform_cast_signs<<>>( + tmp, tmp_sign, dst_dev, N, lo, hi); + else + affine_transform_cast_signs<<>>( + tmp, nullptr, dst_dev, N, lo, hi); + + NVTE_CHECK_CUDA(cudaGetLastError()); + + if (dst_cpu != nullptr) { + NVTE_CHECK_CUDA(cudaMemcpy(dst_cpu, dst_dev, N * sizeof(T), cudaMemcpyDeviceToHost)); + } + + NVTE_CHECK(curandDestroyGenerator(gen) == CURAND_STATUS_SUCCESS); + NVTE_CHECK_CUDA(cudaFree(tmp)); + if (tmp_sign) + NVTE_CHECK_CUDA(cudaFree(tmp_sign)); +} + +template +static void fillUniformTensorDevice(Tensor* t, double lo=-2.0f, + double hi=1.0f, bool random_sign=false) { + void* dst_dev_void = t->rowwise() ? t->rowwise_dptr() : t->columnwise_dptr(); + const auto shape = t->rowwise() ? (t->rowwise_shape()) : (t->columnwise_shape()); + const size_t N = product(shape); + + // per-tensor deterministic seed + const unsigned long long seed = static_cast(t->gen()()); + + T* dst_dev = reinterpret_cast(dst_dev_void); + // Keep the CPU mirror in sync. We could use Tensor::to_cpu() here, + // but that does more than just copying the data. + T* dst_cpu = t->rowwise() ? t->rowwise_cpu_dptr() : t->columnwise_cpu_dptr(); + fillUniformLinearBufferDevice(dst_dev, dst_cpu, N, seed, lo, hi, random_sign); } +#endif void fillUniform(Tensor *t) { if (t->rowwise()) { const size_t size = product(t->rowwise_shape()); TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(t->dtype(), T, { +#ifdef __HIP_PLATFORM_AMD__ + fillUniformTensorDevice(t); +#else T *data = t->rowwise_cpu_dptr(); generate_data_uniformly(data, size, &(t->gen())); +#endif } ); } else { const size_t size = product(t->columnwise_shape()); TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(t->dtype(), T, { +#ifdef __HIP_PLATFORM_AMD__ + fillUniformTensorDevice(t); +#else T *data = t->columnwise_cpu_dptr(); generate_data_uniformly(data, size, &(t->gen())); +#endif } ); } +#ifndef __HIP_PLATFORM_AMD__ + // Data is already on device on AMDGPU t->from_cpu(); +#endif std::uniform_real_distribution<> dis(-2.0, 1.0); t->set_scale_inv(dis(t->gen())); } @@ -857,10 +953,18 @@ void fillCase_special(Tensor *t) { if constexpr (Case == InputsFillCase::zeros) { TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(t->dtype(), InputType, { +#ifdef __HIP_PLATFORM_AMD__ + // Fill device and CPU mirror + void* dst_dev = t->rowwise_dptr(); + NVTE_CHECK_CUDA(cudaMemset(dst_dev, 0, size * sizeof(InputType))); + InputType* dst_cpu = t->rowwise_cpu_dptr(); + std::fill_n(dst_cpu, size, static_cast(0)); +#else InputType *data = t->rowwise_cpu_dptr(); for (size_t i = 0; i < size; ++i) { data[i] = static_cast(0); } +#endif }); } else { double minAbs = -2.0; @@ -869,22 +973,32 @@ void fillCase_special(Tensor *t) { minAbs = Quantized_Limits::ranges[Case]; maxAbs = Quantized_Limits::ranges[Case + 1]; } - std::uniform_real_distribution<> dis(minAbs, maxAbs); - std::uniform_real_distribution<> dis_sign(-1.0, 1.0); TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(t->dtype(), InputType, { - InputType *data = t->rowwise_cpu_dptr(); - for (size_t idx = 0; idx < size; ++idx) { - const bool is_negative = (dis_sign(t->gen()) < 0.0); - double val = dis(t->gen()); - if (is_negative) { - val = -val; - } - data[idx] = static_cast(val); - } +#ifdef __HIP_PLATFORM_AMD__ + const unsigned long long seed = static_cast(t->gen()()); + InputType* dst_dev = static_cast(t->rowwise_dptr()); + InputType* dst_cpu = static_cast(t->rowwise_cpu_dptr()); + fillUniformLinearBufferDevice(dst_dev, dst_cpu, size, seed, + minAbs, maxAbs, /*random_sign=*/true); +#else + std::uniform_real_distribution<> dis(minAbs, maxAbs); + std::uniform_real_distribution<> dis_sign(-1.0, 1.0); + InputType *data = t->rowwise_cpu_dptr(); + for (size_t idx = 0; idx < size; ++idx) { + const bool is_negative = (dis_sign(t->gen()) < 0.0); + double val = dis(t->gen()); + if (is_negative) { + val = -val; + } + data[idx] = static_cast(val); + } +#endif }); } t->set_scale_inv(1.0); +#ifndef __HIP_PLATFORM_AMD__ t->from_cpu(); +#endif } template diff --git a/tests/cpp/util/CMakeLists.txt b/tests/cpp/util/CMakeLists.txt index e80ebffbc..51c855a91 100644 --- a/tests/cpp/util/CMakeLists.txt +++ b/tests/cpp/util/CMakeLists.txt @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2022-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2022-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -21,7 +21,7 @@ find_package(OpenMP REQUIRED) if(USE_CUDA) target_link_libraries(test_util PUBLIC CUDA::cudart GTest::gtest_main ${TE_LIB} CUDA::nvrtc CUDNN::cudnn OpenMP::OpenMP_CXX) else() -target_link_libraries(test_util PUBLIC hip::host hip::device GTest::gtest_main ${TE_LIB} OpenMP::OpenMP_CXX) +target_link_libraries(test_util PUBLIC hip::host hip::device GTest::gtest_main ${TE_LIB} OpenMP::OpenMP_CXX hiprand) endif() target_compile_options(test_util PRIVATE -O2 -fopenmp)