diff --git a/csrc/transformer/inference/csrc/pt_binding.cpp b/csrc/transformer/inference/csrc/pt_binding.cu similarity index 99% rename from csrc/transformer/inference/csrc/pt_binding.cpp rename to csrc/transformer/inference/csrc/pt_binding.cu index 19dbe73726f7..bc524fc83f88 100644 --- a/csrc/transformer/inference/csrc/pt_binding.cpp +++ b/csrc/transformer/inference/csrc/pt_binding.cu @@ -11,6 +11,10 @@ #include "inference_cublas_wrappers.h" #include "inference_cuda_layers.h" +#ifdef BF16_AVAILABLE +#include +#endif + std::array gemm_algos = std::array({99, 99, 99}); // NOTE: This activation function type enum should be always in sync diff --git a/op_builder/transformer_inference.py b/op_builder/transformer_inference.py index 3afa74dc31c2..0c097b59eac3 100755 --- a/op_builder/transformer_inference.py +++ b/op_builder/transformer_inference.py @@ -55,7 +55,7 @@ def filter_ccs(self, ccs): def sources(self): return [ - 'csrc/transformer/inference/csrc/pt_binding.cpp', + 'csrc/transformer/inference/csrc/pt_binding.cu', 'csrc/transformer/inference/csrc/gelu.cu', 'csrc/transformer/inference/csrc/relu.cu', 'csrc/transformer/inference/csrc/layer_norm.cu', @@ -83,6 +83,6 @@ def nvcc_args(self): This cannot be avoided via forward declarations for this transformer_inference extension, since `pt_binding.cpp` code explicitly requires the BF16 header, so disable it for now. """ - if self.is_rocm_pytorch(): - self.enable_bf16 = False + #if self.is_rocm_pytorch(): + # self.enable_bf16 = False return args