From eb5e33cfbbdcd3ff1dc65e94123e3a7dca72d46c Mon Sep 17 00:00:00 2001 From: Artem Kuzmitckii Date: Tue, 16 Dec 2025 16:17:11 +0000 Subject: [PATCH] [Not_for_merge][AMD][Workaround] Enable bf16 for transform_inference extensioni Details https://github.com/deepspeedai/DeepSpeed/pull/7448#issuecomment-3397953771 Signed-off-by: Artem Kuzmitckii --- .../inference/csrc/{pt_binding.cpp => pt_binding.cu} | 4 ++++ op_builder/transformer_inference.py | 6 +++--- 2 files changed, 7 insertions(+), 3 deletions(-) rename csrc/transformer/inference/csrc/{pt_binding.cpp => pt_binding.cu} (99%) diff --git a/csrc/transformer/inference/csrc/pt_binding.cpp b/csrc/transformer/inference/csrc/pt_binding.cu similarity index 99% rename from csrc/transformer/inference/csrc/pt_binding.cpp rename to csrc/transformer/inference/csrc/pt_binding.cu index 19dbe73726f7..bc524fc83f88 100644 --- a/csrc/transformer/inference/csrc/pt_binding.cpp +++ b/csrc/transformer/inference/csrc/pt_binding.cu @@ -11,6 +11,10 @@ #include "inference_cublas_wrappers.h" #include "inference_cuda_layers.h" +#ifdef BF16_AVAILABLE +#include +#endif + std::array gemm_algos = std::array({99, 99, 99}); // NOTE: This activation function type enum should be always in sync diff --git a/op_builder/transformer_inference.py b/op_builder/transformer_inference.py index 3afa74dc31c2..0c097b59eac3 100755 --- a/op_builder/transformer_inference.py +++ b/op_builder/transformer_inference.py @@ -55,7 +55,7 @@ def filter_ccs(self, ccs): def sources(self): return [ - 'csrc/transformer/inference/csrc/pt_binding.cpp', + 'csrc/transformer/inference/csrc/pt_binding.cu', 'csrc/transformer/inference/csrc/gelu.cu', 'csrc/transformer/inference/csrc/relu.cu', 'csrc/transformer/inference/csrc/layer_norm.cu', @@ -83,6 +83,6 @@ def nvcc_args(self): This cannot be avoided via forward declarations for this transformer_inference extension, since `pt_binding.cpp` code explicitly requires the BF16 header, so disable it for now. """ - if self.is_rocm_pytorch(): - self.enable_bf16 = False + #if self.is_rocm_pytorch(): + # self.enable_bf16 = False return args