diff --git a/torchvision/csrc/ops/cuda/roi_pool_kernel.cu b/torchvision/csrc/ops/cuda/roi_pool_kernel.cu index 3a9374bb438..894531901ea 100644 --- a/torchvision/csrc/ops/cuda/roi_pool_kernel.cu +++ b/torchvision/csrc/ops/cuda/roi_pool_kernel.cu @@ -1,4 +1,5 @@ #include +#include #include #include #include @@ -63,10 +64,14 @@ __global__ void roi_pool_forward_kernel_impl( int maxidx = -1; const T* offset_input = input + (roi_batch_ind * channels + c) * height * width; + using acc_t = at::acc_type; for (int h = hstart; h < hend; ++h) { for (int w = wstart; w < wend; ++w) { int input_index = h * width + w; - if (offset_input[input_index] > maxval) { + acc_t v = static_cast(offset_input[input_index]); + acc_t mv = static_cast(maxval); + + if (v > mv) { maxval = offset_input[input_index]; maxidx = input_index; }