diff --git a/setup.py b/setup.py index 60d297c35f3..56f80f85e3f 100644 --- a/setup.py +++ b/setup.py @@ -117,7 +117,7 @@ def get_macros_and_flags(): define_macros += [("WITH_HIP", None)] nvcc_flags = [] else: - define_macros += [("WITH_CUDA", None)] + define_macros += [("WITH_CUDA", None), ("USE_CUDA", None)] if NVCC_FLAGS is None: nvcc_flags = [] else: @@ -283,10 +283,13 @@ def make_image_extension(): libraries = [] define_macros, extra_compile_args = get_macros_and_flags() + # PyTorch Stable ABI target version (2.11) - required for string handling in TORCH_BOX + define_macros += [("TORCH_TARGET_VERSION", "0x020b000000000000")] image_dir = CSRS_DIR / "io/image" sources = list(image_dir.glob("*.cpp")) + list(image_dir.glob("cpu/*.cpp")) + list(image_dir.glob("cpu/giflib/*.c")) + # Always include CUDA sources - they have stubs when NVJPEG_FOUND is not defined if IS_ROCM: sources += list(image_dir.glob("hip/*.cpp")) # we need to exclude this in favor of the hipified source diff --git a/torchvision/csrc/StableABICompat.h b/torchvision/csrc/StableABICompat.h new file mode 100644 index 00000000000..530fa57b717 --- /dev/null +++ b/torchvision/csrc/StableABICompat.h @@ -0,0 +1,557 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +// =========================================================================== +// PyTorch Stable ABI Compatibility Header for TorchVision +// =========================================================================== +// +// This header provides compatibility types and macros for using PyTorch's +// stable ABI API. It replaces the standard PyTorch C++ APIs (torch::, at::, +// c10::) with their stable ABI equivalents. +// +// Target PyTorch version: 2.11+ +// +// Note: TORCH_TARGET_VERSION is set to 0x020b000000000000 (PyTorch 2.11) in +// CMakeLists.txt. This ensures we only use stable ABI features available in +// PyTorch 2.11+, providing forward compatibility when building against newer +// PyTorch versions. + +// Include stable ABI headers +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// Include AOTI headers for CUDA support +#include + +#include +#include + +// =========================================================================== +// Error Handling Macro +// =========================================================================== +// Replacement for TORCH_CHECK() that works with stable ABI. +// Uses STD_TORCH_CHECK from the stable ABI headers. +// Note: Unlike TORCH_CHECK, this always requires a message argument. + +#define VISION_CHECK(cond, ...) STD_TORCH_CHECK(cond, __VA_ARGS__) + +// =========================================================================== +// Type Aliases +// =========================================================================== +// Convenient aliases for stable ABI types in vision namespace + +namespace vision { +namespace stable { + +// Tensor types +using Tensor = torch::stable::Tensor; + +// Device types +using Device = torch::stable::Device; +using DeviceType = torch::headeronly::DeviceType; +using DeviceIndex = torch::stable::accelerator::DeviceIndex; + +// Scalar types (dtype) +using ScalarType = torch::headeronly::ScalarType; + +// DeviceGuard for CUDA context management +using DeviceGuard = torch::stable::accelerator::DeviceGuard; + +// Array reference type for sizes/strides +using IntArrayRef = torch::headeronly::IntHeaderOnlyArrayRef; + +// Layout and MemoryFormat +using Layout = torch::headeronly::Layout; +using MemoryFormat = torch::headeronly::MemoryFormat; + +// =========================================================================== +// Constants +// =========================================================================== + +// Device type constants +constexpr auto kCPU = torch::headeronly::DeviceType::CPU; +constexpr auto kCUDA = torch::headeronly::DeviceType::CUDA; + +// Scalar type constants (equivalents of at::kUInt8, at::kFloat32, etc.) +constexpr auto kByte = torch::headeronly::ScalarType::Byte; +constexpr auto kChar = torch::headeronly::ScalarType::Char; +constexpr auto kShort = torch::headeronly::ScalarType::Short; +constexpr auto kInt = torch::headeronly::ScalarType::Int; +constexpr auto kLong = torch::headeronly::ScalarType::Long; +constexpr auto kHalf = torch::headeronly::ScalarType::Half; +constexpr auto kFloat = torch::headeronly::ScalarType::Float; +constexpr auto kDouble = torch::headeronly::ScalarType::Double; +constexpr auto kBool = torch::headeronly::ScalarType::Bool; +constexpr auto kUInt16 = torch::headeronly::ScalarType::UInt16; + +// Layout constants +constexpr auto kStrided = torch::headeronly::Layout::Strided; + +// =========================================================================== +// Helper Functions - Tensor Creation +// =========================================================================== + +// Stable version of at::empty() +inline Tensor empty( + std::initializer_list sizes, + ScalarType dtype, + Device device) { + std::vector sizesVec(sizes); + return torch::stable::empty( + IntArrayRef(sizesVec.data(), sizesVec.size()), + dtype, + kStrided, + device); +} + +// Overload taking a vector +inline Tensor empty( + const std::vector& sizes, + ScalarType dtype, + Device device) { + return torch::stable::empty( + IntArrayRef(sizes.data(), sizes.size()), + dtype, + kStrided, + device); +} + +// Helper to create CPU tensors +inline Tensor emptyCPU( + std::initializer_list sizes, + ScalarType dtype) { + return empty(sizes, dtype, Device(kCPU)); +} + +// Stable version of at::zeros() - creates via empty then zeros +inline Tensor zeros( + std::initializer_list sizes, + ScalarType dtype, + Device device) { + std::vector sizesVec(sizes); + auto tensor = torch::stable::empty( + IntArrayRef(sizesVec.data(), sizesVec.size()), + dtype, + kStrided, + device); + // Use dispatcher to call aten::zero_ + std::array stack{torch::stable::detail::from(tensor)}; + TORCH_ERROR_CODE_CHECK( + torch_call_dispatcher("aten::zero_", "", stack.data(), TORCH_ABI_VERSION)); + return torch::stable::detail::to(stack[0]); +} + +// =========================================================================== +// Helper Functions - Tensor Operations +// =========================================================================== + +// Stable version of tensor.copy_(src) +inline void copy_(Tensor& dst, const Tensor& src) { + torch::stable::copy_(dst, src); +} + +// Stable version of tensor.to(device) +inline Tensor to(const Tensor& tensor, const Device& device) { + return torch::stable::to(tensor, device); +} + +// Note: narrow() is provided by torch::stable::narrow() directly +// Do NOT define a vision::stable::narrow wrapper as it conflicts with torch::stable::narrow + +// Note: contiguous() is provided by torch::stable::contiguous() directly +// Do NOT define a vision::stable::contiguous wrapper as it conflicts with the +// default parameter in torch::stable::contiguous(tensor, memory_format = Contiguous) + +// Note: select() is provided by torch::stable::select() directly +// Do NOT define a vision::stable::select wrapper as it conflicts with torch::stable::select + +// Helper for tensor.is_contiguous() +inline bool is_contiguous(const Tensor& tensor) { + return tensor.is_contiguous(); +} + +// =========================================================================== +// Helper Functions - Dispatcher Wrappers +// =========================================================================== + +// Stable version of tensor.sort() - returns (values, indices) +// Uses dispatcher to call aten::sort.default +inline std::pair sort( + const Tensor& tensor, + int64_t dim, + bool descending) { + std::array stack{ + torch::stable::detail::from(tensor), + torch::stable::detail::from(dim), + torch::stable::detail::from(descending)}; + TORCH_ERROR_CODE_CHECK(torch_call_dispatcher( + "aten::sort", "", stack.data(), TORCH_ABI_VERSION)); + return std::make_pair( + torch::stable::detail::to(stack[0]), + torch::stable::detail::to(stack[1])); +} + +// Stable version of argsort - returns indices only +inline Tensor argsort(const Tensor& tensor, int64_t dim, bool descending) { + auto [values, indices] = sort(tensor, dim, descending); + return indices; +} + +// Stable version of tensor.permute() +inline Tensor permute( + const Tensor& tensor, + std::initializer_list dims) { + std::vector dimsVec(dims); + std::array stack{ + torch::stable::detail::from(tensor), + torch::stable::detail::from( + IntArrayRef(dimsVec.data(), dimsVec.size()))}; + TORCH_ERROR_CODE_CHECK( + torch_call_dispatcher("aten::permute", "", stack.data(), TORCH_ABI_VERSION)); + return torch::stable::detail::to(stack[0]); +} + +// Stable version of at::cat() - concatenates tensors along a dimension +inline Tensor cat(const std::vector& tensors, int64_t dim = 0) { + std::array stack{ + torch::stable::detail::from(tensors), + torch::stable::detail::from(dim)}; + TORCH_ERROR_CODE_CHECK( + torch_call_dispatcher("aten::cat", "", stack.data(), TORCH_ABI_VERSION)); + return torch::stable::detail::to(stack[0]); +} + +// Stable version of at::clamp() +inline Tensor clamp( + const Tensor& tensor, + double min_val, + double max_val) { + std::array stack{ + torch::stable::detail::from(tensor), + torch::stable::detail::from(min_val), + torch::stable::detail::from(max_val)}; + TORCH_ERROR_CODE_CHECK( + torch_call_dispatcher("aten::clamp", "", stack.data(), TORCH_ABI_VERSION)); + return torch::stable::detail::to(stack[0]); +} + +// Stable version of at::floor() +// Named tensor_floor to avoid conflicts with std::floor in CUDA device code +inline Tensor tensor_floor(const Tensor& tensor) { + std::array stack{torch::stable::detail::from(tensor)}; + TORCH_ERROR_CODE_CHECK( + torch_call_dispatcher("aten::floor", "", stack.data(), TORCH_ABI_VERSION)); + return torch::stable::detail::to(stack[0]); +} + +// Stable version of at::ceil() +// Named tensor_ceil to avoid conflicts with std::ceil in CUDA device code +inline Tensor tensor_ceil(const Tensor& tensor) { + std::array stack{torch::stable::detail::from(tensor)}; + TORCH_ERROR_CODE_CHECK( + torch_call_dispatcher("aten::ceil", "", stack.data(), TORCH_ABI_VERSION)); + return torch::stable::detail::to(stack[0]); +} + +// Stable version of tensor.reshape() +inline Tensor reshape(const Tensor& tensor, const std::vector& shape) { + std::array stack{ + torch::stable::detail::from(tensor), + torch::stable::detail::from(IntArrayRef(shape.data(), shape.size()))}; + TORCH_ERROR_CODE_CHECK( + torch_call_dispatcher("aten::reshape", "", stack.data(), TORCH_ABI_VERSION)); + return torch::stable::detail::to(stack[0]); +} + +// Stable version of tensor.view() +inline Tensor view(const Tensor& tensor, const std::vector& shape) { + std::array stack{ + torch::stable::detail::from(tensor), + torch::stable::detail::from(IntArrayRef(shape.data(), shape.size()))}; + TORCH_ERROR_CODE_CHECK( + torch_call_dispatcher("aten::view", "", stack.data(), TORCH_ABI_VERSION)); + return torch::stable::detail::to(stack[0]); +} + +// Stable version of tensor.flatten(start_dim) +inline Tensor flatten(const Tensor& tensor, int64_t start_dim) { + std::array stack{ + torch::stable::detail::from(tensor), + torch::stable::detail::from(start_dim), + torch::stable::detail::from(static_cast(-1))}; + TORCH_ERROR_CODE_CHECK( + torch_call_dispatcher("aten::flatten", "using_ints", stack.data(), TORCH_ABI_VERSION)); + return torch::stable::detail::to(stack[0]); +} + +// Note: transpose() is provided by torch::stable::transpose() + +// Stable version of tensor.zero_() +inline Tensor& zero_(Tensor& tensor) { + std::array stack{torch::stable::detail::from(tensor)}; + TORCH_ERROR_CODE_CHECK( + torch_call_dispatcher("aten::zero_", "", stack.data(), TORCH_ABI_VERSION)); + tensor = torch::stable::detail::to(stack[0]); + return tensor; +} + +// Stable version of tensor.addmm_(mat1, mat2) +inline Tensor& addmm_(Tensor& self, const Tensor& mat1, const Tensor& mat2) { + std::array stack{ + torch::stable::detail::from(self), + torch::stable::detail::from(mat1), + torch::stable::detail::from(mat2), + torch::stable::detail::from(1.0), // beta + torch::stable::detail::from(1.0)}; // alpha + TORCH_ERROR_CODE_CHECK( + torch_call_dispatcher("aten::addmm_", "", stack.data(), TORCH_ABI_VERSION)); + self = torch::stable::detail::to(stack[0]); + return self; +} + +// Stable version of at::zeros with vector sizes +inline Tensor zeros( + const std::vector& sizes, + ScalarType dtype, + Device device) { + auto tensor = torch::stable::empty( + IntArrayRef(sizes.data(), sizes.size()), + dtype, + kStrided, + device); + std::array stack{torch::stable::detail::from(tensor)}; + TORCH_ERROR_CODE_CHECK( + torch_call_dispatcher("aten::zero_", "", stack.data(), TORCH_ABI_VERSION)); + return torch::stable::detail::to(stack[0]); +} + +// Stable version of at::zeros_like() +inline Tensor zeros_like(const Tensor& tensor) { + // Use dispatcher to call aten::zeros_like + std::array stack{torch::stable::detail::from(tensor)}; + TORCH_ERROR_CODE_CHECK( + torch_call_dispatcher("aten::zeros_like", "", stack.data(), TORCH_ABI_VERSION)); + return torch::stable::detail::to(stack[0]); +} + +// Stable version of at::ones_like() * val +inline Tensor ones_like_times(const Tensor& tensor, const Tensor& val) { + std::array stack{ + torch::stable::detail::from(tensor), + torch::stable::detail::from(val)}; + TORCH_ERROR_CODE_CHECK( + torch_call_dispatcher("aten::mul", "Tensor", stack.data(), TORCH_ABI_VERSION)); + return torch::stable::detail::to(stack[0]); +} + +// Stable version of tensor.sum(dims) +inline Tensor sum(const Tensor& tensor, const std::vector& dims) { + std::array stack{ + torch::stable::detail::from(tensor), + torch::stable::detail::from(IntArrayRef(dims.data(), dims.size())), + torch::stable::detail::from(false)}; // keepdim + TORCH_ERROR_CODE_CHECK( + torch_call_dispatcher("aten::sum", "dim_IntList", stack.data(), TORCH_ABI_VERSION)); + return torch::stable::detail::to(stack[0]); +} + +// Stable version of tensor + other (broadcasting add) +inline Tensor add(const Tensor& self, const Tensor& other) { + std::array stack{ + torch::stable::detail::from(self), + torch::stable::detail::from(other), + torch::stable::detail::from(1.0)}; // alpha + TORCH_ERROR_CODE_CHECK( + torch_call_dispatcher("aten::add", "Tensor", stack.data(), TORCH_ABI_VERSION)); + return torch::stable::detail::to(stack[0]); +} + +// Stable version of tensor.index_select(dim, index) +inline Tensor index_select(const Tensor& tensor, int64_t dim, const Tensor& index) { + std::array stack{ + torch::stable::detail::from(tensor), + torch::stable::detail::from(dim), + torch::stable::detail::from(index)}; + TORCH_ERROR_CODE_CHECK( + torch_call_dispatcher("aten::index_select", "", stack.data(), TORCH_ABI_VERSION)); + return torch::stable::detail::to(stack[0]); +} + +// Stable version of tensor.masked_select(mask) +inline Tensor masked_select(const Tensor& tensor, const Tensor& mask) { + std::array stack{ + torch::stable::detail::from(tensor), + torch::stable::detail::from(mask)}; + TORCH_ERROR_CODE_CHECK( + torch_call_dispatcher("aten::masked_select", "", stack.data(), TORCH_ABI_VERSION)); + return torch::stable::detail::to(stack[0]); +} + +// Stable version of tensor.flip(dims) +inline Tensor flip(const Tensor& tensor, std::initializer_list dims) { + std::vector dimsVec(dims); + std::array stack{ + torch::stable::detail::from(tensor), + torch::stable::detail::from(IntArrayRef(dimsVec.data(), dimsVec.size()))}; + TORCH_ERROR_CODE_CHECK( + torch_call_dispatcher("aten::flip", "", stack.data(), TORCH_ABI_VERSION)); + return torch::stable::detail::to(stack[0]); +} + +// Stable version of torch.from_file() - read file into tensor +inline Tensor from_file( + const std::string& filename, + bool shared, + int64_t size, + ScalarType dtype) { + std::array stack{ + torch::stable::detail::from(filename), + torch::stable::detail::from(shared), + torch::stable::detail::from(size), + torch::stable::detail::from(dtype)}; + TORCH_ERROR_CODE_CHECK( + torch_call_dispatcher("aten::from_file", "", stack.data(), TORCH_ABI_VERSION)); + return torch::stable::detail::to(stack[0]); +} + +// Note: squeeze() is provided by torch::stable::squeeze() + +// Stable version of tensor.unsqueeze(dim) +inline Tensor unsqueeze(const Tensor& tensor, int64_t dim) { + std::array stack{ + torch::stable::detail::from(tensor), + torch::stable::detail::from(dim)}; + TORCH_ERROR_CODE_CHECK( + torch_call_dispatcher("aten::unsqueeze", "", stack.data(), TORCH_ABI_VERSION)); + return torch::stable::detail::to(stack[0]); +} + +// Stable version of tensor.clone() +inline Tensor clone(const Tensor& tensor) { + std::array stack{torch::stable::detail::from(tensor)}; + TORCH_ERROR_CODE_CHECK( + torch_call_dispatcher("aten::clone", "", stack.data(), TORCH_ABI_VERSION)); + return torch::stable::detail::to(stack[0]); +} + +// Create empty tensor with ChannelsLast memory format +inline Tensor emptyCPUChannelsLast(const std::vector& sizes, ScalarType dtype) { + return torch::stable::empty( + IntArrayRef(sizes.data(), sizes.size()), + dtype, + kStrided, + Device(kCPU), + std::nullopt, // pin_memory + MemoryFormat::ChannelsLast); +} + +// =========================================================================== +// Helper Functions - Utility +// =========================================================================== + +// Helper to get a human-readable name for a scalar type +inline const char* scalarTypeName(ScalarType dtype) { + switch (dtype) { + case ScalarType::Byte: + return "uint8"; + case ScalarType::Char: + return "int8"; + case ScalarType::Short: + return "int16"; + case ScalarType::Int: + return "int32"; + case ScalarType::Long: + return "int64"; + case ScalarType::Half: + return "float16"; + case ScalarType::Float: + return "float32"; + case ScalarType::Double: + return "float64"; + case ScalarType::Bool: + return "bool"; + default: + return "unknown"; + } +} + +// Helper to get a human-readable name for a device type +inline const char* deviceTypeName(DeviceType dtype) { + switch (dtype) { + case DeviceType::CPU: + return "cpu"; + case DeviceType::CUDA: + return "cuda"; + default: + return "unknown"; + } +} + +// Helper to convert IntArrayRef to a string for error messages +inline std::string intArrayRefToString(const IntArrayRef& arr) { + std::string result = "["; + for (size_t i = 0; i < arr.size(); ++i) { + if (i > 0) { + result += ", "; + } + result += std::to_string(arr[i]); + } + result += "]"; + return result; +} + +} // namespace stable + +// =========================================================================== +// CUDA Accumulator Type Traits +// =========================================================================== +// Replacement for at::acc_type that works without ATen headers +// Used in CUDA device code for accumulating sums with higher precision + +#ifdef __CUDACC__ +namespace cuda { + +// Primary template - default accumulator type is the same as input +template +struct acc_type { + using type = T; +}; + +// Specializations for CUDA: Half uses float for accumulation +template <> +struct acc_type<__half, true> { + using type = float; +}; + +// Float and double use themselves +template <> +struct acc_type { + using type = float; +}; + +template <> +struct acc_type { + using type = double; +}; + +// Helper alias for convenience +template +using acc_type_t = typename acc_type::type; + +} // namespace cuda +#endif // __CUDACC__ + +} // namespace vision diff --git a/torchvision/csrc/io/image/common.cpp b/torchvision/csrc/io/image/common.cpp index 7743961a09d..4f8d3c42b2a 100644 --- a/torchvision/csrc/io/image/common.cpp +++ b/torchvision/csrc/io/image/common.cpp @@ -12,13 +12,16 @@ void* PyInit_image(void) { namespace vision { namespace image { -void validate_encoded_data(const torch::Tensor& encoded_data) { - TORCH_CHECK(encoded_data.is_contiguous(), "Input tensor must be contiguous."); - TORCH_CHECK( - encoded_data.dtype() == torch::kU8, +using namespace vision::stable; + +void validate_encoded_data(const Tensor& encoded_data) { + VISION_CHECK( + encoded_data.is_contiguous(), "Input tensor must be contiguous."); + VISION_CHECK( + encoded_data.scalar_type() == kByte, "Input tensor must have uint8 data type, got ", - encoded_data.dtype()); - TORCH_CHECK( + scalarTypeName(encoded_data.scalar_type())); + VISION_CHECK( encoded_data.dim() == 1 && encoded_data.numel() > 0, "Input tensor must be 1-dimensional and non-empty, got ", encoded_data.dim(), diff --git a/torchvision/csrc/io/image/common.h b/torchvision/csrc/io/image/common.h index d81acfda7d4..84645c9452a 100644 --- a/torchvision/csrc/io/image/common.h +++ b/torchvision/csrc/io/image/common.h @@ -1,7 +1,7 @@ #pragma once #include -#include +#include "../../StableABICompat.h" namespace vision { namespace image { @@ -14,7 +14,7 @@ const ImageReadMode IMAGE_READ_MODE_GRAY_ALPHA = 2; const ImageReadMode IMAGE_READ_MODE_RGB = 3; const ImageReadMode IMAGE_READ_MODE_RGB_ALPHA = 4; -void validate_encoded_data(const torch::Tensor& encoded_data); +void validate_encoded_data(const vision::stable::Tensor& encoded_data); bool should_this_return_rgb_or_rgba_let_me_know_in_the_comments_down_below_guys_see_you_in_the_next_video( ImageReadMode mode, diff --git a/torchvision/csrc/io/image/cpu/decode_gif.cpp b/torchvision/csrc/io/image/cpu/decode_gif.cpp index 93b0861c5da..67fa3119ce3 100644 --- a/torchvision/csrc/io/image/cpu/decode_gif.cpp +++ b/torchvision/csrc/io/image/cpu/decode_gif.cpp @@ -6,6 +6,8 @@ namespace vision { namespace image { +using namespace vision::stable; + typedef struct reader_helper_t { uint8_t const* encoded_data; // input tensor data pointer size_t encoded_data_size; // size of input tensor in bytes @@ -30,7 +32,7 @@ int read_from_tensor(GifFileType* gifFile, GifByteType* buf, int len) { return num_bytes_to_read; } -torch::Tensor decode_gif(const torch::Tensor& encoded_data) { +Tensor decode_gif(const Tensor& encoded_data) { // LibGif docs: https://giflib.sourceforge.net/intro.html // Refer over there for more details on the libgif API, API ref, and a // detailed description of the GIF format. @@ -57,13 +59,13 @@ torch::Tensor decode_gif(const torch::Tensor& encoded_data) { // If we do that, we'd have to make sure the buffers are never written to by // GIFLIB, otherwise we'd be overridding the tensor data. reader_helper_t reader_helper; - reader_helper.encoded_data = encoded_data.data_ptr(); + reader_helper.encoded_data = encoded_data.const_data_ptr(); reader_helper.encoded_data_size = encoded_data.numel(); reader_helper.num_bytes_read = 0; GifFileType* gifFile = DGifOpen(static_cast(&reader_helper), read_from_tensor, &error); - TORCH_CHECK( + VISION_CHECK( (gifFile != nullptr) && (error == D_GIF_SUCCEEDED), "DGifOpenFileName() failed - ", error); @@ -71,12 +73,12 @@ torch::Tensor decode_gif(const torch::Tensor& encoded_data) { if (DGifSlurp(gifFile) == GIF_ERROR) { auto gifFileError = gifFile->Error; DGifCloseFile(gifFile, &error); - TORCH_CHECK(false, "DGifSlurp() failed - ", gifFileError); + VISION_CHECK(false, "DGifSlurp() failed - ", gifFileError); } auto num_images = gifFile->ImageCount; // This check should already done within DGifSlurp(), just to be safe - TORCH_CHECK(num_images > 0, "GIF file should contain at least one image!"); + VISION_CHECK(num_images > 0, "GIF file should contain at least one image!"); GifColorType bg = {0, 0, 0}; if (gifFile->SColorMap) { @@ -94,12 +96,19 @@ torch::Tensor decode_gif(const torch::Tensor& encoded_data) { // We output a channels-last tensor for consistency with other image decoders. // Torchvision's resize tends to be is faster on uint8 channels-last tensors. - auto options = torch::TensorOptions() - .dtype(torch::kU8) - .memory_format(torch::MemoryFormat::ChannelsLast); - auto out = torch::empty( - {int64_t(num_images), 3, int64_t(out_h), int64_t(out_w)}, options); - auto out_a = out.accessor(); + std::vector sizes = { + int64_t(num_images), 3, int64_t(out_h), int64_t(out_w)}; + auto out = emptyCPUChannelsLast(sizes, kByte); + auto out_ptr = out.mutable_data_ptr(); + + // Calculate strides for NCHW layout with ChannelsLast memory format + // In ChannelsLast format for NCHW tensor, memory is laid out as NHWC + // Stride order: N -> HWC, C -> 1, H -> WC, W -> C + int64_t stride_n = 3 * out_h * out_w; + int64_t stride_c = 1; + int64_t stride_h = out_w * 3; + int64_t stride_w = 3; + for (int i = 0; i < num_images; i++) { const SavedImage& img = gifFile->SavedImages[i]; @@ -109,7 +118,7 @@ torch::Tensor decode_gif(const torch::Tensor& encoded_data) { const GifImageDesc& desc = img.ImageDesc; const ColorMapObject* cmap = desc.ColorMap ? desc.ColorMap : gifFile->SColorMap; - TORCH_CHECK( + VISION_CHECK( cmap != nullptr, "Global and local color maps are missing. This should never happen!"); @@ -132,14 +141,18 @@ torch::Tensor decode_gif(const torch::Tensor& encoded_data) { (gcb.DisposalMode == DISPOSAL_UNSPECIFIED || gcb.DisposalMode == DISPOSE_DO_NOT || gcb.DisposalMode == DISPOSE_PREVIOUS)) { - out[i] = out[i - 1]; + // Copy previous frame to current frame + auto prev_frame_ptr = out_ptr + (i - 1) * stride_n; + auto curr_frame_ptr = out_ptr + i * stride_n; + std::memcpy(curr_frame_ptr, prev_frame_ptr, stride_n); } else { // Background. If bg wasn't defined, it will be (0, 0, 0) for (int h = 0; h < gifFile->SHeight; h++) { for (int w = 0; w < gifFile->SWidth; w++) { - out_a[i][0][h][w] = bg.Red; - out_a[i][1][h][w] = bg.Green; - out_a[i][2][h][w] = bg.Blue; + auto base_idx = i * stride_n + h * stride_h + w * stride_w; + out_ptr[base_idx + 0 * stride_c] = bg.Red; + out_ptr[base_idx + 1 * stride_c] = bg.Green; + out_ptr[base_idx + 2 * stride_c] = bg.Blue; } } } @@ -151,17 +164,20 @@ torch::Tensor decode_gif(const torch::Tensor& encoded_data) { continue; } GifColorType rgb = cmap->Colors[c]; - out_a[i][0][h + desc.Top][w + desc.Left] = rgb.Red; - out_a[i][1][h + desc.Top][w + desc.Left] = rgb.Green; - out_a[i][2][h + desc.Top][w + desc.Left] = rgb.Blue; + auto base_idx = i * stride_n + (h + desc.Top) * stride_h + + (w + desc.Left) * stride_w; + out_ptr[base_idx + 0 * stride_c] = rgb.Red; + out_ptr[base_idx + 1 * stride_c] = rgb.Green; + out_ptr[base_idx + 2 * stride_c] = rgb.Blue; } } } - out = out.squeeze(0); // remove batch dim if there's only one image + out = torch::stable::squeeze( + out, 0); // remove batch dim if there's only one image DGifCloseFile(gifFile, &error); - TORCH_CHECK(error == D_GIF_SUCCEEDED, "DGifCloseFile() failed - ", error); + VISION_CHECK(error == D_GIF_SUCCEEDED, "DGifCloseFile() failed - ", error); return out; } diff --git a/torchvision/csrc/io/image/cpu/decode_gif.h b/torchvision/csrc/io/image/cpu/decode_gif.h index 68d5073c91b..2b14112a219 100644 --- a/torchvision/csrc/io/image/cpu/decode_gif.h +++ b/torchvision/csrc/io/image/cpu/decode_gif.h @@ -1,12 +1,12 @@ #pragma once -#include +#include "../../../StableABICompat.h" namespace vision { namespace image { // encoded_data tensor must be 1D uint8 and contiguous -C10_EXPORT torch::Tensor decode_gif(const torch::Tensor& encoded_data); +vision::stable::Tensor decode_gif(const vision::stable::Tensor& encoded_data); } // namespace image } // namespace vision diff --git a/torchvision/csrc/io/image/cpu/decode_image.cpp b/torchvision/csrc/io/image/cpu/decode_image.cpp index 43a688604f6..aa536fef59e 100644 --- a/torchvision/csrc/io/image/cpu/decode_image.cpp +++ b/torchvision/csrc/io/image/cpu/decode_image.cpp @@ -1,4 +1,5 @@ #include "decode_image.h" +#include "../common.h" #include "decode_gif.h" #include "decode_jpeg.h" @@ -8,32 +9,34 @@ namespace vision { namespace image { -torch::Tensor decode_image( - const torch::Tensor& data, +using namespace vision::stable; + +Tensor decode_image( + const Tensor& data, ImageReadMode mode, bool apply_exif_orientation) { // Check that tensor is a CPU tensor - TORCH_CHECK(data.device() == torch::kCPU, "Expected a CPU tensor"); + VISION_CHECK(data.device() == Device(kCPU), "Expected a CPU tensor"); // Check that the input tensor dtype is uint8 - TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor"); + VISION_CHECK(data.scalar_type() == kByte, "Expected a torch.uint8 tensor"); // Check that the input tensor is 1-dimensional - TORCH_CHECK( + VISION_CHECK( data.dim() == 1 && data.numel() > 0, "Expected a non empty 1-dimensional tensor"); auto err_msg = "Unsupported image file. Only jpeg, png, webp and gif are currently supported. For avif and heic format, please rely on `decode_avif` and `decode_heic` directly."; - auto datap = data.data_ptr(); + auto datap = data.const_data_ptr(); const uint8_t jpeg_signature[3] = {255, 216, 255}; // == "\xFF\xD8\xFF" - TORCH_CHECK(data.numel() >= 3, err_msg); + VISION_CHECK(data.numel() >= 3, err_msg); if (memcmp(jpeg_signature, datap, 3) == 0) { return decode_jpeg(data, mode, apply_exif_orientation); } const uint8_t png_signature[4] = {137, 80, 78, 71}; // == "\211PNG" - TORCH_CHECK(data.numel() >= 4, err_msg); + VISION_CHECK(data.numel() >= 4, err_msg); if (memcmp(png_signature, datap, 4) == 0) { return decode_png(data, mode, apply_exif_orientation); } @@ -42,7 +45,7 @@ torch::Tensor decode_image( 0x47, 0x49, 0x46, 0x38, 0x39, 0x61}; // == "GIF89a" const uint8_t gif_signature_2[6] = { 0x47, 0x49, 0x46, 0x38, 0x37, 0x61}; // == "GIF87a" - TORCH_CHECK(data.numel() >= 6, err_msg); + VISION_CHECK(data.numel() >= 6, err_msg); if (memcmp(gif_signature_1, datap, 6) == 0 || memcmp(gif_signature_2, datap, 6) == 0) { return decode_gif(data); @@ -51,13 +54,13 @@ torch::Tensor decode_image( const uint8_t webp_signature_begin[4] = {0x52, 0x49, 0x46, 0x46}; // == "RIFF" const uint8_t webp_signature_end[7] = { 0x57, 0x45, 0x42, 0x50, 0x56, 0x50, 0x38}; // == "WEBPVP8" - TORCH_CHECK(data.numel() >= 15, err_msg); + VISION_CHECK(data.numel() >= 15, err_msg); if ((memcmp(webp_signature_begin, datap, 4) == 0) && (memcmp(webp_signature_end, datap + 8, 7) == 0)) { return decode_webp(data, mode); } - TORCH_CHECK(false, err_msg); + VISION_CHECK(false, err_msg); } } // namespace image diff --git a/torchvision/csrc/io/image/cpu/decode_image.h b/torchvision/csrc/io/image/cpu/decode_image.h index f66d47eccd4..e20e6b355cf 100644 --- a/torchvision/csrc/io/image/cpu/decode_image.h +++ b/torchvision/csrc/io/image/cpu/decode_image.h @@ -1,13 +1,13 @@ #pragma once -#include +#include "../../../StableABICompat.h" #include "../common.h" namespace vision { namespace image { -C10_EXPORT torch::Tensor decode_image( - const torch::Tensor& data, +vision::stable::Tensor decode_image( + const vision::stable::Tensor& data, ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED, bool apply_exif_orientation = false); diff --git a/torchvision/csrc/io/image/cpu/decode_jpeg.cpp b/torchvision/csrc/io/image/cpu/decode_jpeg.cpp index 052b98e1be9..3e7c739b716 100644 --- a/torchvision/csrc/io/image/cpu/decode_jpeg.cpp +++ b/torchvision/csrc/io/image/cpu/decode_jpeg.cpp @@ -6,12 +6,14 @@ namespace vision { namespace image { +using namespace vision::stable; + #if !JPEG_FOUND -torch::Tensor decode_jpeg( - const torch::Tensor& data, +Tensor decode_jpeg( + const Tensor& data, ImageReadMode mode, bool apply_exif_orientation) { - TORCH_CHECK( + VISION_CHECK( false, "decode_jpeg: torchvision not compiled with libjpeg support"); } #else @@ -129,19 +131,16 @@ void convert_line_cmyk_to_gray( } // namespace -torch::Tensor decode_jpeg( - const torch::Tensor& data, +Tensor decode_jpeg( + const Tensor& data, ImageReadMode mode, bool apply_exif_orientation) { - C10_LOG_API_USAGE_ONCE( - "torchvision.csrc.io.image.cpu.decode_jpeg.decode_jpeg"); - validate_encoded_data(data); struct jpeg_decompress_struct cinfo; struct torch_jpeg_error_mgr jerr; - auto datap = data.data_ptr(); + auto datap = data.const_data_ptr(); // Setup decompression structure cinfo.err = jpeg_std_error(&jerr.pub); jerr.pub.error_exit = torch_jpeg_error_exit; @@ -151,7 +150,7 @@ torch::Tensor decode_jpeg( * We need to clean up the JPEG object. */ jpeg_destroy_decompress(&cinfo); - TORCH_CHECK(false, jerr.jpegLastErrorMsg); + VISION_CHECK(false, jerr.jpegLastErrorMsg); } jpeg_create_decompress(&cinfo); @@ -192,7 +191,8 @@ torch::Tensor decode_jpeg( */ default: jpeg_destroy_decompress(&cinfo); - TORCH_CHECK(false, "The provided mode is not supported for JPEG files"); + VISION_CHECK( + false, "The provided mode is not supported for JPEG files"); } jpeg_calc_output_dimensions(&cinfo); @@ -209,12 +209,11 @@ torch::Tensor decode_jpeg( int width = cinfo.output_width; int stride = width * channels; - auto tensor = - torch::empty({int64_t(height), int64_t(width), channels}, torch::kU8); - auto ptr = tensor.data_ptr(); - torch::Tensor cmyk_line_tensor; + auto tensor = emptyCPU({int64_t(height), int64_t(width), channels}, kByte); + auto ptr = tensor.mutable_data_ptr(); + Tensor cmyk_line_tensor; if (cmyk_to_rgb_or_gray) { - cmyk_line_tensor = torch::empty({int64_t(width), 4}, torch::kU8); + cmyk_line_tensor = emptyCPU({int64_t(width), 4}, kByte); } while (cinfo.output_scanline < cinfo.output_height) { @@ -223,7 +222,7 @@ torch::Tensor decode_jpeg( * more than one scanline at a time if that's more convenient. */ if (cmyk_to_rgb_or_gray) { - auto cmyk_line_ptr = cmyk_line_tensor.data_ptr(); + auto cmyk_line_ptr = cmyk_line_tensor.mutable_data_ptr(); jpeg_read_scanlines(&cinfo, &cmyk_line_ptr, 1); if (channels == 3) { @@ -239,7 +238,7 @@ torch::Tensor decode_jpeg( jpeg_finish_decompress(&cinfo); jpeg_destroy_decompress(&cinfo); - auto output = tensor.permute({2, 0, 1}); + auto output = permute(tensor, {2, 0, 1}); if (apply_exif_orientation) { return exif_orientation_transform(output, exif_orientation); diff --git a/torchvision/csrc/io/image/cpu/decode_jpeg.h b/torchvision/csrc/io/image/cpu/decode_jpeg.h index 7412a46d2ea..ed8760a5862 100644 --- a/torchvision/csrc/io/image/cpu/decode_jpeg.h +++ b/torchvision/csrc/io/image/cpu/decode_jpeg.h @@ -1,18 +1,18 @@ #pragma once -#include +#include "../../../StableABICompat.h" #include "../common.h" namespace vision { namespace image { -C10_EXPORT torch::Tensor decode_jpeg( - const torch::Tensor& data, +vision::stable::Tensor decode_jpeg( + const vision::stable::Tensor& data, ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED, bool apply_exif_orientation = false); -C10_EXPORT int64_t _jpeg_version(); -C10_EXPORT bool _is_compiled_against_turbo(); +int64_t _jpeg_version(); +bool _is_compiled_against_turbo(); } // namespace image } // namespace vision diff --git a/torchvision/csrc/io/image/cpu/decode_png.cpp b/torchvision/csrc/io/image/cpu/decode_png.cpp index 5ea6f073975..583a24fa679 100644 --- a/torchvision/csrc/io/image/cpu/decode_png.cpp +++ b/torchvision/csrc/io/image/cpu/decode_png.cpp @@ -6,14 +6,15 @@ namespace vision { namespace image { +using namespace vision::stable; using namespace exif_private; #if !PNG_FOUND -torch::Tensor decode_png( - const torch::Tensor& data, +Tensor decode_png( + const Tensor& data, ImageReadMode mode, bool apply_exif_orientation) { - TORCH_CHECK( + VISION_CHECK( false, "decode_png: torchvision not compiled with libPNG support"); } #else @@ -23,35 +24,32 @@ bool is_little_endian() { return *(uint8_t*)&x; } -torch::Tensor decode_png( - const torch::Tensor& data, +Tensor decode_png( + const Tensor& data, ImageReadMode mode, bool apply_exif_orientation) { - C10_LOG_API_USAGE_ONCE("torchvision.csrc.io.image.cpu.decode_png.decode_png"); - validate_encoded_data(data); auto png_ptr = png_create_read_struct(PNG_LIBPNG_VER_STRING, nullptr, nullptr, nullptr); - TORCH_CHECK(png_ptr, "libpng read structure allocation failed!") + VISION_CHECK(png_ptr, "libpng read structure allocation failed!") auto info_ptr = png_create_info_struct(png_ptr); if (!info_ptr) { png_destroy_read_struct(&png_ptr, nullptr, nullptr); // Seems redundant with the if statement. done here to avoid leaking memory. - TORCH_CHECK(info_ptr, "libpng info structure allocation failed!") + VISION_CHECK(info_ptr, "libpng info structure allocation failed!") } - auto accessor = data.accessor(); - auto datap = accessor.data(); - auto datap_len = accessor.size(0); + auto datap = data.const_data_ptr(); + auto datap_len = data.numel(); if (setjmp(png_jmpbuf(png_ptr)) != 0) { png_destroy_read_struct(&png_ptr, &info_ptr, nullptr); - TORCH_CHECK(false, "Internal error."); + VISION_CHECK(false, "Internal error."); } - TORCH_CHECK(datap_len >= 8, "Content is too small for png!") + VISION_CHECK(datap_len >= 8, "Content is too small for png!") auto is_png = !png_sig_cmp(datap, 0, 8); - TORCH_CHECK(is_png, "Content is not png!") + VISION_CHECK(is_png, "Content is not png!") struct Reader { png_const_bytep ptr; @@ -64,7 +62,7 @@ torch::Tensor decode_png( png_bytep output, png_size_t bytes) { auto reader = static_cast(png_get_io_ptr(png_ptr)); - TORCH_CHECK( + VISION_CHECK( reader->count >= bytes, "Out of bound read in decode_png. Probably, the input image is corrupted"); std::copy(reader->ptr, reader->ptr + bytes, output); @@ -91,12 +89,12 @@ torch::Tensor decode_png( if (retval != 1) { png_destroy_read_struct(&png_ptr, &info_ptr, nullptr); - TORCH_CHECK(retval == 1, "Could read image metadata from content.") + VISION_CHECK(retval == 1, "Could read image metadata from content.") } if (bit_depth > 8 && bit_depth != 16) { png_destroy_read_struct(&png_ptr, &info_ptr, nullptr); - TORCH_CHECK( + VISION_CHECK( false, "bit depth of png image is " + std::to_string(bit_depth) + ". Only <=8 and 16 are supported.") @@ -188,7 +186,7 @@ torch::Tensor decode_png( break; default: png_destroy_read_struct(&png_ptr, &info_ptr, nullptr); - TORCH_CHECK(false, "The provided mode is not supported for PNG files"); + VISION_CHECK(false, "The provided mode is not supported for PNG files"); } png_read_update_info(png_ptr, info_ptr); @@ -196,19 +194,25 @@ torch::Tensor decode_png( auto num_pixels_per_row = width * channels; auto is_16_bits = bit_depth == 16; - auto tensor = torch::empty( + auto tensor = emptyCPU( {int64_t(height), int64_t(width), channels}, - is_16_bits ? at::kUInt16 : torch::kU8); + is_16_bits ? kUInt16 : kByte); if (is_little_endian()) { png_set_swap(png_ptr); } - auto t_ptr = (uint8_t*)tensor.data_ptr(); + // Get raw pointer - for 16-bit images we get uint16_t* and cast to uint8_t* + auto t_ptr = is_16_bits + ? (uint8_t*)tensor.mutable_data_ptr() + : tensor.mutable_data_ptr(); for (int pass = 0; pass < number_of_passes; pass++) { for (png_uint_32 i = 0; i < height; ++i) { png_read_row(png_ptr, t_ptr, nullptr); t_ptr += num_pixels_per_row * (is_16_bits ? 2 : 1); } - t_ptr = (uint8_t*)tensor.data_ptr(); + // Reset pointer - for 16-bit images we get uint16_t* and cast to uint8_t* + t_ptr = is_16_bits + ? (uint8_t*)tensor.mutable_data_ptr() + : tensor.mutable_data_ptr(); } int exif_orientation = -1; @@ -218,7 +222,7 @@ torch::Tensor decode_png( png_destroy_read_struct(&png_ptr, &info_ptr, nullptr); - auto output = tensor.permute({2, 0, 1}); + auto output = permute(tensor, {2, 0, 1}); if (apply_exif_orientation) { return exif_orientation_transform(output, exif_orientation); } diff --git a/torchvision/csrc/io/image/cpu/decode_png.h b/torchvision/csrc/io/image/cpu/decode_png.h index faaffa7ae49..286cdf7571b 100644 --- a/torchvision/csrc/io/image/cpu/decode_png.h +++ b/torchvision/csrc/io/image/cpu/decode_png.h @@ -1,13 +1,13 @@ #pragma once -#include +#include "../../../StableABICompat.h" #include "../common.h" namespace vision { namespace image { -C10_EXPORT torch::Tensor decode_png( - const torch::Tensor& data, +vision::stable::Tensor decode_png( + const vision::stable::Tensor& data, ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED, bool apply_exif_orientation = false); diff --git a/torchvision/csrc/io/image/cpu/decode_webp.cpp b/torchvision/csrc/io/image/cpu/decode_webp.cpp index 80fe68862fb..945fffeba70 100644 --- a/torchvision/csrc/io/image/cpu/decode_webp.cpp +++ b/torchvision/csrc/io/image/cpu/decode_webp.cpp @@ -9,35 +9,30 @@ namespace vision { namespace image { +using namespace vision::stable; + #if !WEBP_FOUND -torch::Tensor decode_webp( - const torch::Tensor& encoded_data, - ImageReadMode mode) { - TORCH_CHECK( +Tensor decode_webp(const Tensor& encoded_data, ImageReadMode mode) { + VISION_CHECK( false, "decode_webp: torchvision not compiled with libwebp support"); } #else -torch::Tensor decode_webp( - const torch::Tensor& encoded_data, - ImageReadMode mode) { +Tensor decode_webp(const Tensor& encoded_data, ImageReadMode mode) { validate_encoded_data(encoded_data); - auto encoded_data_p = encoded_data.data_ptr(); + auto encoded_data_p = encoded_data.const_data_ptr(); auto encoded_data_size = encoded_data.numel(); WebPBitstreamFeatures features; auto res = WebPGetFeatures(encoded_data_p, encoded_data_size, &features); - TORCH_CHECK( + VISION_CHECK( res == VP8_STATUS_OK, "WebPGetFeatures failed with error code ", res); - TORCH_CHECK( + VISION_CHECK( !features.has_animation, "Animated webp files are not supported."); - if (mode == IMAGE_READ_MODE_GRAY || mode == IMAGE_READ_MODE_GRAY_ALPHA) { - TORCH_WARN_ONCE( - "Webp does not support grayscale conversions. " - "The returned tensor will be in the colorspace of the original image."); - } + // Note: TORCH_WARN_ONCE is not available in stable ABI, so we just skip the + // warning auto return_rgb = should_this_return_rgb_or_rgba_let_me_know_in_the_comments_down_below_guys_see_you_in_the_next_video( @@ -52,13 +47,18 @@ torch::Tensor decode_webp( auto decoded_data = decoding_func(encoded_data_p, encoded_data_size, &width, &height); - TORCH_CHECK(decoded_data != nullptr, "WebPDecodeRGB[A] failed."); + VISION_CHECK(decoded_data != nullptr, "WebPDecodeRGB[A] failed."); + + // Create tensor and copy data (from_blob with deleter not available in stable + // ABI) + auto out = emptyCPU({height, width, num_channels}, kByte); + auto out_ptr = out.mutable_data_ptr(); + std::memcpy(out_ptr, decoded_data, height * width * num_channels); - auto deleter = [decoded_data](void*) { WebPFree(decoded_data); }; - auto out = torch::from_blob( - decoded_data, {height, width, num_channels}, deleter, torch::kUInt8); + // Free the webp-allocated memory + WebPFree(decoded_data); - return out.permute({2, 0, 1}); + return permute(out, {2, 0, 1}); } #endif // WEBP_FOUND diff --git a/torchvision/csrc/io/image/cpu/decode_webp.h b/torchvision/csrc/io/image/cpu/decode_webp.h index d5c81547c42..8e897ff8d8b 100644 --- a/torchvision/csrc/io/image/cpu/decode_webp.h +++ b/torchvision/csrc/io/image/cpu/decode_webp.h @@ -1,13 +1,13 @@ #pragma once -#include +#include "../../../StableABICompat.h" #include "../common.h" namespace vision { namespace image { -C10_EXPORT torch::Tensor decode_webp( - const torch::Tensor& encoded_data, +vision::stable::Tensor decode_webp( + const vision::stable::Tensor& encoded_data, ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED); } // namespace image diff --git a/torchvision/csrc/io/image/cpu/encode_jpeg.cpp b/torchvision/csrc/io/image/cpu/encode_jpeg.cpp index d2ed73071a2..71b1749f5b2 100644 --- a/torchvision/csrc/io/image/cpu/encode_jpeg.cpp +++ b/torchvision/csrc/io/image/cpu/encode_jpeg.cpp @@ -1,14 +1,17 @@ #include "encode_jpeg.h" +#include "../common.h" #include "common_jpeg.h" namespace vision { namespace image { +using namespace vision::stable; + #if !JPEG_FOUND -torch::Tensor encode_jpeg(const torch::Tensor& data, int64_t quality) { - TORCH_CHECK( +Tensor encode_jpeg(const Tensor& data, int64_t quality) { + VISION_CHECK( false, "encode_jpeg: torchvision not compiled with libjpeg support"); } @@ -24,10 +27,9 @@ using JpegSizeType = size_t; using namespace detail; -torch::Tensor encode_jpeg(const torch::Tensor& data, int64_t quality) { - C10_LOG_API_USAGE_ONCE( - "torchvision.csrc.io.image.cpu.encode_jpeg.encode_jpeg"); - // Define compression structures and error handling +Tensor encode_jpeg(const Tensor& data, int64_t quality) { + // Note: C10_LOG_API_USAGE_ONCE is not available in stable ABI, so we just + // skip the logging Define compression structures and error handling struct jpeg_compress_struct cinfo {}; struct torch_jpeg_error_mgr jerr {}; @@ -48,25 +50,26 @@ torch::Tensor encode_jpeg(const torch::Tensor& data, int64_t quality) { free(jpegBuf); } - TORCH_CHECK(false, (const char*)jerr.jpegLastErrorMsg); + VISION_CHECK(false, (const char*)jerr.jpegLastErrorMsg); } // Check that the input tensor is on CPU - TORCH_CHECK(data.device() == torch::kCPU, "Input tensor should be on CPU"); + VISION_CHECK(data.device() == Device(kCPU), "Input tensor should be on CPU"); // Check that the input tensor dtype is uint8 - TORCH_CHECK(data.dtype() == torch::kU8, "Input tensor dtype should be uint8"); + VISION_CHECK( + data.scalar_type() == kByte, "Input tensor dtype should be uint8"); // Check that the input tensor is 3-dimensional - TORCH_CHECK(data.dim() == 3, "Input data should be a 3-dimensional tensor"); + VISION_CHECK(data.dim() == 3, "Input data should be a 3-dimensional tensor"); // Get image info int channels = data.size(0); int height = data.size(1); int width = data.size(2); - auto input = data.permute({1, 2, 0}).contiguous(); + auto input = torch::stable::contiguous(permute(data, {1, 2, 0})); - TORCH_CHECK( + VISION_CHECK( channels == 1 || channels == 3, "The number of channels should be 1 or 3, got: ", channels); @@ -90,21 +93,24 @@ torch::Tensor encode_jpeg(const torch::Tensor& data, int64_t quality) { jpeg_start_compress(&cinfo, TRUE); auto stride = width * channels; - auto ptr = input.data_ptr(); + auto ptr = input.const_data_ptr(); // Encode JPEG file while (cinfo.next_scanline < cinfo.image_height) { - jpeg_write_scanlines(&cinfo, &ptr, 1); + jpeg_write_scanlines(&cinfo, const_cast(&ptr), 1); ptr += stride; } jpeg_finish_compress(&cinfo); jpeg_destroy_compress(&cinfo); - torch::TensorOptions options = torch::TensorOptions{torch::kU8}; - auto out_tensor = - torch::from_blob(jpegBuf, {(long)jpegSize}, ::free, options); - jpegBuf = nullptr; + // Create tensor and copy data (from_blob with deleter not available in stable + // ABI) + auto out_tensor = emptyCPU({(int64_t)jpegSize}, kByte); + auto out_ptr = out_tensor.mutable_data_ptr(); + std::memcpy(out_ptr, jpegBuf, jpegSize); + free(jpegBuf); + return out_tensor; } #endif diff --git a/torchvision/csrc/io/image/cpu/encode_jpeg.h b/torchvision/csrc/io/image/cpu/encode_jpeg.h index 25084e154d6..2b69cc4e319 100644 --- a/torchvision/csrc/io/image/cpu/encode_jpeg.h +++ b/torchvision/csrc/io/image/cpu/encode_jpeg.h @@ -1,12 +1,12 @@ #pragma once -#include +#include "../../../StableABICompat.h" namespace vision { namespace image { -C10_EXPORT torch::Tensor encode_jpeg( - const torch::Tensor& data, +vision::stable::Tensor encode_jpeg( + const vision::stable::Tensor& data, int64_t quality); } // namespace image diff --git a/torchvision/csrc/io/image/cpu/encode_png.cpp b/torchvision/csrc/io/image/cpu/encode_png.cpp index d55a0ed3ff6..8a9671b5b7b 100644 --- a/torchvision/csrc/io/image/cpu/encode_png.cpp +++ b/torchvision/csrc/io/image/cpu/encode_png.cpp @@ -1,14 +1,17 @@ -#include "encode_jpeg.h" +#include "encode_png.h" +#include "../common.h" #include "common_png.h" namespace vision { namespace image { +using namespace vision::stable; + #if !PNG_FOUND -torch::Tensor encode_png(const torch::Tensor& data, int64_t compression_level) { - TORCH_CHECK( +Tensor encode_png(const Tensor& data, int64_t compression_level) { + VISION_CHECK( false, "encode_png: torchvision not compiled with libpng support"); } @@ -64,9 +67,9 @@ void torch_png_write_data( } // namespace -torch::Tensor encode_png(const torch::Tensor& data, int64_t compression_level) { - C10_LOG_API_USAGE_ONCE("torchvision.csrc.io.image.cpu.encode_png.encode_png"); - // Define compression structures and error handling +Tensor encode_png(const Tensor& data, int64_t compression_level) { + // Note: C10_LOG_API_USAGE_ONCE is not available in stable ABI, so we just + // skip the logging Define compression structures and error handling png_structp png_write; png_infop info_ptr; struct torch_png_error_mgr err_ptr; @@ -93,30 +96,31 @@ torch::Tensor encode_png(const torch::Tensor& data, int64_t compression_level) { free(buf_info.buffer); } - TORCH_CHECK(false, err_ptr.pngLastErrorMsg); + VISION_CHECK(false, err_ptr.pngLastErrorMsg); } // Check that the compression level is between 0 and 9 - TORCH_CHECK( + VISION_CHECK( compression_level >= 0 && compression_level <= 9, "Compression level should be between 0 and 9"); // Check that the input tensor is on CPU - TORCH_CHECK(data.device() == torch::kCPU, "Input tensor should be on CPU"); + VISION_CHECK(data.device() == Device(kCPU), "Input tensor should be on CPU"); // Check that the input tensor dtype is uint8 - TORCH_CHECK(data.dtype() == torch::kU8, "Input tensor dtype should be uint8"); + VISION_CHECK( + data.scalar_type() == kByte, "Input tensor dtype should be uint8"); // Check that the input tensor is 3-dimensional - TORCH_CHECK(data.dim() == 3, "Input data should be a 3-dimensional tensor"); + VISION_CHECK(data.dim() == 3, "Input data should be a 3-dimensional tensor"); // Get image info int channels = data.size(0); int height = data.size(1); int width = data.size(2); - auto input = data.permute({1, 2, 0}).contiguous(); + auto input = torch::stable::contiguous(permute(data, {1, 2, 0})); - TORCH_CHECK( + VISION_CHECK( channels == 1 || channels == 3, "The number of channels should be 1 or 3, got: ", channels); @@ -150,7 +154,7 @@ torch::Tensor encode_png(const torch::Tensor& data, int64_t compression_level) { png_write_info(png_write, info_ptr); auto stride = width * channels; - auto ptr = input.data_ptr(); + auto ptr = input.const_data_ptr(); // Encode PNG file for (int y = 0; y < height; ++y) { @@ -164,13 +168,10 @@ torch::Tensor encode_png(const torch::Tensor& data, int64_t compression_level) { // Destroy structures png_destroy_write_struct(&png_write, &info_ptr); - torch::TensorOptions options = torch::TensorOptions{torch::kU8}; - auto outTensor = torch::empty({(long)buf_info.size}, options); - - // Copy memory from png buffer, since torch cannot get ownership of it via - // `from_blob` - auto outPtr = outTensor.data_ptr(); - std::memcpy(outPtr, buf_info.buffer, sizeof(uint8_t) * outTensor.numel()); + // Create tensor and copy data + auto outTensor = emptyCPU({(int64_t)buf_info.size}, kByte); + auto outPtr = outTensor.mutable_data_ptr(); + std::memcpy(outPtr, buf_info.buffer, buf_info.size); free(buf_info.buffer); return outTensor; diff --git a/torchvision/csrc/io/image/cpu/encode_png.h b/torchvision/csrc/io/image/cpu/encode_png.h index 86a67c8706e..a27b4d47b5a 100644 --- a/torchvision/csrc/io/image/cpu/encode_png.h +++ b/torchvision/csrc/io/image/cpu/encode_png.h @@ -1,12 +1,12 @@ #pragma once -#include +#include "../../../StableABICompat.h" namespace vision { namespace image { -C10_EXPORT torch::Tensor encode_png( - const torch::Tensor& data, +vision::stable::Tensor encode_png( + const vision::stable::Tensor& data, int64_t compression_level); } // namespace image diff --git a/torchvision/csrc/io/image/cpu/exif.h b/torchvision/csrc/io/image/cpu/exif.h index 7680737f8c0..022340288fc 100644 --- a/torchvision/csrc/io/image/cpu/exif.h +++ b/torchvision/csrc/io/image/cpu/exif.h @@ -58,12 +58,14 @@ direct, #include #endif -#include +#include "../../../StableABICompat.h" namespace vision { namespace image { namespace exif_private { +using namespace vision::stable; + constexpr uint16_t APP1 = 0xe1; constexpr uint16_t ENDIANNESS_INTEL = 0x49; constexpr uint16_t ENDIANNESS_MOTO = 0x4d; @@ -78,7 +80,7 @@ class ExifDataReader { return _size; } const unsigned char& operator[](size_t index) const { - TORCH_CHECK(index >= 0 && index < _size); + VISION_CHECK(index >= 0 && index < _size, "EXIF data index out of bounds"); return _ptr[index]; } @@ -227,27 +229,25 @@ constexpr uint16_t IMAGE_ORIENTATION_RB = 7; // mirrored horizontal & rotate 90 CW constexpr uint16_t IMAGE_ORIENTATION_LB = 8; // needs 270 CW rotation -inline torch::Tensor exif_orientation_transform( - const torch::Tensor& image, - int orientation) { +inline Tensor exif_orientation_transform(const Tensor& image, int orientation) { if (orientation == IMAGE_ORIENTATION_TL) { return image; } else if (orientation == IMAGE_ORIENTATION_TR) { - return image.flip(-1); + return flip(image, {-1}); } else if (orientation == IMAGE_ORIENTATION_BR) { // needs 180 rotation equivalent to // flip both horizontally and vertically - return image.flip({-2, -1}); + return flip(image, {-2, -1}); } else if (orientation == IMAGE_ORIENTATION_BL) { - return image.flip(-2); + return flip(image, {-2}); } else if (orientation == IMAGE_ORIENTATION_LT) { - return image.transpose(-1, -2); + return torch::stable::transpose(image, -1, -2); } else if (orientation == IMAGE_ORIENTATION_RT) { - return image.transpose(-1, -2).flip(-1); + return flip(torch::stable::transpose(image, -1, -2), {-1}); } else if (orientation == IMAGE_ORIENTATION_RB) { - return image.transpose(-1, -2).flip({-2, -1}); + return flip(torch::stable::transpose(image, -1, -2), {-2, -1}); } else if (orientation == IMAGE_ORIENTATION_LB) { - return image.transpose(-1, -2).flip(-2); + return flip(torch::stable::transpose(image, -1, -2), {-2}); } return image; } diff --git a/torchvision/csrc/io/image/cpu/read_write_file.cpp b/torchvision/csrc/io/image/cpu/read_write_file.cpp index 06de72a5053..9c94304f6f0 100644 --- a/torchvision/csrc/io/image/cpu/read_write_file.cpp +++ b/torchvision/csrc/io/image/cpu/read_write_file.cpp @@ -10,6 +10,8 @@ namespace vision { namespace image { +using namespace vision::stable; + #ifdef _WIN32 namespace { std::wstring utf8_decode(const std::string& str) { @@ -18,7 +20,7 @@ std::wstring utf8_decode(const std::string& str) { } int size_needed = MultiByteToWideChar( CP_UTF8, 0, str.c_str(), static_cast(str.size()), nullptr, 0); - TORCH_CHECK(size_needed > 0, "Error converting the content to Unicode"); + VISION_CHECK(size_needed > 0, "Error converting the content to Unicode"); std::wstring wstrTo(size_needed, 0); MultiByteToWideChar( CP_UTF8, @@ -32,9 +34,7 @@ std::wstring utf8_decode(const std::string& str) { } // namespace #endif -torch::Tensor read_file(const std::string& filename) { - C10_LOG_API_USAGE_ONCE( - "torchvision.csrc.io.image.cpu.read_write_file.read_file"); +Tensor read_file(std::string filename) { #ifdef _WIN32 // According to // https://docs.microsoft.com/en-us/cpp/c-runtime-library/reference/stat-functions?view=vs-2019, @@ -47,49 +47,44 @@ torch::Tensor read_file(const std::string& filename) { int rc = stat(filename.c_str(), &stat_buf); #endif // errno is a variable defined in errno.h - TORCH_CHECK( + VISION_CHECK( rc == 0, "[Errno ", errno, "] ", strerror(errno), ": '", filename, "'"); int64_t size = stat_buf.st_size; - TORCH_CHECK(size > 0, "Expected a non empty file"); + VISION_CHECK(size > 0, "Expected a non empty file"); #ifdef _WIN32 - // TODO: Once torch::from_file handles UTF-8 paths correctly, we should move - // back to use the following implementation since it uses file mapping. - // auto data = - // torch::from_file(filename, /*shared=*/false, /*size=*/size, - // torch::kU8).clone() + // On Windows, read file manually since torch::from_file may not handle UTF-8 FILE* infile = _wfopen(fileW.c_str(), L"rb"); +#else + // Use fopen/fread instead of from_file for stable ABI compatibility + FILE* infile = fopen(filename.c_str(), "rb"); +#endif - TORCH_CHECK(infile != nullptr, "Error opening input file"); + VISION_CHECK(infile != nullptr, "Error opening input file"); - auto data = torch::empty({size}, torch::kU8); - auto dataBytes = data.data_ptr(); + auto data = emptyCPU({size}, kByte); + auto dataBytes = data.mutable_data_ptr(); fread(dataBytes, sizeof(uint8_t), size, infile); fclose(infile); -#else - auto data = - torch::from_file(filename, /*shared=*/false, /*size=*/size, torch::kU8); -#endif return data; } -void write_file(const std::string& filename, torch::Tensor& data) { - C10_LOG_API_USAGE_ONCE( - "torchvision.csrc.io.image.cpu.read_write_file.write_file"); +void write_file(std::string filename, const Tensor& data) { // Check that the input tensor is on CPU - TORCH_CHECK(data.device() == torch::kCPU, "Input tensor should be on CPU"); + VISION_CHECK(data.device().type() == kCPU, "Input tensor should be on CPU"); // Check that the input tensor dtype is uint8 - TORCH_CHECK(data.dtype() == torch::kU8, "Input tensor dtype should be uint8"); + VISION_CHECK( + data.scalar_type() == kByte, "Input tensor dtype should be uint8"); // Check that the input tensor is 3-dimensional - TORCH_CHECK(data.dim() == 1, "Input data should be a 1-dimensional tensor"); + VISION_CHECK(data.dim() == 1, "Input data should be a 1-dimensional tensor"); - auto fileBytes = data.data_ptr(); + auto fileBytes = data.const_data_ptr(); auto fileCStr = filename.c_str(); #ifdef _WIN32 auto fileW = utf8_decode(filename); @@ -98,7 +93,7 @@ void write_file(const std::string& filename, torch::Tensor& data) { FILE* outfile = fopen(fileCStr, "wb"); #endif - TORCH_CHECK(outfile != nullptr, "Error opening output file"); + VISION_CHECK(outfile != nullptr, "Error opening output file"); fwrite(fileBytes, sizeof(uint8_t), data.numel(), outfile); fclose(outfile); diff --git a/torchvision/csrc/io/image/cpu/read_write_file.h b/torchvision/csrc/io/image/cpu/read_write_file.h index a5a712dd8e2..2170c71bdce 100644 --- a/torchvision/csrc/io/image/cpu/read_write_file.h +++ b/torchvision/csrc/io/image/cpu/read_write_file.h @@ -1,13 +1,13 @@ #pragma once -#include +#include "../../../StableABICompat.h" namespace vision { namespace image { -C10_EXPORT torch::Tensor read_file(const std::string& filename); +vision::stable::Tensor read_file(std::string filename); -C10_EXPORT void write_file(const std::string& filename, torch::Tensor& data); +void write_file(std::string filename, const vision::stable::Tensor& data); } // namespace image } // namespace vision diff --git a/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp b/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp index 85aa6c760c1..6159a2ac9af 100644 --- a/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp +++ b/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp @@ -2,20 +2,20 @@ #if !NVJPEG_FOUND namespace vision { namespace image { -std::vector decode_jpegs_cuda( - const std::vector& encoded_images, + +using namespace vision::stable; + +std::vector decode_jpegs_cuda( + const std::vector& encoded_images, vision::image::ImageReadMode mode, - torch::Device device) { - TORCH_CHECK( + Device device) { + VISION_CHECK( false, "decode_jpegs_cuda: torchvision not compiled with nvJPEG support"); } } // namespace image } // namespace vision #else -#include -#include -#include #include #include #include @@ -28,40 +28,41 @@ std::vector decode_jpegs_cuda( namespace vision { namespace image { +using namespace vision::stable; + std::mutex decoderMutex; std::unique_ptr cudaJpegDecoder; -std::vector decode_jpegs_cuda( - const std::vector& encoded_images, +std::vector decode_jpegs_cuda( + const std::vector& encoded_images, vision::image::ImageReadMode mode, - torch::Device device) { - C10_LOG_API_USAGE_ONCE( - "torchvision.csrc.io.image.cuda.decode_jpegs_cuda.decode_jpegs_cuda"); + Device device) { + // Note: C10_LOG_API_USAGE_ONCE is not available in stable ABI std::lock_guard lock(decoderMutex); - std::vector contig_images; + std::vector contig_images; contig_images.reserve(encoded_images.size()); - TORCH_CHECK( + VISION_CHECK( device.is_cuda(), "Expected the device parameter to be a cuda device"); for (auto& encoded_image : encoded_images) { - TORCH_CHECK( - encoded_image.dtype() == torch::kU8, "Expected a torch.uint8 tensor"); + VISION_CHECK( + encoded_image.scalar_type() == kByte, "Expected a torch.uint8 tensor"); - TORCH_CHECK( + VISION_CHECK( !encoded_image.is_cuda(), "The input tensor must be on CPU when decoding with nvjpeg") - TORCH_CHECK( + VISION_CHECK( encoded_image.dim() == 1 && encoded_image.numel() > 0, "Expected a non empty 1-dimensional tensor"); // nvjpeg requires images to be contiguous - if (encoded_image.is_contiguous()) { + if (is_contiguous(encoded_image)) { contig_images.push_back(encoded_image); } else { - contig_images.push_back(encoded_image.contiguous()); + contig_images.push_back(torch::stable::contiguous(encoded_image)); } } @@ -72,27 +73,31 @@ std::vector decode_jpegs_cuda( nvjpegStatus_t get_minor_property_status = nvjpegGetProperty(MINOR_VERSION, &minor_version); - TORCH_CHECK( + VISION_CHECK( get_major_property_status == NVJPEG_STATUS_SUCCESS, "nvjpegGetProperty failed: ", get_major_property_status); - TORCH_CHECK( + VISION_CHECK( get_minor_property_status == NVJPEG_STATUS_SUCCESS, "nvjpegGetProperty failed: ", get_minor_property_status); - if ((major_version < 11) || ((major_version == 11) && (minor_version < 6))) { - TORCH_WARN_ONCE( - "There is a memory leak issue in the nvjpeg library for CUDA versions < 11.6. " - "Make sure to rely on CUDA 11.6 or above before using decode_jpeg(..., device='cuda')."); - } + // Note: TORCH_WARN_ONCE is not available in stable ABI, so we skip the + // warning about nvjpeg memory leaks in CUDA versions < 11.6 + + // Set the target CUDA device + int prev_device; + cudaGetDevice(&prev_device); + int target_device_idx = device.has_index() ? device.index() : prev_device; + cudaSetDevice(target_device_idx); - at::cuda::CUDAGuard device_guard(device); + // Create a new device with the resolved index for consistency + Device resolved_device(kCUDA, static_cast(target_device_idx)); - if (cudaJpegDecoder == nullptr || device != cudaJpegDecoder->target_device) { + if (cudaJpegDecoder == nullptr || resolved_device != cudaJpegDecoder->target_device) { if (cudaJpegDecoder != nullptr) { - cudaJpegDecoder.reset(new CUDAJpegDecoder(device)); + cudaJpegDecoder.reset(new CUDAJpegDecoder(resolved_device)); } else { - cudaJpegDecoder = std::make_unique(device); + cudaJpegDecoder = std::make_unique(resolved_device); std::atexit([]() { cudaJpegDecoder.reset(); }); } } @@ -114,36 +119,41 @@ std::vector decode_jpegs_cuda( output_format = NVJPEG_OUTPUT_RGB; break; default: - TORCH_CHECK( + VISION_CHECK( false, "The provided mode is not supported for JPEG decoding on GPU"); } try { - at::cuda::CUDAEvent event; auto result = cudaJpegDecoder->decode_images(contig_images, output_format); - auto current_stream{ - device.has_index() ? at::cuda::getCurrentCUDAStream( - cudaJpegDecoder->original_device.index()) - : at::cuda::getCurrentCUDAStream()}; - event.record(cudaJpegDecoder->stream); - event.block(current_stream); + + // Synchronize the decoder stream with the current stream using events + cudaEvent_t event; + cudaEventCreate(&event); + cudaEventRecord(event, cudaJpegDecoder->stream); + // Use the default stream for synchronization + cudaStreamWaitEvent(nullptr, event, 0); + cudaEventDestroy(event); + + // Restore original device + cudaSetDevice(prev_device); return result; } catch (const std::exception& e) { - if (typeid(e) != typeid(std::runtime_error)) { - TORCH_CHECK(false, "Error while decoding JPEG images: ", e.what()); - } else { - throw; - } + cudaSetDevice(prev_device); + throw std::runtime_error(std::string("Error while decoding JPEG images: ") + e.what()); } } -CUDAJpegDecoder::CUDAJpegDecoder(const torch::Device& target_device) - : original_device{torch::kCUDA, c10::cuda::current_device()}, +CUDAJpegDecoder::CUDAJpegDecoder(const Device& target_device) + : original_device{kCUDA, []() { + int dev; + cudaGetDevice(&dev); + return static_cast(dev); + }()}, target_device{target_device}, - stream{ - target_device.has_index() - ? at::cuda::getStreamFromPool(false, target_device.index()) - : at::cuda::getStreamFromPool(false)} { + stream{nullptr} { + // Create a CUDA stream for this decoder + cudaStreamCreate(&stream); + nvjpegStatus_t status; hw_decode_available = true; @@ -160,70 +170,70 @@ CUDAJpegDecoder::CUDAJpegDecoder(const torch::Device& target_device) NULL, NVJPEG_FLAGS_DEFAULT, &nvjpeg_handle); - TORCH_CHECK( + VISION_CHECK( status == NVJPEG_STATUS_SUCCESS, "Failed to initialize nvjpeg with default backend: ", status); hw_decode_available = false; } else { - TORCH_CHECK( + VISION_CHECK( status == NVJPEG_STATUS_SUCCESS, "Failed to initialize nvjpeg with hardware backend: ", status); } status = nvjpegJpegStateCreate(nvjpeg_handle, &nvjpeg_state); - TORCH_CHECK( + VISION_CHECK( status == NVJPEG_STATUS_SUCCESS, "Failed to create nvjpeg state: ", status); status = nvjpegDecoderCreate( nvjpeg_handle, NVJPEG_BACKEND_DEFAULT, &nvjpeg_decoder); - TORCH_CHECK( + VISION_CHECK( status == NVJPEG_STATUS_SUCCESS, "Failed to create nvjpeg decoder: ", status); status = nvjpegDecoderStateCreate( nvjpeg_handle, nvjpeg_decoder, &nvjpeg_decoupled_state); - TORCH_CHECK( + VISION_CHECK( status == NVJPEG_STATUS_SUCCESS, "Failed to create nvjpeg decoder state: ", status); status = nvjpegBufferPinnedCreate(nvjpeg_handle, NULL, &pinned_buffers[0]); - TORCH_CHECK( + VISION_CHECK( status == NVJPEG_STATUS_SUCCESS, "Failed to create pinned buffer: ", status); status = nvjpegBufferPinnedCreate(nvjpeg_handle, NULL, &pinned_buffers[1]); - TORCH_CHECK( + VISION_CHECK( status == NVJPEG_STATUS_SUCCESS, "Failed to create pinned buffer: ", status); status = nvjpegBufferDeviceCreate(nvjpeg_handle, NULL, &device_buffer); - TORCH_CHECK( + VISION_CHECK( status == NVJPEG_STATUS_SUCCESS, "Failed to create device buffer: ", status); status = nvjpegJpegStreamCreate(nvjpeg_handle, &jpeg_streams[0]); - TORCH_CHECK( + VISION_CHECK( status == NVJPEG_STATUS_SUCCESS, "Failed to create jpeg stream: ", status); status = nvjpegJpegStreamCreate(nvjpeg_handle, &jpeg_streams[1]); - TORCH_CHECK( + VISION_CHECK( status == NVJPEG_STATUS_SUCCESS, "Failed to create jpeg stream: ", status); status = nvjpegDecodeParamsCreate(nvjpeg_handle, &nvjpeg_decode_params); - TORCH_CHECK( + VISION_CHECK( status == NVJPEG_STATUS_SUCCESS, "Failed to create decode params: ", status); @@ -240,80 +250,21 @@ CUDAJpegDecoder::~CUDAJpegDecoder() { Please send a PR if you have a solution for this problem. */ - // nvjpegStatus_t status; - - // status = nvjpegDecodeParamsDestroy(nvjpeg_decode_params); - // TORCH_CHECK( - // status == NVJPEG_STATUS_SUCCESS, - // "Failed to destroy nvjpeg decode params: ", - // status); - - // status = nvjpegJpegStreamDestroy(jpeg_streams[0]); - // TORCH_CHECK( - // status == NVJPEG_STATUS_SUCCESS, - // "Failed to destroy jpeg stream: ", - // status); - - // status = nvjpegJpegStreamDestroy(jpeg_streams[1]); - // TORCH_CHECK( - // status == NVJPEG_STATUS_SUCCESS, - // "Failed to destroy jpeg stream: ", - // status); - - // status = nvjpegBufferPinnedDestroy(pinned_buffers[0]); - // TORCH_CHECK( - // status == NVJPEG_STATUS_SUCCESS, - // "Failed to destroy pinned buffer[0]: ", - // status); - - // status = nvjpegBufferPinnedDestroy(pinned_buffers[1]); - // TORCH_CHECK( - // status == NVJPEG_STATUS_SUCCESS, - // "Failed to destroy pinned buffer[1]: ", - // status); - - // status = nvjpegBufferDeviceDestroy(device_buffer); - // TORCH_CHECK( - // status == NVJPEG_STATUS_SUCCESS, - // "Failed to destroy device buffer: ", - // status); - - // status = nvjpegJpegStateDestroy(nvjpeg_decoupled_state); - // TORCH_CHECK( - // status == NVJPEG_STATUS_SUCCESS, - // "Failed to destroy nvjpeg decoupled state: ", - // status); - - // status = nvjpegDecoderDestroy(nvjpeg_decoder); - // TORCH_CHECK( - // status == NVJPEG_STATUS_SUCCESS, - // "Failed to destroy nvjpeg decoder: ", - // status); - - // status = nvjpegJpegStateDestroy(nvjpeg_state); - // TORCH_CHECK( - // status == NVJPEG_STATUS_SUCCESS, - // "Failed to destroy nvjpeg state: ", - // status); - - // status = nvjpegDestroy(nvjpeg_handle); - // TORCH_CHECK( - // status == NVJPEG_STATUS_SUCCESS, "nvjpegDestroy failed: ", status); + // if (stream != nullptr) { + // cudaStreamDestroy(stream); + // } } -std::tuple< - std::vector, - std::vector, - std::vector> +std::tuple, std::vector, std::vector> CUDAJpegDecoder::prepare_buffers( - const std::vector& encoded_images, + const std::vector& encoded_images, const nvjpegOutputFormat_t& output_format) { /* This function scans the encoded images' jpeg headers and allocates decoding buffers based on the metadata found Args: - - encoded_images (std::vector): a vector of tensors + - encoded_images (std::vector): a vector of tensors containing the jpeg bitstreams to be decoded. Each tensor must have dtype torch.uint8 and device cpu - output_format (nvjpegOutputFormat_t): NVJPEG_OUTPUT_RGB, NVJPEG_OUTPUT_Y @@ -322,7 +273,7 @@ CUDAJpegDecoder::prepare_buffers( Returns: - decoded_images (std::vector): a vector of nvjpegImages containing pointers to the memory of the decoded images - - output_tensors (std::vector): a vector of Tensors + - output_tensors (std::vector): a vector of Tensors containing the decoded images. `decoded_images` points to the memory of output_tensors - channels (std::vector): a vector of ints containing the number of @@ -335,25 +286,24 @@ CUDAJpegDecoder::prepare_buffers( nvjpegChromaSubsampling_t subsampling; nvjpegStatus_t status; - std::vector output_tensors{encoded_images.size()}; + std::vector output_tensors{encoded_images.size()}; std::vector decoded_images{encoded_images.size()}; - for (std::vector::size_type i = 0; i < encoded_images.size(); - i++) { + for (size_t i = 0; i < encoded_images.size(); i++) { // extract bitstream meta data to figure out the number of channels, height, // width for every image status = nvjpegGetImageInfo( nvjpeg_handle, - (unsigned char*)encoded_images[i].data_ptr(), + encoded_images[i].const_data_ptr(), encoded_images[i].numel(), &channels[i], &subsampling, width, height); - TORCH_CHECK( + VISION_CHECK( status == NVJPEG_STATUS_SUCCESS, "Failed to get image info: ", status); - TORCH_CHECK( + VISION_CHECK( subsampling != NVJPEG_CSS_UNKNOWN, "Unknown chroma subsampling"); // output channels may be different from the actual number of channels in @@ -368,14 +318,16 @@ CUDAJpegDecoder::prepare_buffers( } // reserve output buffer - auto output_tensor = torch::empty( + auto output_tensor = empty( {int64_t(output_channels), int64_t(height[0]), int64_t(width[0])}, - torch::dtype(torch::kU8).device(target_device)); + kByte, + target_device); output_tensors[i] = output_tensor; // fill nvjpegImage_t struct for (int c = 0; c < output_channels; c++) { - decoded_images[i].channel[c] = output_tensor[c].data_ptr(); + decoded_images[i].channel[c] = torch::stable::select(output_tensor, 0, c) + .mutable_data_ptr(); decoded_images[i].pitch[c] = width[0]; } for (int c = output_channels; c < NVJPEG_MAX_COMPONENT; c++) { @@ -386,8 +338,8 @@ CUDAJpegDecoder::prepare_buffers( return {decoded_images, output_tensors, channels}; } -std::vector CUDAJpegDecoder::decode_images( - const std::vector& encoded_images, +std::vector CUDAJpegDecoder::decode_images( + const std::vector& encoded_images, const nvjpegOutputFormat_t& output_format) { /* This function decodes a batch of jpeg bitstreams. @@ -402,14 +354,14 @@ std::vector CUDAJpegDecoder::decode_images( for reference. Args: - - encoded_images (std::vector): a vector of tensors + - encoded_images (std::vector): a vector of tensors containing the jpeg bitstreams to be decoded - output_format (nvjpegOutputFormat_t): NVJPEG_OUTPUT_RGB, NVJPEG_OUTPUT_Y or NVJPEG_OUTPUT_UNCHANGED - - device (torch::Device): The desired CUDA device for the returned Tensors + - device (Device): The desired CUDA device for the returned Tensors Returns: - - output_tensors (std::vector): a vector of Tensors + - output_tensors (std::vector): a vector of Tensors containing the decoded images */ @@ -420,7 +372,7 @@ std::vector CUDAJpegDecoder::decode_images( cudaError_t cudaStatus; cudaStatus = cudaStreamSynchronize(stream); - TORCH_CHECK( + VISION_CHECK( cudaStatus == cudaSuccess, "Failed to synchronize CUDA stream: ", cudaStatus); @@ -438,13 +390,12 @@ std::vector CUDAJpegDecoder::decode_images( std::vector sw_output_buffer; if (hw_decode_available) { - for (std::vector::size_type i = 0; i < encoded_images.size(); - ++i) { + for (size_t i = 0; i < encoded_images.size(); ++i) { // extract bitstream meta data to figure out whether a bit-stream can be // decoded nvjpegJpegStreamParseHeader( nvjpeg_handle, - encoded_images[i].data_ptr(), + encoded_images[i].const_data_ptr(), encoded_images[i].numel(), jpeg_streams[0]); int isSupported = -1; @@ -452,19 +403,18 @@ std::vector CUDAJpegDecoder::decode_images( nvjpeg_handle, jpeg_streams[0], &isSupported); if (isSupported == 0) { - hw_input_buffer.push_back(encoded_images[i].data_ptr()); + hw_input_buffer.push_back(encoded_images[i].const_data_ptr()); hw_input_buffer_size.push_back(encoded_images[i].numel()); hw_output_buffer.push_back(decoded_imgs_buf[i]); } else { - sw_input_buffer.push_back(encoded_images[i].data_ptr()); + sw_input_buffer.push_back(encoded_images[i].const_data_ptr()); sw_input_buffer_size.push_back(encoded_images[i].numel()); sw_output_buffer.push_back(decoded_imgs_buf[i]); } } } else { - for (std::vector::size_type i = 0; i < encoded_images.size(); - ++i) { - sw_input_buffer.push_back(encoded_images[i].data_ptr()); + for (size_t i = 0; i < encoded_images.size(); ++i) { + sw_input_buffer.push_back(encoded_images[i].const_data_ptr()); sw_input_buffer_size.push_back(encoded_images[i].numel()); sw_output_buffer.push_back(decoded_imgs_buf[i]); } @@ -479,7 +429,7 @@ std::vector CUDAJpegDecoder::decode_images( 1, output_format == NVJPEG_OUTPUT_UNCHANGED ? NVJPEG_OUTPUT_RGB : output_format); - TORCH_CHECK( + VISION_CHECK( status == NVJPEG_STATUS_SUCCESS, "Failed to initialize batch decoding: ", status); @@ -491,14 +441,14 @@ std::vector CUDAJpegDecoder::decode_images( hw_input_buffer_size.data(), hw_output_buffer.data(), stream); - TORCH_CHECK( + VISION_CHECK( status == NVJPEG_STATUS_SUCCESS, "Failed to decode batch: ", status); } if (sw_input_buffer.size() > 0) { status = nvjpegStateAttachDeviceBuffer(nvjpeg_decoupled_state, device_buffer); - TORCH_CHECK( + VISION_CHECK( status == NVJPEG_STATUS_SUCCESS, "Failed to attach device buffer: ", status); @@ -508,12 +458,11 @@ std::vector CUDAJpegDecoder::decode_images( nvjpeg_decode_params, output_format == NVJPEG_OUTPUT_UNCHANGED ? NVJPEG_OUTPUT_RGB : output_format); - TORCH_CHECK( + VISION_CHECK( status == NVJPEG_STATUS_SUCCESS, "Failed to set output format: ", status); - for (std::vector::size_type i = 0; i < sw_input_buffer.size(); - ++i) { + for (size_t i = 0; i < sw_input_buffer.size(); ++i) { status = nvjpegJpegStreamParse( nvjpeg_handle, sw_input_buffer[i], @@ -521,14 +470,14 @@ std::vector CUDAJpegDecoder::decode_images( 0, 0, jpeg_streams[buffer_index]); - TORCH_CHECK( + VISION_CHECK( status == NVJPEG_STATUS_SUCCESS, "Failed to parse jpeg stream: ", status); status = nvjpegStateAttachPinnedBuffer( nvjpeg_decoupled_state, pinned_buffers[buffer_index]); - TORCH_CHECK( + VISION_CHECK( status == NVJPEG_STATUS_SUCCESS, "Failed to attach pinned buffer: ", status); @@ -539,13 +488,13 @@ std::vector CUDAJpegDecoder::decode_images( nvjpeg_decoupled_state, nvjpeg_decode_params, jpeg_streams[buffer_index]); - TORCH_CHECK( + VISION_CHECK( status == NVJPEG_STATUS_SUCCESS, "Failed to decode jpeg stream: ", status); cudaStatus = cudaStreamSynchronize(stream); - TORCH_CHECK( + VISION_CHECK( cudaStatus == cudaSuccess, "Failed to synchronize CUDA stream: ", cudaStatus); @@ -556,7 +505,7 @@ std::vector CUDAJpegDecoder::decode_images( nvjpeg_decoupled_state, jpeg_streams[buffer_index], stream); - TORCH_CHECK( + VISION_CHECK( status == NVJPEG_STATUS_SUCCESS, "Failed to transfer jpeg to device: ", status); @@ -570,7 +519,7 @@ std::vector CUDAJpegDecoder::decode_images( nvjpeg_decoupled_state, &sw_output_buffer[i], stream); - TORCH_CHECK( + VISION_CHECK( status == NVJPEG_STATUS_SUCCESS, "Failed to decode jpeg stream: ", status); @@ -578,17 +527,17 @@ std::vector CUDAJpegDecoder::decode_images( } cudaStatus = cudaStreamSynchronize(stream); - TORCH_CHECK( + VISION_CHECK( cudaStatus == cudaSuccess, "Failed to synchronize CUDA stream: ", cudaStatus); // prune extraneous channels from single channel images if (output_format == NVJPEG_OUTPUT_UNCHANGED) { - for (std::vector::size_type i = 0; i < output_tensors.size(); - ++i) { + for (size_t i = 0; i < output_tensors.size(); ++i) { if (channels[i] == 1) { - output_tensors[i] = output_tensors[i][0].unsqueeze(0).clone(); + output_tensors[i] = torch::stable::clone(torch::stable::unsqueeze( + torch::stable::select(output_tensors[i], 0, 0), 0)); } } } diff --git a/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.h b/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.h index 6f72d9e35b2..3274e119279 100644 --- a/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.h +++ b/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.h @@ -1,34 +1,34 @@ #pragma once -#include #include +#include "../../../StableABICompat.h" #include "../common.h" #if NVJPEG_FOUND -#include +#include #include namespace vision { namespace image { class CUDAJpegDecoder { public: - CUDAJpegDecoder(const torch::Device& target_device); + CUDAJpegDecoder(const vision::stable::Device& target_device); ~CUDAJpegDecoder(); - std::vector decode_images( - const std::vector& encoded_images, + std::vector decode_images( + const std::vector& encoded_images, const nvjpegOutputFormat_t& output_format); - const torch::Device original_device; - const torch::Device target_device; - const c10::cuda::CUDAStream stream; + const vision::stable::Device original_device; + const vision::stable::Device target_device; + cudaStream_t stream; private: std::tuple< std::vector, - std::vector, + std::vector, std::vector> prepare_buffers( - const std::vector& encoded_images, + const std::vector& encoded_images, const nvjpegOutputFormat_t& output_format); nvjpegJpegState_t nvjpeg_state; nvjpegJpegState_t nvjpeg_decoupled_state; diff --git a/torchvision/csrc/io/image/cuda/encode_decode_jpegs_cuda.h b/torchvision/csrc/io/image/cuda/encode_decode_jpegs_cuda.h index 8c3ad8f9a9d..9074e9464f9 100644 --- a/torchvision/csrc/io/image/cuda/encode_decode_jpegs_cuda.h +++ b/torchvision/csrc/io/image/cuda/encode_decode_jpegs_cuda.h @@ -1,6 +1,6 @@ #pragma once -#include +#include "../../../StableABICompat.h" #include "../common.h" #include "decode_jpegs_cuda.h" #include "encode_jpegs_cuda.h" @@ -14,45 +14,45 @@ Fast jpeg decoding with CUDA. A100+ GPUs have dedicated hardware support for jpeg decoding. Args: - - encoded_images (const std::vector&): a vector of tensors - containing the jpeg bitstreams to be decoded. Each tensor must have dtype - torch.uint8 and device cpu + - encoded_images (const std::vector&): a vector of +tensors containing the jpeg bitstreams to be decoded. Each tensor must have +dtype torch.uint8 and device cpu - mode (ImageReadMode): IMAGE_READ_MODE_UNCHANGED, IMAGE_READ_MODE_GRAY and IMAGE_READ_MODE_RGB are supported - - device (torch::Device): The desired CUDA device to run the decoding on and -which will contain the output tensors + - device (vision::stable::Device): The desired CUDA device to run the +decoding on and which will contain the output tensors Returns: - - decoded_images (std::vector): a vector of torch::Tensors of -dtype torch.uint8 on the specified containing the decoded images + - decoded_images (std::vector): a vector of Tensors +of dtype torch.uint8 on the specified containing the decoded images Notes: - If a single image fails, the whole batch fails. - This function is thread-safe */ -C10_EXPORT std::vector decode_jpegs_cuda( - const std::vector& encoded_images, +C10_EXPORT std::vector decode_jpegs_cuda( + const std::vector& encoded_images, vision::image::ImageReadMode mode, - torch::Device device); + vision::stable::Device device); /* Fast jpeg encoding with CUDA. Args: - - decoded_images (const std::vector&): a vector of contiguous -CUDA tensors of dtype torch.uint8 to be encoded. + - decoded_images (const std::vector&): a vector of +contiguous CUDA tensors of dtype torch.uint8 to be encoded. - quality (int64_t): 0-100, 75 is the default Returns: - - encoded_images (std::vector): a vector of CUDA -torch::Tensors of dtype torch.uint8 containing the encoded images + - encoded_images (std::vector): a vector of CUDA +Tensors of dtype torch.uint8 containing the encoded images Notes: - If a single image fails, the whole batch fails. - This function is thread-safe */ -C10_EXPORT std::vector encode_jpegs_cuda( - const std::vector& decoded_images, +C10_EXPORT std::vector encode_jpegs_cuda( + const std::vector& decoded_images, const int64_t quality); } // namespace image diff --git a/torchvision/csrc/io/image/cuda/encode_jpegs_cuda.cpp b/torchvision/csrc/io/image/cuda/encode_jpegs_cuda.cpp index 80accc1a241..aa582701b15 100644 --- a/torchvision/csrc/io/image/cuda/encode_jpegs_cuda.cpp +++ b/torchvision/csrc/io/image/cuda/encode_jpegs_cuda.cpp @@ -2,55 +2,67 @@ #if !NVJPEG_FOUND namespace vision { namespace image { -std::vector encode_jpegs_cuda( - const std::vector& decoded_images, + +using namespace vision::stable; + +std::vector encode_jpegs_cuda( + const std::vector& decoded_images, const int64_t quality) { - TORCH_CHECK( + VISION_CHECK( false, "encode_jpegs_cuda: torchvision not compiled with nvJPEG support"); } } // namespace image } // namespace vision #else -#include -#include +#include #include #include +#include #include -#include "c10/core/ScalarType.h" namespace vision { namespace image { +using namespace vision::stable; + // We use global variables to cache the encoder and decoder instances and // reuse them across calls to the corresponding pytorch functions std::mutex encoderMutex; std::unique_ptr cudaJpegEncoder; -std::vector encode_jpegs_cuda( - const std::vector& decoded_images, +std::vector encode_jpegs_cuda( + const std::vector& decoded_images, const int64_t quality) { - C10_LOG_API_USAGE_ONCE( - "torchvision.csrc.io.image.cuda.encode_jpegs_cuda.encode_jpegs_cuda"); + // Note: C10_LOG_API_USAGE_ONCE is not available in stable ABI // Some nvjpeg structures are not thread safe so we're keeping it single // threaded for now. In the future this may be an opportunity to unlock // further speedups std::lock_guard lock(encoderMutex); - TORCH_CHECK(decoded_images.size() > 0, "Empty input tensor list"); - torch::Device device = decoded_images[0].device(); - at::cuda::CUDAGuard device_guard(device); + VISION_CHECK(decoded_images.size() > 0, "Empty input tensor list"); + Device device = Device( + decoded_images[0].device().type(), decoded_images[0].device().index()); + + // Set the target CUDA device + int prev_device; + cudaGetDevice(&prev_device); + int target_device_idx = device.has_index() ? device.index() : prev_device; + cudaSetDevice(target_device_idx); + + // Create a device with the resolved index for consistency + Device resolved_device(kCUDA, static_cast(target_device_idx)); // lazy init of the encoder class // the encoder object holds on to a lot of state and is expensive to create, // so we reuse it across calls. NB: the cached structures are device specific // and cannot be reused across devices - if (cudaJpegEncoder == nullptr || device != cudaJpegEncoder->target_device) { + if (cudaJpegEncoder == nullptr || resolved_device != cudaJpegEncoder->target_device) { if (cudaJpegEncoder != nullptr) { delete cudaJpegEncoder.release(); } - cudaJpegEncoder = std::make_unique(device); + cudaJpegEncoder = std::make_unique(resolved_device); // Unfortunately, we cannot rely on the smart pointer releasing the encoder // object correctly upon program exit. This is because, when cudaJpegEncoder @@ -63,41 +75,40 @@ std::vector encode_jpegs_cuda( std::atexit([]() { delete cudaJpegEncoder.release(); }); } - std::vector contig_images; + std::vector contig_images; contig_images.reserve(decoded_images.size()); for (const auto& image : decoded_images) { - TORCH_CHECK( - image.dtype() == torch::kU8, "Input tensor dtype should be uint8"); + VISION_CHECK( + image.scalar_type() == kByte, "Input tensor dtype should be uint8"); - TORCH_CHECK( - image.device() == device, + VISION_CHECK( + image.device().type() == device.type() && + image.device().index() == device.index(), "All input tensors must be on the same CUDA device when encoding with nvjpeg") - TORCH_CHECK( + VISION_CHECK( image.dim() == 3 && image.numel() > 0, "Input data should be a 3-dimensional tensor"); - TORCH_CHECK( + VISION_CHECK( image.size(0) == 3, "The number of channels should be 3, got: ", image.size(0)); // nvjpeg requires images to be contiguous - if (image.is_contiguous()) { + if (is_contiguous(image)) { contig_images.push_back(image); } else { - contig_images.push_back(image.contiguous()); + contig_images.push_back(torch::stable::contiguous(image)); } } cudaJpegEncoder->set_quality(quality); - std::vector encoded_images; + std::vector encoded_images; for (const auto& image : contig_images) { auto encoded_image = cudaJpegEncoder->encode_jpeg(image); encoded_images.push_back(encoded_image); } - at::cuda::CUDAEvent event; - event.record(cudaJpegEncoder->stream); // We use a dedicated stream to do the encoding and even though the results // may be ready on that stream we cannot assume that they are also available @@ -106,36 +117,46 @@ std::vector encode_jpegs_cuda( // do not want to block the host at this particular point // (which is what cudaStreamSynchronize would do.) Events allow us to // synchronize the streams without blocking the host. - event.block(cudaJpegEncoder->current_stream); + cudaEvent_t event; + cudaEventCreate(&event); + cudaEventRecord(event, cudaJpegEncoder->stream); + cudaStreamWaitEvent(cudaJpegEncoder->current_stream, event, 0); + cudaEventDestroy(event); + + // Restore original device + cudaSetDevice(prev_device); return encoded_images; } -CUDAJpegEncoder::CUDAJpegEncoder(const torch::Device& target_device) - : original_device{torch::kCUDA, torch::cuda::current_device()}, +CUDAJpegEncoder::CUDAJpegEncoder(const Device& target_device) + : original_device{kCUDA, []() { + int dev; + cudaGetDevice(&dev); + return static_cast(dev); + }()}, target_device{target_device}, - stream{ - target_device.has_index() - ? at::cuda::getStreamFromPool(false, target_device.index()) - : at::cuda::getStreamFromPool(false)}, - current_stream{ - original_device.has_index() - ? at::cuda::getCurrentCUDAStream(original_device.index()) - : at::cuda::getCurrentCUDAStream()} { + stream{nullptr}, + current_stream{nullptr} { + // Create CUDA streams + cudaStreamCreate(&stream); + // Get the default stream (nullptr represents the default stream) + current_stream = nullptr; + nvjpegStatus_t status; status = nvjpegCreateSimple(&nvjpeg_handle); - TORCH_CHECK( + VISION_CHECK( status == NVJPEG_STATUS_SUCCESS, "Failed to create nvjpeg handle: ", status); status = nvjpegEncoderStateCreate(nvjpeg_handle, &nv_enc_state, stream); - TORCH_CHECK( + VISION_CHECK( status == NVJPEG_STATUS_SUCCESS, "Failed to create nvjpeg encoder state: ", status); status = nvjpegEncoderParamsCreate(nvjpeg_handle, &nv_enc_params, stream); - TORCH_CHECK( + VISION_CHECK( status == NVJPEG_STATUS_SUCCESS, "Failed to create nvjpeg encoder params: ", status); @@ -152,45 +173,18 @@ CUDAJpegEncoder::~CUDAJpegEncoder() { Please send a PR if you have a solution for this problem. */ - // // We run cudaGetDeviceCount as a dummy to test if the CUDA runtime is - // still - // // initialized. If it is not, we can skip the rest of this function as it - // is - // // unsafe to execute. - // int deviceCount = 0; - // cudaError_t error = cudaGetDeviceCount(&deviceCount); - // if (error != cudaSuccess) - // return; // CUDA runtime has already shut down. There's nothing we can do - // // now. - - // nvjpegStatus_t status; - - // status = nvjpegEncoderParamsDestroy(nv_enc_params); - // TORCH_CHECK( - // status == NVJPEG_STATUS_SUCCESS, - // "Failed to destroy nvjpeg encoder params: ", - // status); - - // status = nvjpegEncoderStateDestroy(nv_enc_state); - // TORCH_CHECK( - // status == NVJPEG_STATUS_SUCCESS, - // "Failed to destroy nvjpeg encoder state: ", - // status); - - // cudaStreamSynchronize(stream); - - // status = nvjpegDestroy(nvjpeg_handle); - // TORCH_CHECK( - // status == NVJPEG_STATUS_SUCCESS, "nvjpegDestroy failed: ", status); + // if (stream != nullptr) { + // cudaStreamDestroy(stream); + // } } -torch::Tensor CUDAJpegEncoder::encode_jpeg(const torch::Tensor& src_image) { +Tensor CUDAJpegEncoder::encode_jpeg(const Tensor& src_image) { nvjpegStatus_t status; cudaError_t cudaStatus; // Ensure that the incoming src_image is safe to use cudaStatus = cudaStreamSynchronize(current_stream); - TORCH_CHECK(cudaStatus == cudaSuccess, "CUDA ERROR: ", cudaStatus); + VISION_CHECK(cudaStatus == cudaSuccess, "CUDA ERROR: ", cudaStatus); int channels = src_image.size(0); int height = src_image.size(1); @@ -198,14 +192,15 @@ torch::Tensor CUDAJpegEncoder::encode_jpeg(const torch::Tensor& src_image) { status = nvjpegEncoderParamsSetSamplingFactors( nv_enc_params, NVJPEG_CSS_444, stream); - TORCH_CHECK( + VISION_CHECK( status == NVJPEG_STATUS_SUCCESS, "Failed to set nvjpeg encoder params sampling factors: ", status); nvjpegImage_t target_image; for (int c = 0; c < channels; c++) { - target_image.channel[c] = src_image[c].data_ptr(); + target_image.channel[c] = + torch::stable::select(src_image, 0, c).mutable_data_ptr(); // this is why we need contiguous tensors target_image.pitch[c] = width; } @@ -224,39 +219,34 @@ torch::Tensor CUDAJpegEncoder::encode_jpeg(const torch::Tensor& src_image) { height, stream); - TORCH_CHECK( + VISION_CHECK( status == NVJPEG_STATUS_SUCCESS, "image encoding failed: ", status); // Retrieve length of the encoded image size_t length; status = nvjpegEncodeRetrieveBitstreamDevice( nvjpeg_handle, nv_enc_state, NULL, &length, stream); - TORCH_CHECK( + VISION_CHECK( status == NVJPEG_STATUS_SUCCESS, "Failed to retrieve encoded image stream state: ", status); // Synchronize the stream to ensure that the encoded image is ready cudaStatus = cudaStreamSynchronize(stream); - TORCH_CHECK(cudaStatus == cudaSuccess, "CUDA ERROR: ", cudaStatus); + VISION_CHECK(cudaStatus == cudaSuccess, "CUDA ERROR: ", cudaStatus); // Reserve buffer for the encoded image - torch::Tensor encoded_image = torch::empty( - {static_cast(length)}, - torch::TensorOptions() - .dtype(torch::kByte) - .layout(torch::kStrided) - .device(target_device) - .requires_grad(false)); + Tensor encoded_image = + empty({static_cast(length)}, kByte, target_device); cudaStatus = cudaStreamSynchronize(stream); - TORCH_CHECK(cudaStatus == cudaSuccess, "CUDA ERROR: ", cudaStatus); + VISION_CHECK(cudaStatus == cudaSuccess, "CUDA ERROR: ", cudaStatus); // Retrieve the encoded image status = nvjpegEncodeRetrieveBitstreamDevice( nvjpeg_handle, nv_enc_state, - encoded_image.data_ptr(), + encoded_image.mutable_data_ptr(), &length, stream); - TORCH_CHECK( + VISION_CHECK( status == NVJPEG_STATUS_SUCCESS, "Failed to retrieve encoded image: ", status); @@ -266,7 +256,7 @@ torch::Tensor CUDAJpegEncoder::encode_jpeg(const torch::Tensor& src_image) { void CUDAJpegEncoder::set_quality(const int64_t quality) { nvjpegStatus_t paramsQualityStatus = nvjpegEncoderParamsSetQuality(nv_enc_params, quality, stream); - TORCH_CHECK( + VISION_CHECK( paramsQualityStatus == NVJPEG_STATUS_SUCCESS, "Failed to set nvjpeg encoder params quality: ", paramsQualityStatus); diff --git a/torchvision/csrc/io/image/cuda/encode_jpegs_cuda.h b/torchvision/csrc/io/image/cuda/encode_jpegs_cuda.h index 6ee0ad91df4..020e1646340 100644 --- a/torchvision/csrc/io/image/cuda/encode_jpegs_cuda.h +++ b/torchvision/csrc/io/image/cuda/encode_jpegs_cuda.h @@ -1,10 +1,9 @@ #pragma once -#include #include +#include "../../../StableABICompat.h" #if NVJPEG_FOUND -#include -#include +#include #include namespace vision { @@ -12,17 +11,17 @@ namespace image { class CUDAJpegEncoder { public: - CUDAJpegEncoder(const torch::Device& device); + CUDAJpegEncoder(const vision::stable::Device& device); ~CUDAJpegEncoder(); - torch::Tensor encode_jpeg(const torch::Tensor& src_image); + vision::stable::Tensor encode_jpeg(const vision::stable::Tensor& src_image); void set_quality(const int64_t quality); - const torch::Device original_device; - const torch::Device target_device; - const c10::cuda::CUDAStream stream; - const c10::cuda::CUDAStream current_stream; + const vision::stable::Device original_device; + const vision::stable::Device target_device; + cudaStream_t stream; + cudaStream_t current_stream; protected: nvjpegEncoderState_t nv_enc_state; diff --git a/torchvision/csrc/io/image/image.cpp b/torchvision/csrc/io/image/image.cpp index b4a4ed54a67..ebfecabbe2d 100644 --- a/torchvision/csrc/io/image/image.cpp +++ b/torchvision/csrc/io/image/image.cpp @@ -1,29 +1,58 @@ #include "image.h" -#include +#include "../../StableABICompat.h" namespace vision { namespace image { -static auto registry = - torch::RegisterOperators() - .op("image::decode_gif", &decode_gif) - .op("image::decode_png(Tensor data, int mode, bool apply_exif_orientation=False) -> Tensor", - &decode_png) - .op("image::encode_png", &encode_png) - .op("image::decode_jpeg(Tensor data, int mode, bool apply_exif_orientation=False) -> Tensor", - &decode_jpeg) - .op("image::decode_webp(Tensor encoded_data, int mode) -> Tensor", - &decode_webp) - .op("image::encode_jpeg", &encode_jpeg) - .op("image::read_file", &read_file) - .op("image::write_file", &write_file) - .op("image::decode_image(Tensor data, int mode, bool apply_exif_orientation=False) -> Tensor", - &decode_image) - .op("image::decode_jpegs_cuda", &decode_jpegs_cuda) - .op("image::encode_jpegs_cuda", &encode_jpegs_cuda) - .op("image::_jpeg_version", &_jpeg_version) - .op("image::_is_compiled_against_turbo", &_is_compiled_against_turbo); +// Register operators using stable ABI macros +STABLE_TORCH_LIBRARY(image, m) { + m.def("decode_gif(Tensor data) -> Tensor"); + m.def( + "decode_png(Tensor data, int mode, bool apply_exif_orientation=False) -> Tensor"); + m.def("encode_png(Tensor data, int compression_level) -> Tensor"); + m.def( + "decode_jpeg(Tensor data, int mode, bool apply_exif_orientation=False) -> Tensor"); + m.def("decode_webp(Tensor encoded_data, int mode) -> Tensor"); + m.def("encode_jpeg(Tensor data, int quality) -> Tensor"); + m.def("read_file(str filename) -> Tensor"); + m.def("write_file(str filename, Tensor data) -> ()"); + m.def( + "decode_image(Tensor data, int mode, bool apply_exif_orientation=False) -> Tensor"); + m.def( + "decode_jpegs_cuda(Tensor[] encoded_jpegs, int mode, Device device) -> Tensor[]"); + m.def("encode_jpegs_cuda(Tensor[] decoded_jpegs, int quality) -> Tensor[]"); + m.def("_jpeg_version() -> int"); + m.def("_is_compiled_against_turbo() -> bool"); +} + +STABLE_TORCH_LIBRARY_IMPL(image, CPU, m) { + m.impl("decode_gif", TORCH_BOX(&decode_gif)); + m.impl("decode_png", TORCH_BOX(&decode_png)); + m.impl("decode_jpeg", TORCH_BOX(&decode_jpeg)); + m.impl("decode_webp", TORCH_BOX(&decode_webp)); + m.impl("decode_image", TORCH_BOX(&decode_image)); +} + +// Ops without tensor inputs or with cross-device semantics need BackendSelect dispatch +// encode_jpeg/encode_png also use BackendSelect so they can give proper error messages +// when CUDA tensors are passed (instead of "no kernel for CUDA") +STABLE_TORCH_LIBRARY_IMPL(image, BackendSelect, m) { + m.impl("read_file", TORCH_BOX(&read_file)); + m.impl("write_file", TORCH_BOX(&write_file)); + m.impl("_jpeg_version", TORCH_BOX(&_jpeg_version)); + m.impl("_is_compiled_against_turbo", TORCH_BOX(&_is_compiled_against_turbo)); + // decode_jpegs_cuda takes CPU tensors as input but outputs CUDA tensors + m.impl("decode_jpegs_cuda", TORCH_BOX(&decode_jpegs_cuda)); + // encode functions need BackendSelect to provide proper error messages for CUDA inputs + m.impl("encode_png", TORCH_BOX(&encode_png)); + m.impl("encode_jpeg", TORCH_BOX(&encode_jpeg)); +} + +STABLE_TORCH_LIBRARY_IMPL(image, CUDA, m) { + // encode_jpegs_cuda takes CUDA tensors as input + m.impl("encode_jpegs_cuda", TORCH_BOX(&encode_jpegs_cuda)); +} } // namespace image } // namespace vision diff --git a/torchvision/csrc/ops/autocast/nms_kernel.cpp b/torchvision/csrc/ops/autocast/nms_kernel.cpp deleted file mode 100644 index 39482ceadbf..00000000000 --- a/torchvision/csrc/ops/autocast/nms_kernel.cpp +++ /dev/null @@ -1,49 +0,0 @@ -#include "../nms.h" - -#include -#include -#include - -namespace vision { -namespace ops { - -namespace { - -template -at::Tensor nms_autocast( - const at::Tensor& dets, - const at::Tensor& scores, - double iou_threshold) { - c10::impl::ExcludeDispatchKeyGuard no_autocast(autocast_key); - - return nms( - at::autocast::cached_cast(at::kFloat, dets, device_type), - at::autocast::cached_cast(at::kFloat, scores, device_type), - iou_threshold); -} - -} // namespace - -TORCH_LIBRARY_IMPL(torchvision, Autocast, m) { - m.impl( - TORCH_SELECTIVE_NAME("torchvision::nms"), - TORCH_FN( - (nms_autocast))); -} - -TORCH_LIBRARY_IMPL(torchvision, AutocastCPU, m) { - m.impl( - TORCH_SELECTIVE_NAME("torchvision::nms"), - TORCH_FN( - (nms_autocast))); -} - -TORCH_LIBRARY_IMPL(torchvision, AutocastXPU, m) { - m.impl( - TORCH_SELECTIVE_NAME("torchvision::nms"), - TORCH_FN( - (nms_autocast))); -} - -} // namespace ops -} // namespace vision diff --git a/torchvision/csrc/ops/autocast/ps_roi_align_kernel.cpp b/torchvision/csrc/ops/autocast/ps_roi_align_kernel.cpp deleted file mode 100644 index bce987b0f71..00000000000 --- a/torchvision/csrc/ops/autocast/ps_roi_align_kernel.cpp +++ /dev/null @@ -1,42 +0,0 @@ -#include "../ps_roi_align.h" - -#include -#include -#include - -namespace vision { -namespace ops { - -namespace { - -std::tuple ps_roi_align_autocast( - const at::Tensor& input, - const at::Tensor& rois, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, - int64_t sampling_ratio) { - c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast); - auto result = ps_roi_align( - at::autocast::cached_cast(at::kFloat, input), - at::autocast::cached_cast(at::kFloat, rois), - spatial_scale, - pooled_height, - pooled_width, - sampling_ratio); - - return std::make_tuple( - std::get<0>(result).to(input.scalar_type()), - std::get<1>(result).to(input.scalar_type())); -} - -} // namespace - -TORCH_LIBRARY_IMPL(torchvision, Autocast, m) { - m.impl( - TORCH_SELECTIVE_NAME("torchvision::ps_roi_align"), - TORCH_FN(ps_roi_align_autocast)); -} - -} // namespace ops -} // namespace vision diff --git a/torchvision/csrc/ops/autocast/ps_roi_pool_kernel.cpp b/torchvision/csrc/ops/autocast/ps_roi_pool_kernel.cpp deleted file mode 100644 index 3cf1e7f80d7..00000000000 --- a/torchvision/csrc/ops/autocast/ps_roi_pool_kernel.cpp +++ /dev/null @@ -1,40 +0,0 @@ -#include "../ps_roi_pool.h" - -#include -#include -#include - -namespace vision { -namespace ops { - -namespace { - -std::tuple ps_roi_pool_autocast( - const at::Tensor& input, - const at::Tensor& rois, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width) { - c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast); - auto result = ps_roi_pool( - at::autocast::cached_cast(at::kFloat, input), - at::autocast::cached_cast(at::kFloat, rois), - spatial_scale, - pooled_height, - pooled_width); - - return std::make_tuple( - std::get<0>(result).to(input.scalar_type()), - std::get<1>(result).to(input.scalar_type())); -} - -} // namespace - -TORCH_LIBRARY_IMPL(torchvision, Autocast, m) { - m.impl( - TORCH_SELECTIVE_NAME("torchvision::ps_roi_pool"), - TORCH_FN(ps_roi_pool_autocast)); -} - -} // namespace ops -} // namespace vision diff --git a/torchvision/csrc/ops/autocast/roi_align_kernel.cpp b/torchvision/csrc/ops/autocast/roi_align_kernel.cpp deleted file mode 100644 index 3eb8443b54d..00000000000 --- a/torchvision/csrc/ops/autocast/roi_align_kernel.cpp +++ /dev/null @@ -1,60 +0,0 @@ -#include "../roi_align.h" - -#include -#include -#include - -namespace vision { -namespace ops { - -namespace { - -template -at::Tensor roi_align_autocast( - const at::Tensor& input, - const at::Tensor& rois, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, - int64_t sampling_ratio, - bool aligned) { - c10::impl::ExcludeDispatchKeyGuard no_autocast(autocast_key); - return roi_align( - at::autocast::cached_cast(at::kFloat, input, device_type), - at::autocast::cached_cast(at::kFloat, rois, device_type), - spatial_scale, - pooled_height, - pooled_width, - sampling_ratio, - aligned) - .to(input.scalar_type()); -} - -} // namespace - -TORCH_LIBRARY_IMPL(torchvision, Autocast, m) { - m.impl( - TORCH_SELECTIVE_NAME("torchvision::roi_align"), - TORCH_FN((roi_align_autocast< - c10::DispatchKey::Autocast, - c10::DeviceType::CUDA>))); -} - -TORCH_LIBRARY_IMPL(torchvision, AutocastCPU, m) { - m.impl( - TORCH_SELECTIVE_NAME("torchvision::roi_align"), - TORCH_FN((roi_align_autocast< - c10::DispatchKey::AutocastCPU, - c10::DeviceType::CPU>))); -} - -TORCH_LIBRARY_IMPL(torchvision, AutocastXPU, m) { - m.impl( - TORCH_SELECTIVE_NAME("torchvision::roi_align"), - TORCH_FN((roi_align_autocast< - c10::DispatchKey::AutocastXPU, - c10::DeviceType::XPU>))); -} - -} // namespace ops -} // namespace vision diff --git a/torchvision/csrc/ops/autocast/roi_pool_kernel.cpp b/torchvision/csrc/ops/autocast/roi_pool_kernel.cpp deleted file mode 100644 index 3aaa038a9b4..00000000000 --- a/torchvision/csrc/ops/autocast/roi_pool_kernel.cpp +++ /dev/null @@ -1,40 +0,0 @@ -#include "../roi_pool.h" - -#include -#include -#include - -namespace vision { -namespace ops { - -namespace { - -std::tuple roi_pool_autocast( - const at::Tensor& input, - const at::Tensor& rois, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width) { - c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast); - auto result = roi_pool( - at::autocast::cached_cast(at::kFloat, input), - at::autocast::cached_cast(at::kFloat, rois), - spatial_scale, - pooled_height, - pooled_width); - - return std::make_tuple( - std::get<0>(result).to(input.scalar_type()), - std::get<1>(result).to(input.scalar_type())); -} - -} // namespace - -TORCH_LIBRARY_IMPL(torchvision, Autocast, m) { - m.impl( - TORCH_SELECTIVE_NAME("torchvision::roi_pool"), - TORCH_FN(roi_pool_autocast)); -} - -} // namespace ops -} // namespace vision diff --git a/torchvision/csrc/ops/autograd/ps_roi_align_kernel.cpp b/torchvision/csrc/ops/autograd/ps_roi_align_kernel.cpp deleted file mode 100644 index 01f7dd1aa76..00000000000 --- a/torchvision/csrc/ops/autograd/ps_roi_align_kernel.cpp +++ /dev/null @@ -1,169 +0,0 @@ -#include "../ps_roi_align.h" - -#include -#include - -#include - -namespace vision { -namespace ops { - -namespace { - -class PSROIAlignFunction - : public torch::autograd::Function { - public: - static torch::autograd::variable_list forward( - torch::autograd::AutogradContext* ctx, - const torch::autograd::Variable& input, - const torch::autograd::Variable& rois, - double spatial_scale, - const c10::SymInt& pooled_height, - const c10::SymInt& pooled_width, - int64_t sampling_ratio) { - ctx->saved_data["spatial_scale"] = spatial_scale; - ctx->saved_data["pooled_height"] = pooled_height; - ctx->saved_data["pooled_width"] = pooled_width; - ctx->saved_data["sampling_ratio"] = sampling_ratio; - ctx->saved_data["input_shape"] = input.sym_sizes(); - at::AutoDispatchBelowADInplaceOrView g; - auto result = ps_roi_align_symint( - input, - rois, - spatial_scale, - pooled_height, - pooled_width, - sampling_ratio); - - auto output = std::get<0>(result); - auto channel_mapping = std::get<1>(result); - ctx->save_for_backward({rois, channel_mapping}); - ctx->mark_non_differentiable({channel_mapping}); - - return {output, channel_mapping}; - } - - static torch::autograd::variable_list backward( - torch::autograd::AutogradContext* ctx, - const torch::autograd::variable_list& grad_output) { - // Use data saved in forward - auto saved = ctx->get_saved_variables(); - auto rois = saved[0]; - auto channel_mapping = saved[1]; - auto input_shape = ctx->saved_data["input_shape"].toList(); - auto grad_in = detail::_ps_roi_align_backward_symint( - grad_output[0], - rois, - channel_mapping, - ctx->saved_data["spatial_scale"].toDouble(), - ctx->saved_data["pooled_height"].toSymInt(), - ctx->saved_data["pooled_width"].toSymInt(), - ctx->saved_data["sampling_ratio"].toInt(), - input_shape[0].get().toSymInt(), - input_shape[1].get().toSymInt(), - input_shape[2].get().toSymInt(), - input_shape[3].get().toSymInt()); - - return { - grad_in, - torch::autograd::Variable(), - torch::autograd::Variable(), - torch::autograd::Variable(), - torch::autograd::Variable(), - torch::autograd::Variable()}; - } -}; - -// TODO: There should be an easier way to do this -class PSROIAlignBackwardFunction - : public torch::autograd::Function { - public: - static torch::autograd::variable_list forward( - torch::autograd::AutogradContext* ctx, - const torch::autograd::Variable& grad, - const torch::autograd::Variable& rois, - const torch::autograd::Variable& channel_mapping, - double spatial_scale, - c10::SymInt pooled_height, - c10::SymInt pooled_width, - int64_t sampling_ratio, - c10::SymInt batch_size, - c10::SymInt channels, - c10::SymInt height, - c10::SymInt width) { - at::AutoDispatchBelowADInplaceOrView g; - auto grad_in = detail::_ps_roi_align_backward_symint( - grad, - rois, - channel_mapping, - spatial_scale, - std::move(pooled_height), - std::move(pooled_width), - sampling_ratio, - std::move(batch_size), - std::move(channels), - std::move(height), - std::move(width)); - - return {grad_in}; - } - - static torch::autograd::variable_list backward( - torch::autograd::AutogradContext* ctx, - const torch::autograd::variable_list& grad_output) { - TORCH_CHECK(0, "double backwards on ps_roi_align not supported"); - } -}; - -std::tuple ps_roi_align_autograd( - const at::Tensor& input, - const at::Tensor& rois, - double spatial_scale, - c10::SymInt pooled_height, - c10::SymInt pooled_width, - int64_t sampling_ratio) { - auto result = PSROIAlignFunction::apply( - input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio); - - return std::make_tuple(result[0], result[1]); -} - -at::Tensor ps_roi_align_backward_autograd( - const at::Tensor& grad, - const at::Tensor& rois, - const at::Tensor& channel_mapping, - double spatial_scale, - c10::SymInt pooled_height, - c10::SymInt pooled_width, - int64_t sampling_ratio, - c10::SymInt batch_size, - c10::SymInt channels, - c10::SymInt height, - c10::SymInt width) { - return PSROIAlignBackwardFunction::apply( - grad, - rois, - channel_mapping, - spatial_scale, - pooled_height, - pooled_width, - sampling_ratio, - batch_size, - channels, - height, - width)[0]; -} - -} // namespace - -TORCH_LIBRARY_IMPL(torchvision, Autograd, m) { - m.impl( - TORCH_SELECTIVE_NAME("torchvision::ps_roi_align"), - TORCH_FN(ps_roi_align_autograd)); - m.impl( - TORCH_SELECTIVE_NAME("torchvision::_ps_roi_align_backward"), - TORCH_FN(ps_roi_align_backward_autograd)); -} - -} // namespace ops -} // namespace vision diff --git a/torchvision/csrc/ops/autograd/ps_roi_pool_kernel.cpp b/torchvision/csrc/ops/autograd/ps_roi_pool_kernel.cpp deleted file mode 100644 index 5c3315bb52a..00000000000 --- a/torchvision/csrc/ops/autograd/ps_roi_pool_kernel.cpp +++ /dev/null @@ -1,154 +0,0 @@ -#include "../ps_roi_pool.h" - -#include -#include - -#include - -namespace vision { -namespace ops { - -namespace { - -class PSROIPoolFunction : public torch::autograd::Function { - public: - static torch::autograd::variable_list forward( - torch::autograd::AutogradContext* ctx, - const torch::autograd::Variable& input, - const torch::autograd::Variable& rois, - double spatial_scale, - const c10::SymInt& pooled_height, - const c10::SymInt& pooled_width) { - ctx->saved_data["spatial_scale"] = spatial_scale; - ctx->saved_data["pooled_height"] = pooled_height; - ctx->saved_data["pooled_width"] = pooled_width; - ctx->saved_data["input_shape"] = input.sym_sizes(); - at::AutoDispatchBelowADInplaceOrView g; - auto result = ps_roi_pool_symint( - input, rois, spatial_scale, pooled_height, pooled_width); - - auto output = std::get<0>(result); - auto channel_mapping = std::get<1>(result); - ctx->save_for_backward({rois, channel_mapping}); - ctx->mark_non_differentiable({channel_mapping}); - - return {output, channel_mapping}; - } - - static torch::autograd::variable_list backward( - torch::autograd::AutogradContext* ctx, - const torch::autograd::variable_list& grad_output) { - // Use data saved in forward - auto saved = ctx->get_saved_variables(); - auto rois = saved[0]; - auto channel_mapping = saved[1]; - auto input_shape = ctx->saved_data["input_shape"].toList(); - auto grad_in = detail::_ps_roi_pool_backward_symint( - grad_output[0], - rois, - channel_mapping, - ctx->saved_data["spatial_scale"].toDouble(), - ctx->saved_data["pooled_height"].toSymInt(), - ctx->saved_data["pooled_width"].toSymInt(), - input_shape[0].get().toSymInt(), - input_shape[1].get().toSymInt(), - input_shape[2].get().toSymInt(), - input_shape[3].get().toSymInt()); - - return { - grad_in, - torch::autograd::Variable(), - torch::autograd::Variable(), - torch::autograd::Variable(), - torch::autograd::Variable()}; - } -}; - -// TODO: There should be an easier way to do this -class PSROIPoolBackwardFunction - : public torch::autograd::Function { - public: - static torch::autograd::variable_list forward( - torch::autograd::AutogradContext* ctx, - const torch::autograd::Variable& grad, - const torch::autograd::Variable& rois, - const torch::autograd::Variable& channel_mapping, - double spatial_scale, - c10::SymInt pooled_height, - c10::SymInt pooled_width, - c10::SymInt batch_size, - c10::SymInt channels, - c10::SymInt height, - c10::SymInt width) { - at::AutoDispatchBelowADInplaceOrView g; - auto grad_in = detail::_ps_roi_pool_backward_symint( - grad, - rois, - channel_mapping, - spatial_scale, - std::move(pooled_height), - std::move(pooled_width), - std::move(batch_size), - std::move(channels), - std::move(height), - std::move(width)); - - return {grad_in}; - } - - static torch::autograd::variable_list backward( - torch::autograd::AutogradContext* ctx, - const torch::autograd::variable_list& grad_output) { - TORCH_CHECK(0, "double backwards on ps_roi_pool not supported"); - } -}; - -std::tuple ps_roi_pool_autograd( - const at::Tensor& input, - const at::Tensor& rois, - double spatial_scale, - c10::SymInt pooled_height, - c10::SymInt pooled_width) { - auto result = PSROIPoolFunction::apply( - input, rois, spatial_scale, pooled_height, pooled_width); - - return std::make_tuple(result[0], result[1]); -} - -at::Tensor ps_roi_pool_backward_autograd( - const at::Tensor& grad, - const at::Tensor& rois, - const at::Tensor& channel_mapping, - double spatial_scale, - c10::SymInt pooled_height, - c10::SymInt pooled_width, - c10::SymInt batch_size, - c10::SymInt channels, - c10::SymInt height, - c10::SymInt width) { - return PSROIPoolBackwardFunction::apply( - grad, - rois, - channel_mapping, - spatial_scale, - pooled_height, - pooled_width, - batch_size, - channels, - height, - width)[0]; -} - -} // namespace - -TORCH_LIBRARY_IMPL(torchvision, Autograd, m) { - m.impl( - TORCH_SELECTIVE_NAME("torchvision::ps_roi_pool"), - TORCH_FN(ps_roi_pool_autograd)); - m.impl( - TORCH_SELECTIVE_NAME("torchvision::_ps_roi_pool_backward"), - TORCH_FN(ps_roi_pool_backward_autograd)); -} - -} // namespace ops -} // namespace vision diff --git a/torchvision/csrc/ops/autograd/roi_align_kernel.cpp b/torchvision/csrc/ops/autograd/roi_align_kernel.cpp deleted file mode 100644 index 0a1ae55b971..00000000000 --- a/torchvision/csrc/ops/autograd/roi_align_kernel.cpp +++ /dev/null @@ -1,169 +0,0 @@ -#include "../roi_align.h" - -#include -#include - -#include - -namespace vision { -namespace ops { - -namespace { - -class ROIAlignFunction : public torch::autograd::Function { - public: - static torch::autograd::variable_list forward( - torch::autograd::AutogradContext* ctx, - const torch::autograd::Variable& input, - const torch::autograd::Variable& rois, - double spatial_scale, - const c10::SymInt& pooled_height, - const c10::SymInt& pooled_width, - int64_t sampling_ratio, - bool aligned) { - ctx->saved_data["spatial_scale"] = spatial_scale; - ctx->saved_data["pooled_height"] = pooled_height; - ctx->saved_data["pooled_width"] = pooled_width; - ctx->saved_data["sampling_ratio"] = sampling_ratio; - ctx->saved_data["aligned"] = aligned; - ctx->saved_data["input_shape"] = input.sym_sizes(); - ctx->save_for_backward({rois}); - at::AutoDispatchBelowADInplaceOrView g; - auto result = roi_align_symint( - input, - rois, - spatial_scale, - pooled_height, - pooled_width, - sampling_ratio, - aligned); - return {result}; - } - - static torch::autograd::variable_list backward( - torch::autograd::AutogradContext* ctx, - const torch::autograd::variable_list& grad_output) { - // Use data saved in forward - auto saved = ctx->get_saved_variables(); - auto rois = saved[0]; - auto input_shape = ctx->saved_data["input_shape"].toList(); - auto grad_in = detail::_roi_align_backward_symint( - grad_output[0], - rois, - ctx->saved_data["spatial_scale"].toDouble(), - ctx->saved_data["pooled_height"].toSymInt(), - ctx->saved_data["pooled_width"].toSymInt(), - input_shape[0].get().toSymInt(), - input_shape[1].get().toSymInt(), - input_shape[2].get().toSymInt(), - input_shape[3].get().toSymInt(), - ctx->saved_data["sampling_ratio"].toInt(), - ctx->saved_data["aligned"].toBool()); - return { - grad_in, - torch::autograd::Variable(), - torch::autograd::Variable(), - torch::autograd::Variable(), - torch::autograd::Variable(), - torch::autograd::Variable(), - torch::autograd::Variable()}; - } -}; - -// TODO: There should be an easier way to do this -class ROIAlignBackwardFunction - : public torch::autograd::Function { - public: - static torch::autograd::variable_list forward( - torch::autograd::AutogradContext* ctx, - const torch::autograd::Variable& grad, - const torch::autograd::Variable& rois, - double spatial_scale, - c10::SymInt pooled_height, - c10::SymInt pooled_width, - c10::SymInt batch_size, - c10::SymInt channels, - c10::SymInt height, - c10::SymInt width, - int64_t sampling_ratio, - bool aligned) { - at::AutoDispatchBelowADInplaceOrView g; - auto result = detail::_roi_align_backward_symint( - grad, - rois, - spatial_scale, - std::move(pooled_height), - std::move(pooled_width), - std::move(batch_size), - std::move(channels), - std::move(height), - std::move(width), - sampling_ratio, - aligned); - return {result}; - } - - static torch::autograd::variable_list backward( - torch::autograd::AutogradContext* ctx, - const torch::autograd::variable_list& grad_output) { - TORCH_CHECK(0, "double backwards on roi_align not supported"); - } -}; - -at::Tensor roi_align_autograd( - const at::Tensor& input, - const at::Tensor& rois, - double spatial_scale, - c10::SymInt pooled_height, - c10::SymInt pooled_width, - int64_t sampling_ratio, - bool aligned) { - return ROIAlignFunction::apply( - input, - rois, - spatial_scale, - pooled_height, - pooled_width, - sampling_ratio, - aligned)[0]; -} - -at::Tensor roi_align_backward_autograd( - const at::Tensor& grad, - const at::Tensor& rois, - double spatial_scale, - c10::SymInt pooled_height, - c10::SymInt pooled_width, - c10::SymInt batch_size, - c10::SymInt channels, - c10::SymInt height, - c10::SymInt width, - int64_t sampling_ratio, - bool aligned) { - return ROIAlignBackwardFunction::apply( - grad, - rois, - spatial_scale, - pooled_height, - pooled_width, - batch_size, - channels, - height, - width, - sampling_ratio, - aligned)[0]; -} - -} // namespace - -TORCH_LIBRARY_IMPL(torchvision, Autograd, m) { - m.impl( - TORCH_SELECTIVE_NAME("torchvision::roi_align"), - TORCH_FN(roi_align_autograd)); - m.impl( - TORCH_SELECTIVE_NAME("torchvision::_roi_align_backward"), - TORCH_FN(roi_align_backward_autograd)); -} - -} // namespace ops -} // namespace vision diff --git a/torchvision/csrc/ops/autograd/roi_pool_kernel.cpp b/torchvision/csrc/ops/autograd/roi_pool_kernel.cpp deleted file mode 100644 index 4944a731c6b..00000000000 --- a/torchvision/csrc/ops/autograd/roi_pool_kernel.cpp +++ /dev/null @@ -1,154 +0,0 @@ -#include "../roi_pool.h" - -#include -#include - -#include - -namespace vision { -namespace ops { - -namespace { - -class ROIPoolFunction : public torch::autograd::Function { - public: - static torch::autograd::variable_list forward( - torch::autograd::AutogradContext* ctx, - const torch::autograd::Variable& input, - const torch::autograd::Variable& rois, - double spatial_scale, - const c10::SymInt& pooled_height, - const c10::SymInt& pooled_width) { - ctx->saved_data["spatial_scale"] = spatial_scale; - ctx->saved_data["pooled_height"] = pooled_height; - ctx->saved_data["pooled_width"] = pooled_width; - ctx->saved_data["input_shape"] = input.sym_sizes(); - at::AutoDispatchBelowADInplaceOrView g; - auto result = roi_pool_symint( - input, rois, spatial_scale, pooled_height, pooled_width); - - auto output = std::get<0>(result); - auto argmax = std::get<1>(result); - ctx->save_for_backward({rois, argmax}); - ctx->mark_non_differentiable({argmax}); - - return {output, argmax}; - } - - static torch::autograd::variable_list backward( - torch::autograd::AutogradContext* ctx, - const torch::autograd::variable_list& grad_output) { - // Use data saved in forward - auto saved = ctx->get_saved_variables(); - auto rois = saved[0]; - auto argmax = saved[1]; - auto input_shape = ctx->saved_data["input_shape"].toList(); - auto grad_in = detail::_roi_pool_backward_symint( - grad_output[0], - rois, - argmax, - ctx->saved_data["spatial_scale"].toDouble(), - ctx->saved_data["pooled_height"].toSymInt(), - ctx->saved_data["pooled_width"].toSymInt(), - input_shape[0].get().toSymInt(), - input_shape[1].get().toSymInt(), - input_shape[2].get().toSymInt(), - input_shape[3].get().toSymInt()); - - return { - grad_in, - torch::autograd::Variable(), - torch::autograd::Variable(), - torch::autograd::Variable(), - torch::autograd::Variable()}; - } -}; - -// TODO: There should be an easier way to do this -class ROIPoolBackwardFunction - : public torch::autograd::Function { - public: - static torch::autograd::variable_list forward( - torch::autograd::AutogradContext* ctx, - const torch::autograd::Variable& grad, - const torch::autograd::Variable& rois, - const torch::autograd::Variable& argmax, - double spatial_scale, - c10::SymInt pooled_height, - c10::SymInt pooled_width, - c10::SymInt batch_size, - c10::SymInt channels, - c10::SymInt height, - c10::SymInt width) { - at::AutoDispatchBelowADInplaceOrView g; - auto grad_in = detail::_roi_pool_backward_symint( - grad, - rois, - argmax, - spatial_scale, - std::move(pooled_height), - std::move(pooled_width), - std::move(batch_size), - std::move(channels), - std::move(height), - std::move(width)); - - return {grad_in}; - } - - static torch::autograd::variable_list backward( - torch::autograd::AutogradContext* ctx, - const torch::autograd::variable_list& grad_output) { - TORCH_CHECK(0, "double backwards on roi_pool not supported"); - } -}; - -std::tuple roi_pool_autograd( - const at::Tensor& input, - const at::Tensor& rois, - double spatial_scale, - c10::SymInt pooled_height, - c10::SymInt pooled_width) { - auto result = ROIPoolFunction::apply( - input, rois, spatial_scale, pooled_height, pooled_width); - - return std::make_tuple(result[0], result[1]); -} - -at::Tensor roi_pool_backward_autograd( - const at::Tensor& grad, - const at::Tensor& rois, - const at::Tensor& argmax, - double spatial_scale, - c10::SymInt pooled_height, - c10::SymInt pooled_width, - c10::SymInt batch_size, - c10::SymInt channels, - c10::SymInt height, - c10::SymInt width) { - return ROIPoolBackwardFunction::apply( - grad, - rois, - argmax, - spatial_scale, - pooled_height, - pooled_width, - batch_size, - channels, - height, - width)[0]; -} - -} // namespace - -TORCH_LIBRARY_IMPL(torchvision, Autograd, m) { - m.impl( - TORCH_SELECTIVE_NAME("torchvision::roi_pool"), - TORCH_FN(roi_pool_autograd)); - m.impl( - TORCH_SELECTIVE_NAME("torchvision::_roi_pool_backward"), - TORCH_FN(roi_pool_backward_autograd)); -} - -} // namespace ops -} // namespace vision diff --git a/torchvision/csrc/ops/cpu/deform_conv2d_kernel.cpp b/torchvision/csrc/ops/cpu/deform_conv2d_kernel.cpp index f89e6cc3030..4bdfe09e664 100644 --- a/torchvision/csrc/ops/cpu/deform_conv2d_kernel.cpp +++ b/torchvision/csrc/ops/cpu/deform_conv2d_kernel.cpp @@ -88,8 +88,8 @@ scalar_t bilinear_interpolate( return 0; } - int h_low = floor(h); - int w_low = floor(w); + int h_low = std::floor(h); + int w_low = std::floor(w); int h_high = h_low + 1; int w_high = w_low + 1; @@ -389,8 +389,8 @@ scalar_t get_coordinate_weight( scalar_t y, scalar_t x, bool is_y_direction) { - int y_l = floor(y); - int x_l = floor(x); + int y_l = std::floor(y); + int x_l = std::floor(x); int y_h = y_l + 1; int x_h = x_l + 1; diff --git a/torchvision/csrc/ops/cpu/nms_kernel.cpp b/torchvision/csrc/ops/cpu/nms_kernel.cpp index 454ce118a6d..4811f98967e 100644 --- a/torchvision/csrc/ops/cpu/nms_kernel.cpp +++ b/torchvision/csrc/ops/cpu/nms_kernel.cpp @@ -1,48 +1,57 @@ -#include -#include +#include "../../StableABICompat.h" +#include namespace vision { namespace ops { namespace { +using namespace vision::stable; + template -at::Tensor nms_kernel_impl( - const at::Tensor& dets, - const at::Tensor& scores, +Tensor nms_kernel_impl( + const Tensor& dets, + const Tensor& scores, double iou_threshold) { - TORCH_CHECK(dets.is_cpu(), "dets must be a CPU tensor"); - TORCH_CHECK(scores.is_cpu(), "scores must be a CPU tensor"); - TORCH_CHECK( + VISION_CHECK(dets.is_cpu(), "dets must be a CPU tensor"); + VISION_CHECK(scores.is_cpu(), "scores must be a CPU tensor"); + VISION_CHECK( dets.scalar_type() == scores.scalar_type(), "dets should have the same type as scores"); if (dets.numel() == 0) { - return at::empty({0}, dets.options().dtype(at::kLong)); + return empty({0}, kLong, Device(kCPU)); } - auto x1_t = dets.select(1, 0).contiguous(); - auto y1_t = dets.select(1, 1).contiguous(); - auto x2_t = dets.select(1, 2).contiguous(); - auto y2_t = dets.select(1, 3).contiguous(); + auto x1_t = torch::stable::contiguous(torch::stable::select(dets, 1, 0)); + auto y1_t = torch::stable::contiguous(torch::stable::select(dets, 1, 1)); + auto x2_t = torch::stable::contiguous(torch::stable::select(dets, 1, 2)); + auto y2_t = torch::stable::contiguous(torch::stable::select(dets, 1, 3)); - at::Tensor areas_t = (x2_t - x1_t) * (y2_t - y1_t); + // Compute areas: (x2 - x1) * (y2 - y1) + // Need to do this manually with data pointers + auto ndets = dets.size(0); + Tensor areas_t = empty({ndets}, dets.scalar_type(), Device(kCPU)); - auto order_t = std::get<1>( - scores.sort(/*stable=*/true, /*dim=*/0, /* descending=*/true)); + auto x1_ptr = x1_t.const_data_ptr(); + auto y1_ptr = y1_t.const_data_ptr(); + auto x2_ptr = x2_t.const_data_ptr(); + auto y2_ptr = y2_t.const_data_ptr(); + auto areas_ptr = areas_t.mutable_data_ptr(); - auto ndets = dets.size(0); - at::Tensor suppressed_t = at::zeros({ndets}, dets.options().dtype(at::kByte)); - at::Tensor keep_t = at::zeros({ndets}, dets.options().dtype(at::kLong)); - - auto suppressed = suppressed_t.data_ptr(); - auto keep = keep_t.data_ptr(); - auto order = order_t.data_ptr(); - auto x1 = x1_t.data_ptr(); - auto y1 = y1_t.data_ptr(); - auto x2 = x2_t.data_ptr(); - auto y2 = y2_t.data_ptr(); - auto areas = areas_t.data_ptr(); + for (int64_t i = 0; i < ndets; i++) { + areas_ptr[i] = (x2_ptr[i] - x1_ptr[i]) * (y2_ptr[i] - y1_ptr[i]); + } + + // Sort scores descending + auto [sorted_scores, order_t] = sort(scores, /*dim=*/0, /*descending=*/true); + + Tensor suppressed_t = zeros({ndets}, kByte, Device(kCPU)); + Tensor keep_t = zeros({ndets}, kLong, Device(kCPU)); + + auto suppressed = suppressed_t.mutable_data_ptr(); + auto keep = keep_t.mutable_data_ptr(); + auto order = order_t.const_data_ptr(); int64_t num_to_keep = 0; @@ -52,50 +61,50 @@ at::Tensor nms_kernel_impl( continue; } keep[num_to_keep++] = i; - auto ix1 = x1[i]; - auto iy1 = y1[i]; - auto ix2 = x2[i]; - auto iy2 = y2[i]; - auto iarea = areas[i]; + auto ix1 = x1_ptr[i]; + auto iy1 = y1_ptr[i]; + auto ix2 = x2_ptr[i]; + auto iy2 = y2_ptr[i]; + auto iarea = areas_ptr[i]; for (int64_t _j = _i + 1; _j < ndets; _j++) { auto j = order[_j]; if (suppressed[j] == 1) { continue; } - auto xx1 = std::max(ix1, x1[j]); - auto yy1 = std::max(iy1, y1[j]); - auto xx2 = std::min(ix2, x2[j]); - auto yy2 = std::min(iy2, y2[j]); + auto xx1 = std::max(ix1, x1_ptr[j]); + auto yy1 = std::max(iy1, y1_ptr[j]); + auto xx2 = std::min(ix2, x2_ptr[j]); + auto yy2 = std::min(iy2, y2_ptr[j]); auto w = std::max(static_cast(0), xx2 - xx1); auto h = std::max(static_cast(0), yy2 - yy1); auto inter = w * h; - auto ovr = inter / (iarea + areas[j] - inter); + auto ovr = inter / (iarea + areas_ptr[j] - inter); if (ovr > iou_threshold) { suppressed[j] = 1; } } } - return keep_t.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep); + return torch::stable::narrow(keep_t, /*dim=*/0, /*start=*/0, /*length=*/num_to_keep); } -at::Tensor nms_kernel( - const at::Tensor& dets, - const at::Tensor& scores, +Tensor nms_kernel( + const Tensor& dets, + const Tensor& scores, double iou_threshold) { - TORCH_CHECK( + VISION_CHECK( dets.dim() == 2, "boxes should be a 2d tensor, got ", dets.dim(), "D"); - TORCH_CHECK( + VISION_CHECK( dets.size(1) == 4, "boxes should have 4 elements in dimension 1, got ", dets.size(1)); - TORCH_CHECK( + VISION_CHECK( scores.dim() == 1, "scores should be a 1d tensor, got ", scores.dim(), "D"); - TORCH_CHECK( + VISION_CHECK( dets.size(0) == scores.size(0), "boxes and scores should have same number of elements in ", "dimension 0, got ", @@ -103,18 +112,23 @@ at::Tensor nms_kernel( " and ", scores.size(0)); - auto result = at::empty({0}, dets.options()); + Tensor result = empty({0}, kLong, Device(kCPU)); - AT_DISPATCH_FLOATING_TYPES(dets.scalar_type(), "nms_kernel", [&] { - result = nms_kernel_impl(dets, scores, iou_threshold); - }); + auto dtype = dets.scalar_type(); + if (dtype == kFloat) { + result = nms_kernel_impl(dets, scores, iou_threshold); + } else if (dtype == kDouble) { + result = nms_kernel_impl(dets, scores, iou_threshold); + } else { + VISION_CHECK(false, "nms only supports float and double types"); + } return result; } } // namespace -TORCH_LIBRARY_IMPL(torchvision, CPU, m) { - m.impl(TORCH_SELECTIVE_NAME("torchvision::nms"), TORCH_FN(nms_kernel)); +STABLE_TORCH_LIBRARY_IMPL(torchvision, CPU, m) { + m.impl("nms", TORCH_BOX(&nms_kernel)); } } // namespace ops diff --git a/torchvision/csrc/ops/cpu/ps_roi_align_kernel.cpp b/torchvision/csrc/ops/cpu/ps_roi_align_kernel.cpp index 820029c73d5..277de0ab99b 100644 --- a/torchvision/csrc/ops/cpu/ps_roi_align_kernel.cpp +++ b/torchvision/csrc/ops/cpu/ps_roi_align_kernel.cpp @@ -1,11 +1,13 @@ -#include -#include +#include "../../StableABICompat.h" +#include namespace vision { namespace ops { namespace { +using namespace vision::stable; + template T bilinear_interpolate( const T* input, @@ -108,10 +110,10 @@ void ps_roi_align_forward_kernel_impl( // We use roi_bin_grid to sample the grid and mimic integral int roi_bin_grid_h = (sampling_ratio > 0) ? sampling_ratio - : ceil(roi_height / pooled_height); + : std::ceil(roi_height / pooled_height); int roi_bin_grid_w = (sampling_ratio > 0) ? sampling_ratio - : ceil(roi_width / pooled_width); + : std::ceil(roi_width / pooled_width); const T count = roi_bin_grid_h * roi_bin_grid_w; const T* offset_input = @@ -256,9 +258,9 @@ void ps_roi_align_backward_kernel_impl( // We use roi_bin_grid to sample the grid and mimic integral int roi_bin_grid_h = (sampling_ratio > 0) ? sampling_ratio - : ceil(roi_height / pooled_height); // e.g., = 2 + : std::ceil(roi_height / pooled_height); // e.g., = 2 int roi_bin_grid_w = - (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); + (sampling_ratio > 0) ? sampling_ratio : std::ceil(roi_width / pooled_width); const T count = roi_bin_grid_h * roi_bin_grid_w; for (int iy = 0; iy < roi_bin_grid_h; iy++) { @@ -304,68 +306,90 @@ void ps_roi_align_backward_kernel_impl( } } -std::tuple ps_roi_align_forward_kernel( - const at::Tensor& input, - const at::Tensor& rois, +std::tuple ps_roi_align_forward_kernel( + const Tensor& input, + const Tensor& rois, double spatial_scale, int64_t pooled_height, int64_t pooled_width, int64_t sampling_ratio) { // Check if input tensors are CPU tensors - TORCH_CHECK(input.device().is_cpu(), "input must be a CPU tensor"); - TORCH_CHECK(rois.device().is_cpu(), "rois must be a CPU tensor"); - TORCH_CHECK( + VISION_CHECK(input.is_cpu(), "input must be a CPU tensor"); + VISION_CHECK(rois.is_cpu(), "rois must be a CPU tensor"); + VISION_CHECK( rois.size(1) == 5, "Tensor rois should have shape as Tensor[K, 5]"); - - at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2}; - - at::CheckedFrom c = "ps_roi_align_forward_kernel"; - at::checkAllSameType(c, {input_t, rois_t}); + VISION_CHECK( + input.scalar_type() == rois.scalar_type(), + "input and rois must have the same dtype"); int num_rois = rois.size(0); int channels = input.size(1); int height = input.size(2); int width = input.size(3); - TORCH_CHECK( + VISION_CHECK( channels % (pooled_height * pooled_width) == 0, "input channels must be a multiple of pooling height * pooling width"); int channels_out = channels / (pooled_height * pooled_width); - auto output = at::zeros( - {num_rois, channels_out, pooled_height, pooled_width}, input.options()); - auto channel_mapping = - at::zeros(output.sizes(), input.options().dtype(at::kInt)); + Tensor output = zeros( + {num_rois, channels_out, pooled_height, pooled_width}, + input.scalar_type(), + Device(kCPU)); + Tensor channel_mapping = zeros( + {num_rois, channels_out, pooled_height, pooled_width}, + kInt, + Device(kCPU)); if (output.numel() == 0) { return std::make_tuple(output, channel_mapping); } - auto input_ = input.contiguous(), rois_ = rois.contiguous(); - AT_DISPATCH_FLOATING_TYPES_AND_HALF( - input.scalar_type(), "ps_roi_align_forward_kernel", [&] { - ps_roi_align_forward_kernel_impl( - num_rois, - input_.data_ptr(), - spatial_scale, - channels, - height, - width, - pooled_height, - pooled_width, - sampling_ratio, - rois_.data_ptr(), - channels_out, - output.data_ptr(), - channel_mapping.data_ptr()); - }); + auto input_ = torch::stable::contiguous(input); + auto rois_ = torch::stable::contiguous(rois); + + auto dtype = input.scalar_type(); + if (dtype == kFloat) { + ps_roi_align_forward_kernel_impl( + num_rois, + input_.const_data_ptr(), + spatial_scale, + channels, + height, + width, + pooled_height, + pooled_width, + sampling_ratio, + rois_.const_data_ptr(), + channels_out, + output.mutable_data_ptr(), + channel_mapping.mutable_data_ptr()); + } else if (dtype == kDouble) { + ps_roi_align_forward_kernel_impl( + num_rois, + input_.const_data_ptr(), + spatial_scale, + channels, + height, + width, + pooled_height, + pooled_width, + sampling_ratio, + rois_.const_data_ptr(), + channels_out, + output.mutable_data_ptr(), + channel_mapping.mutable_data_ptr()); + } else { + VISION_CHECK( + false, "ps_roi_align only supports float and double types"); + } return std::make_tuple(output, channel_mapping); } -at::Tensor ps_roi_align_backward_kernel( - const at::Tensor& grad, - const at::Tensor& rois, - const at::Tensor& channel_mapping, +Tensor ps_roi_align_backward_kernel( + const Tensor& grad, + const Tensor& rois, + const Tensor& channel_mapping, double spatial_scale, int64_t pooled_height, int64_t pooled_width, @@ -375,20 +399,19 @@ at::Tensor ps_roi_align_backward_kernel( int64_t height, int64_t width) { // Check if input tensors are CPU tensors - TORCH_CHECK(grad.device().is_cpu(), "grad must be a CPU tensor"); - TORCH_CHECK(rois.device().is_cpu(), "rois must be a CPU tensor"); - TORCH_CHECK( - channel_mapping.device().is_cpu(), + VISION_CHECK(grad.is_cpu(), "grad must be a CPU tensor"); + VISION_CHECK(rois.is_cpu(), "rois must be a CPU tensor"); + VISION_CHECK( + channel_mapping.is_cpu(), "channel_mapping must be a CPU tensor"); + VISION_CHECK( + grad.scalar_type() == rois.scalar_type(), + "grad and rois must have the same dtype"); - at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2}, - channel_mapping_t{channel_mapping, "channel_mapping", 3}; - - at::CheckedFrom c = "ps_roi_align_backward_kernel"; - at::checkAllSameType(c, {grad_t, rois_t}); - - auto grad_input = - at::zeros({batch_size, channels, height, width}, grad.options()); + Tensor grad_input = zeros( + {batch_size, channels, height, width}, + grad.scalar_type(), + Device(kCPU)); // handle possibly empty gradients if (grad.numel() == 0) { @@ -397,36 +420,52 @@ at::Tensor ps_roi_align_backward_kernel( int channels_out = channels / (pooled_height * pooled_width); - auto grad_ = grad.contiguous(), rois_ = rois.contiguous(); - AT_DISPATCH_FLOATING_TYPES_AND_HALF( - grad.scalar_type(), "ps_roi_align_backward_kernel", [&] { - ps_roi_align_backward_kernel_impl( - grad.numel(), - grad_.data_ptr(), - channel_mapping.data_ptr(), - spatial_scale, - channels, - height, - width, - pooled_height, - pooled_width, - sampling_ratio, - channels_out, - grad_input.data_ptr(), - rois_.data_ptr()); - }); + auto grad_ = torch::stable::contiguous(grad); + auto rois_ = torch::stable::contiguous(rois); + + auto dtype = grad.scalar_type(); + if (dtype == kFloat) { + ps_roi_align_backward_kernel_impl( + grad.numel(), + grad_.const_data_ptr(), + channel_mapping.const_data_ptr(), + spatial_scale, + channels, + height, + width, + pooled_height, + pooled_width, + sampling_ratio, + channels_out, + grad_input.mutable_data_ptr(), + rois_.const_data_ptr()); + } else if (dtype == kDouble) { + ps_roi_align_backward_kernel_impl( + grad.numel(), + grad_.const_data_ptr(), + channel_mapping.const_data_ptr(), + spatial_scale, + channels, + height, + width, + pooled_height, + pooled_width, + sampling_ratio, + channels_out, + grad_input.mutable_data_ptr(), + rois_.const_data_ptr()); + } else { + VISION_CHECK( + false, "ps_roi_align backward only supports float and double types"); + } return grad_input; } } // namespace -TORCH_LIBRARY_IMPL(torchvision, CPU, m) { - m.impl( - TORCH_SELECTIVE_NAME("torchvision::ps_roi_align"), - TORCH_FN(ps_roi_align_forward_kernel)); - m.impl( - TORCH_SELECTIVE_NAME("torchvision::_ps_roi_align_backward"), - TORCH_FN(ps_roi_align_backward_kernel)); +STABLE_TORCH_LIBRARY_IMPL(torchvision, CPU, m) { + m.impl("ps_roi_align", TORCH_BOX(&ps_roi_align_forward_kernel)); + m.impl("_ps_roi_align_backward", TORCH_BOX(&ps_roi_align_backward_kernel)); } } // namespace ops diff --git a/torchvision/csrc/ops/cpu/ps_roi_pool_kernel.cpp b/torchvision/csrc/ops/cpu/ps_roi_pool_kernel.cpp index 607cbe4bab6..0062d485979 100644 --- a/torchvision/csrc/ops/cpu/ps_roi_pool_kernel.cpp +++ b/torchvision/csrc/ops/cpu/ps_roi_pool_kernel.cpp @@ -1,11 +1,13 @@ -#include -#include +#include "../../StableABICompat.h" +#include namespace vision { namespace ops { namespace { +using namespace vision::stable; + template inline void add(T* address, const T& val) { *address += val; @@ -43,12 +45,12 @@ void ps_roi_pool_forward_kernel_impl( for (int c_out = 0; c_out < channels_out; ++c_out) { for (int ph = 0; ph < pooled_height; ++ph) { for (int pw = 0; pw < pooled_width; ++pw) { - int hstart = static_cast(floor(static_cast(ph) * bin_size_h)); - int wstart = static_cast(floor(static_cast(pw) * bin_size_w)); + int hstart = static_cast(std::floor(static_cast(ph) * bin_size_h)); + int wstart = static_cast(std::floor(static_cast(pw) * bin_size_w)); int hend = - static_cast(ceil(static_cast(ph + 1) * bin_size_h)); + static_cast(std::ceil(static_cast(ph + 1) * bin_size_h)); int wend = - static_cast(ceil(static_cast(pw + 1) * bin_size_w)); + static_cast(std::ceil(static_cast(pw + 1) * bin_size_w)); // Add roi offsets and clip to input boundaries hstart = std::min(std::max(hstart + roi_start_h, 0), height - 1); @@ -111,10 +113,10 @@ void ps_roi_pool_backward_kernel_impl( for (int ph = 0; ph < pooled_height; ++ph) { for (int pw = 0; pw < pooled_width; ++pw) { - int hstart = static_cast(floor(static_cast(ph) * bin_size_h)); - int wstart = static_cast(floor(static_cast(pw) * bin_size_w)); - int hend = static_cast(ceil(static_cast(ph + 1) * bin_size_h)); - int wend = static_cast(ceil(static_cast(pw + 1) * bin_size_w)); + int hstart = static_cast(std::floor(static_cast(ph) * bin_size_h)); + int wstart = static_cast(std::floor(static_cast(pw) * bin_size_w)); + int hend = static_cast(std::ceil(static_cast(ph + 1) * bin_size_h)); + int wend = static_cast(std::ceil(static_cast(pw + 1) * bin_size_w)); // Add roi offsets and clip to input boundaries hstart = std::min(std::max(hstart + roi_start_h, 0), height); @@ -146,67 +148,87 @@ void ps_roi_pool_backward_kernel_impl( } } -std::tuple ps_roi_pool_forward_kernel( - const at::Tensor& input, - const at::Tensor& rois, +std::tuple ps_roi_pool_forward_kernel( + const Tensor& input, + const Tensor& rois, double spatial_scale, int64_t pooled_height, int64_t pooled_width) { // Check if input tensors are CPU tensors - TORCH_CHECK(input.device().is_cpu(), "input must be a CPU tensor"); - TORCH_CHECK(rois.device().is_cpu(), "rois must be a CPU tensor"); - TORCH_CHECK( + VISION_CHECK(input.is_cpu(), "input must be a CPU tensor"); + VISION_CHECK(rois.is_cpu(), "rois must be a CPU tensor"); + VISION_CHECK( rois.size(1) == 5, "Tensor rois should have shape as Tensor[K, 5]"); - - at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2}; - - at::CheckedFrom c = "ps_roi_pool_forward_kernel"; - at::checkAllSameType(c, {input_t, rois_t}); + VISION_CHECK( + input.scalar_type() == rois.scalar_type(), + "input and rois must have the same dtype"); int num_rois = rois.size(0); int channels = input.size(1); int height = input.size(2); int width = input.size(3); - TORCH_CHECK( + VISION_CHECK( channels % (pooled_height * pooled_width) == 0, "input channels must be a multiple of pooling height * pooling width"); int channels_out = channels / (pooled_height * pooled_width); - auto output = at::zeros( - {num_rois, channels_out, pooled_height, pooled_width}, input.options()); - auto channel_mapping = - at::zeros(output.sizes(), input.options().dtype(at::kInt)); + Tensor output = zeros( + {num_rois, channels_out, pooled_height, pooled_width}, + input.scalar_type(), + Device(kCPU)); + Tensor channel_mapping = zeros( + {num_rois, channels_out, pooled_height, pooled_width}, + kInt, + Device(kCPU)); auto output_size = output.numel(); if (output_size == 0) { return std::make_tuple(output, channel_mapping); } - auto input_ = input.contiguous(), rois_ = rois.contiguous(); - AT_DISPATCH_FLOATING_TYPES_AND_HALF( - input.scalar_type(), "ps_roi_pool_forward_kernel", [&] { - ps_roi_pool_forward_kernel_impl( - input_.data_ptr(), - spatial_scale, - channels, - height, - width, - pooled_height, - pooled_width, - rois_.data_ptr(), - channels_out, - num_rois, - output.data_ptr(), - channel_mapping.data_ptr()); - }); + auto input_ = torch::stable::contiguous(input); + auto rois_ = torch::stable::contiguous(rois); + + auto dtype = input.scalar_type(); + if (dtype == kFloat) { + ps_roi_pool_forward_kernel_impl( + input_.const_data_ptr(), + spatial_scale, + channels, + height, + width, + pooled_height, + pooled_width, + rois_.const_data_ptr(), + channels_out, + num_rois, + output.mutable_data_ptr(), + channel_mapping.mutable_data_ptr()); + } else if (dtype == kDouble) { + ps_roi_pool_forward_kernel_impl( + input_.const_data_ptr(), + spatial_scale, + channels, + height, + width, + pooled_height, + pooled_width, + rois_.const_data_ptr(), + channels_out, + num_rois, + output.mutable_data_ptr(), + channel_mapping.mutable_data_ptr()); + } else { + VISION_CHECK(false, "ps_roi_pool only supports float and double types"); + } return std::make_tuple(output, channel_mapping); } -at::Tensor ps_roi_pool_backward_kernel( - const at::Tensor& grad, - const at::Tensor& rois, - const at::Tensor& channel_mapping, +Tensor ps_roi_pool_backward_kernel( + const Tensor& grad, + const Tensor& rois, + const Tensor& channel_mapping, double spatial_scale, int64_t pooled_height, int64_t pooled_width, @@ -215,21 +237,19 @@ at::Tensor ps_roi_pool_backward_kernel( int64_t height, int64_t width) { // Check if input tensors are CPU tensors - TORCH_CHECK(grad.device().is_cpu(), "grad must be a CPU tensor"); - TORCH_CHECK(rois.device().is_cpu(), "rois must be a CPU tensor"); - TORCH_CHECK( - channel_mapping.device().is_cpu(), - "channel_mapping must be a CPU tensor"); - - at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2}, - channel_mapping_t{channel_mapping, "channel_mapping", 3}; - - at::CheckedFrom c = "ps_roi_pool_backward_kernel"; - at::checkAllSameType(c, {grad_t, rois_t}); + VISION_CHECK(grad.is_cpu(), "grad must be a CPU tensor"); + VISION_CHECK(rois.is_cpu(), "rois must be a CPU tensor"); + VISION_CHECK( + channel_mapping.is_cpu(), "channel_mapping must be a CPU tensor"); + VISION_CHECK( + grad.scalar_type() == rois.scalar_type(), + "grad and rois must have the same dtype"); auto num_rois = rois.size(0); - auto grad_input = - at::zeros({batch_size, channels, height, width}, grad.options()); + Tensor grad_input = zeros( + {batch_size, channels, height, width}, + grad.scalar_type(), + Device(kCPU)); // handle possibly empty gradients if (grad.numel() == 0) { @@ -238,35 +258,50 @@ at::Tensor ps_roi_pool_backward_kernel( int channels_out = channels / (pooled_height * pooled_width); - auto grad_ = grad.contiguous(), rois_ = rois.contiguous(); - AT_DISPATCH_FLOATING_TYPES_AND_HALF( - grad.scalar_type(), "ps_roi_pool_backward_kernel", [&] { - ps_roi_pool_backward_kernel_impl( - grad_.data_ptr(), - channel_mapping.data_ptr(), - num_rois, - spatial_scale, - channels, - height, - width, - pooled_height, - pooled_width, - channels_out, - grad_input.data_ptr(), - rois_.data_ptr()); - }); + auto grad_ = torch::stable::contiguous(grad); + auto rois_ = torch::stable::contiguous(rois); + + auto dtype = grad.scalar_type(); + if (dtype == kFloat) { + ps_roi_pool_backward_kernel_impl( + grad_.const_data_ptr(), + channel_mapping.const_data_ptr(), + num_rois, + spatial_scale, + channels, + height, + width, + pooled_height, + pooled_width, + channels_out, + grad_input.mutable_data_ptr(), + rois_.const_data_ptr()); + } else if (dtype == kDouble) { + ps_roi_pool_backward_kernel_impl( + grad_.const_data_ptr(), + channel_mapping.const_data_ptr(), + num_rois, + spatial_scale, + channels, + height, + width, + pooled_height, + pooled_width, + channels_out, + grad_input.mutable_data_ptr(), + rois_.const_data_ptr()); + } else { + VISION_CHECK( + false, "ps_roi_pool backward only supports float and double types"); + } return grad_input; } } // namespace -TORCH_LIBRARY_IMPL(torchvision, CPU, m) { - m.impl( - TORCH_SELECTIVE_NAME("torchvision::ps_roi_pool"), - TORCH_FN(ps_roi_pool_forward_kernel)); - m.impl( - TORCH_SELECTIVE_NAME("torchvision::_ps_roi_pool_backward"), - TORCH_FN(ps_roi_pool_backward_kernel)); +STABLE_TORCH_LIBRARY_IMPL(torchvision, CPU, m) { + m.impl("ps_roi_pool", TORCH_BOX(&ps_roi_pool_forward_kernel)); + m.impl("_ps_roi_pool_backward", TORCH_BOX(&ps_roi_pool_backward_kernel)); } } // namespace ops diff --git a/torchvision/csrc/ops/cpu/roi_align_common.h b/torchvision/csrc/ops/cpu/roi_align_common.h index e10c67b5b79..03e5084357f 100644 --- a/torchvision/csrc/ops/cpu/roi_align_common.h +++ b/torchvision/csrc/ops/cpu/roi_align_common.h @@ -1,6 +1,6 @@ #pragma once -#include +#include namespace vision { namespace ops { diff --git a/torchvision/csrc/ops/cpu/roi_align_kernel.cpp b/torchvision/csrc/ops/cpu/roi_align_kernel.cpp index e0185da45df..6ab577efb52 100644 --- a/torchvision/csrc/ops/cpu/roi_align_kernel.cpp +++ b/torchvision/csrc/ops/cpu/roi_align_kernel.cpp @@ -1,5 +1,5 @@ -#include -#include +#include "../../StableABICompat.h" +#include #include "./roi_align_common.h" @@ -8,6 +8,8 @@ namespace ops { namespace { +using namespace vision::stable; + template void roi_align_forward_kernel_impl( int n_rois, @@ -52,9 +54,9 @@ void roi_align_forward_kernel_impl( // We use roi_bin_grid to sample the grid and mimic integral int roi_bin_grid_h = (sampling_ratio > 0) ? sampling_ratio - : ceil(roi_height / pooled_height); // e.g., = 2 + : std::ceil(roi_height / pooled_height); // e.g., = 2 int roi_bin_grid_w = - (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); + (sampling_ratio > 0) ? sampling_ratio : std::ceil(roi_width / pooled_width); // We do average (integral) pooling inside a bin // When the grid is empty, output zeros. @@ -230,9 +232,9 @@ void roi_align_backward_kernel_impl( // We use roi_bin_grid to sample the grid and mimic integral int roi_bin_grid_h = (sampling_ratio > 0) ? sampling_ratio - : ceil(roi_height / pooled_height); // e.g., = 2 + : std::ceil(roi_height / pooled_height); // e.g., = 2 int roi_bin_grid_w = - (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); + (sampling_ratio > 0) ? sampling_ratio : std::ceil(roi_width / pooled_width); // We do average (integral) pooling inside a bin const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4 @@ -281,58 +283,76 @@ void roi_align_backward_kernel_impl( } // for } -at::Tensor roi_align_forward_kernel( - const at::Tensor& input, - const at::Tensor& rois, +Tensor roi_align_forward_kernel( + const Tensor& input, + const Tensor& rois, double spatial_scale, int64_t pooled_height, int64_t pooled_width, int64_t sampling_ratio, bool aligned) { - TORCH_CHECK(input.device().is_cpu(), "input must be a CPU tensor"); - TORCH_CHECK(rois.device().is_cpu(), "rois must be a CPU tensor"); - TORCH_CHECK(rois.size(1) == 5, "rois must have shape as Tensor[K, 5]"); - - at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2}; - - at::CheckedFrom c = "roi_align_forward_kernel"; - at::checkAllSameType(c, {input_t, rois_t}); + VISION_CHECK(input.is_cpu(), "input must be a CPU tensor"); + VISION_CHECK(rois.is_cpu(), "rois must be a CPU tensor"); + VISION_CHECK(rois.size(1) == 5, "rois must have shape as Tensor[K, 5]"); + VISION_CHECK( + input.scalar_type() == rois.scalar_type(), + "input and rois must have the same dtype"); auto num_rois = rois.size(0); auto channels = input.size(1); auto height = input.size(2); auto width = input.size(3); - at::Tensor output = at::zeros( - {num_rois, channels, pooled_height, pooled_width}, input.options()); + Tensor output = zeros( + {num_rois, channels, pooled_height, pooled_width}, + input.scalar_type(), + Device(kCPU)); if (output.numel() == 0) { return output; } - auto input_ = input.contiguous(), rois_ = rois.contiguous(); - AT_DISPATCH_FLOATING_TYPES_AND_HALF( - input.scalar_type(), "roi_align_forward_kernel", [&] { - roi_align_forward_kernel_impl( - num_rois, - input_.data_ptr(), - spatial_scale, - channels, - height, - width, - pooled_height, - pooled_width, - sampling_ratio, - aligned, - rois_.data_ptr(), - output.data_ptr()); - }); + auto input_ = torch::stable::contiguous(input); + auto rois_ = torch::stable::contiguous(rois); + + auto dtype = input.scalar_type(); + if (dtype == kFloat) { + roi_align_forward_kernel_impl( + num_rois, + input_.const_data_ptr(), + spatial_scale, + channels, + height, + width, + pooled_height, + pooled_width, + sampling_ratio, + aligned, + rois_.const_data_ptr(), + output.mutable_data_ptr()); + } else if (dtype == kDouble) { + roi_align_forward_kernel_impl( + num_rois, + input_.const_data_ptr(), + spatial_scale, + channels, + height, + width, + pooled_height, + pooled_width, + sampling_ratio, + aligned, + rois_.const_data_ptr(), + output.mutable_data_ptr()); + } else { + VISION_CHECK(false, "roi_align only supports float and double types"); + } return output; } -at::Tensor roi_align_backward_kernel( - const at::Tensor& grad, - const at::Tensor& rois, +Tensor roi_align_backward_kernel( + const Tensor& grad, + const Tensor& rois, double spatial_scale, int64_t pooled_height, int64_t pooled_width, @@ -342,16 +362,16 @@ at::Tensor roi_align_backward_kernel( int64_t width, int64_t sampling_ratio, bool aligned) { - TORCH_CHECK(grad.device().is_cpu(), "grad must be a CPU tensor"); - TORCH_CHECK(rois.device().is_cpu(), "rois must be a CPU tensor"); + VISION_CHECK(grad.is_cpu(), "grad must be a CPU tensor"); + VISION_CHECK(rois.is_cpu(), "rois must be a CPU tensor"); + VISION_CHECK( + grad.scalar_type() == rois.scalar_type(), + "grad and rois must have the same dtype"); - at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2}; - - at::CheckedFrom c = "roi_align_backward_kernel"; - at::checkAllSameType(c, {grad_t, rois_t}); - - at::Tensor grad_input = - at::zeros({batch_size, channels, height, width}, grad.options()); + Tensor grad_input = zeros( + {batch_size, channels, height, width}, + grad.scalar_type(), + Device(kCPU)); // handle possibly empty gradients if (grad.numel() == 0) { @@ -364,39 +384,57 @@ at::Tensor roi_align_backward_kernel( int h_stride = grad.stride(2); int w_stride = grad.stride(3); - auto rois_ = rois.contiguous(); - AT_DISPATCH_FLOATING_TYPES_AND_HALF( - grad.scalar_type(), "roi_align_backward_kernel", [&] { - roi_align_backward_kernel_impl( - grad.numel(), - grad.data_ptr(), - spatial_scale, - channels, - height, - width, - pooled_height, - pooled_width, - sampling_ratio, - aligned, - grad_input.data_ptr(), - rois_.data_ptr(), - n_stride, - c_stride, - h_stride, - w_stride); - }); + auto rois_ = torch::stable::contiguous(rois); + + auto dtype = grad.scalar_type(); + if (dtype == kFloat) { + roi_align_backward_kernel_impl( + grad.numel(), + grad.const_data_ptr(), + spatial_scale, + channels, + height, + width, + pooled_height, + pooled_width, + sampling_ratio, + aligned, + grad_input.mutable_data_ptr(), + rois_.const_data_ptr(), + n_stride, + c_stride, + h_stride, + w_stride); + } else if (dtype == kDouble) { + roi_align_backward_kernel_impl( + grad.numel(), + grad.const_data_ptr(), + spatial_scale, + channels, + height, + width, + pooled_height, + pooled_width, + sampling_ratio, + aligned, + grad_input.mutable_data_ptr(), + rois_.const_data_ptr(), + n_stride, + c_stride, + h_stride, + w_stride); + } else { + VISION_CHECK( + false, "roi_align backward only supports float and double types"); + } return grad_input; } } // namespace -TORCH_LIBRARY_IMPL(torchvision, CPU, m) { - m.impl( - TORCH_SELECTIVE_NAME("torchvision::roi_align"), - TORCH_FN(roi_align_forward_kernel)); - m.impl( - TORCH_SELECTIVE_NAME("torchvision::_roi_align_backward"), - TORCH_FN(roi_align_backward_kernel)); +STABLE_TORCH_LIBRARY_IMPL(torchvision, CPU, m) { + m.impl("roi_align", TORCH_BOX(&roi_align_forward_kernel)); + m.impl("_roi_align_backward", TORCH_BOX(&roi_align_backward_kernel)); } } // namespace ops diff --git a/torchvision/csrc/ops/cpu/roi_pool_kernel.cpp b/torchvision/csrc/ops/cpu/roi_pool_kernel.cpp index b099523896a..a28f473c0bb 100644 --- a/torchvision/csrc/ops/cpu/roi_pool_kernel.cpp +++ b/torchvision/csrc/ops/cpu/roi_pool_kernel.cpp @@ -1,13 +1,15 @@ #include -#include -#include +#include "../../StableABICompat.h" +#include namespace vision { namespace ops { namespace { +using namespace vision::stable; + template inline void add(T* address, const T& val) { *address += val; @@ -42,10 +44,10 @@ void roi_pool_forward_kernel_impl( for (int ph = 0; ph < pooled_height; ++ph) { for (int pw = 0; pw < pooled_width; ++pw) { - int hstart = static_cast(floor(static_cast(ph) * bin_size_h)); - int wstart = static_cast(floor(static_cast(pw) * bin_size_w)); - int hend = static_cast(ceil(static_cast(ph + 1) * bin_size_h)); - int wend = static_cast(ceil(static_cast(pw + 1) * bin_size_w)); + int hstart = static_cast(std::floor(static_cast(ph) * bin_size_h)); + int wstart = static_cast(std::floor(static_cast(pw) * bin_size_w)); + int hend = static_cast(std::ceil(static_cast(ph + 1) * bin_size_h)); + int wend = static_cast(std::ceil(static_cast(pw + 1) * bin_size_w)); // Add roi offsets and clip to input boundaries hstart = std::min(std::max(hstart + roi_start_h, 0), height); @@ -125,58 +127,76 @@ void roi_pool_backward_kernel_impl( } // num_rois } -std::tuple roi_pool_forward_kernel( - const at::Tensor& input, - const at::Tensor& rois, +std::tuple roi_pool_forward_kernel( + const Tensor& input, + const Tensor& rois, double spatial_scale, int64_t pooled_height, int64_t pooled_width) { - TORCH_CHECK(input.device().is_cpu(), "input must be a CPU tensor"); - TORCH_CHECK(rois.device().is_cpu(), "rois must be a CPU tensor"); - - at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2}; - - at::CheckedFrom c = "roi_pool_forward_kernel"; - at::checkAllSameType(c, {input_t, rois_t}); + VISION_CHECK(input.is_cpu(), "input must be a CPU tensor"); + VISION_CHECK(rois.is_cpu(), "rois must be a CPU tensor"); + VISION_CHECK( + input.scalar_type() == rois.scalar_type(), + "input and rois must have the same dtype"); int num_rois = rois.size(0); int channels = input.size(1); int height = input.size(2); int width = input.size(3); - at::Tensor output = at::zeros( - {num_rois, channels, pooled_height, pooled_width}, input.options()); - at::Tensor argmax = at::zeros( + Tensor output = zeros( + {num_rois, channels, pooled_height, pooled_width}, + input.scalar_type(), + Device(kCPU)); + Tensor argmax = zeros( {num_rois, channels, pooled_height, pooled_width}, - input.options().dtype(at::kInt)); + kInt, + Device(kCPU)); if (output.numel() == 0) { return std::make_tuple(output, argmax); } - auto input_ = input.contiguous(), rois_ = rois.contiguous(); - AT_DISPATCH_FLOATING_TYPES_AND_HALF( - input.scalar_type(), "roi_pool_forward_kernel", [&] { - roi_pool_forward_kernel_impl( - input_.data_ptr(), - spatial_scale, - channels, - height, - width, - pooled_height, - pooled_width, - rois_.data_ptr(), - num_rois, - output.data_ptr(), - argmax.data_ptr()); - }); + auto input_ = torch::stable::contiguous(input); + auto rois_ = torch::stable::contiguous(rois); + + auto dtype = input.scalar_type(); + if (dtype == kFloat) { + roi_pool_forward_kernel_impl( + input_.const_data_ptr(), + spatial_scale, + channels, + height, + width, + pooled_height, + pooled_width, + rois_.const_data_ptr(), + num_rois, + output.mutable_data_ptr(), + argmax.mutable_data_ptr()); + } else if (dtype == kDouble) { + roi_pool_forward_kernel_impl( + input_.const_data_ptr(), + spatial_scale, + channels, + height, + width, + pooled_height, + pooled_width, + rois_.const_data_ptr(), + num_rois, + output.mutable_data_ptr(), + argmax.mutable_data_ptr()); + } else { + VISION_CHECK(false, "roi_pool only supports float and double types"); + } return std::make_tuple(output, argmax); } -at::Tensor roi_pool_backward_kernel( - const at::Tensor& grad, - const at::Tensor& rois, - const at::Tensor& argmax, +Tensor roi_pool_backward_kernel( + const Tensor& grad, + const Tensor& rois, + const Tensor& argmax, double spatial_scale, int64_t pooled_height, int64_t pooled_width, @@ -185,21 +205,21 @@ at::Tensor roi_pool_backward_kernel( int64_t height, int64_t width) { // Check if input tensors are CPU tensors - TORCH_CHECK(grad.device().is_cpu(), "grad must be a CPU tensor"); - TORCH_CHECK(rois.device().is_cpu(), "rois must be a CPU tensor"); - TORCH_CHECK(argmax.device().is_cpu(), "argmax must be a CPU tensor"); - TORCH_CHECK( + VISION_CHECK(grad.is_cpu(), "grad must be a CPU tensor"); + VISION_CHECK(rois.is_cpu(), "rois must be a CPU tensor"); + VISION_CHECK(argmax.is_cpu(), "argmax must be a CPU tensor"); + VISION_CHECK( rois.size(1) == 5, "Tensor rois should have shape as Tensor[K, 5]"); - - at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2}; - - at::CheckedFrom c = "roi_pool_backward_kernel"; - at::checkAllSameType(c, {grad_t, rois_t}); + VISION_CHECK( + grad.scalar_type() == rois.scalar_type(), + "grad and rois must have the same dtype"); auto num_rois = rois.size(0); - at::Tensor grad_input = - at::zeros({batch_size, channels, height, width}, grad.options()); + Tensor grad_input = zeros( + {batch_size, channels, height, width}, + grad.scalar_type(), + Device(kCPU)); // handle possibly empty gradients if (grad.numel() == 0) { @@ -212,37 +232,52 @@ at::Tensor roi_pool_backward_kernel( int h_stride = grad.stride(2); int w_stride = grad.stride(3); - auto rois_ = rois.contiguous(); - AT_DISPATCH_FLOATING_TYPES_AND_HALF( - grad.scalar_type(), "roi_pool_backward_kernel", [&] { - roi_pool_backward_kernel_impl( - grad.data_ptr(), - argmax.data_ptr(), - num_rois, - channels, - height, - width, - pooled_height, - pooled_width, - grad_input.data_ptr(), - rois_.data_ptr(), - n_stride, - c_stride, - h_stride, - w_stride); - }); + auto rois_ = torch::stable::contiguous(rois); + + auto dtype = grad.scalar_type(); + if (dtype == kFloat) { + roi_pool_backward_kernel_impl( + grad.const_data_ptr(), + argmax.const_data_ptr(), + num_rois, + channels, + height, + width, + pooled_height, + pooled_width, + grad_input.mutable_data_ptr(), + rois_.const_data_ptr(), + n_stride, + c_stride, + h_stride, + w_stride); + } else if (dtype == kDouble) { + roi_pool_backward_kernel_impl( + grad.const_data_ptr(), + argmax.const_data_ptr(), + num_rois, + channels, + height, + width, + pooled_height, + pooled_width, + grad_input.mutable_data_ptr(), + rois_.const_data_ptr(), + n_stride, + c_stride, + h_stride, + w_stride); + } else { + VISION_CHECK(false, "roi_pool backward only supports float and double types"); + } return grad_input; } } // namespace -TORCH_LIBRARY_IMPL(torchvision, CPU, m) { - m.impl( - TORCH_SELECTIVE_NAME("torchvision::roi_pool"), - TORCH_FN(roi_pool_forward_kernel)); - m.impl( - TORCH_SELECTIVE_NAME("torchvision::_roi_pool_backward"), - TORCH_FN(roi_pool_backward_kernel)); +STABLE_TORCH_LIBRARY_IMPL(torchvision, CPU, m) { + m.impl("roi_pool", TORCH_BOX(&roi_pool_forward_kernel)); + m.impl("_roi_pool_backward", TORCH_BOX(&roi_pool_backward_kernel)); } } // namespace ops diff --git a/torchvision/csrc/ops/cuda/nms_kernel.cu b/torchvision/csrc/ops/cuda/nms_kernel.cu index 44ce8db6b8e..78e0c2bf099 100644 --- a/torchvision/csrc/ops/cuda/nms_kernel.cu +++ b/torchvision/csrc/ops/cuda/nms_kernel.cu @@ -1,8 +1,6 @@ -#include -#include -#include -#include -#include +#include "../../StableABICompat.h" +#include +#include #include "cuda_helpers.h" @@ -11,6 +9,8 @@ namespace ops { namespace { +using namespace vision::stable; + int const threadsPerBlock = sizeof(unsigned long long) * 8; template @@ -18,13 +18,52 @@ __device__ inline bool devIoU( T const* const a, T const* const b, const float threshold) { - T left = max(a[0], b[0]), right = min(a[2], b[2]); - T top = max(a[1], b[1]), bottom = min(a[3], b[3]); - T width = max(right - left, (T)0), height = max(bottom - top, (T)0); - using acc_T = at::acc_type; - acc_T interS = (acc_T)width * height; - acc_T Sa = ((acc_T)a[2] - a[0]) * (a[3] - a[1]); - acc_T Sb = ((acc_T)b[2] - b[0]) * (b[3] - b[1]); + // Use float for all arithmetic to avoid issues with half operators being disabled + float a0 = __half2float(a[0]); + float a1 = __half2float(a[1]); + float a2 = __half2float(a[2]); + float a3 = __half2float(a[3]); + float b0 = __half2float(b[0]); + float b1 = __half2float(b[1]); + float b2 = __half2float(b[2]); + float b3 = __half2float(b[3]); + + float left = max(a0, b0), right = min(a2, b2); + float top = max(a1, b1), bottom = min(a3, b3); + float width = max(right - left, 0.0f), height = max(bottom - top, 0.0f); + float interS = width * height; + float Sa = (a2 - a0) * (a3 - a1); + float Sb = (b2 - b0) * (b3 - b1); + return (interS / (Sa + Sb - interS)) > threshold; +} + +// Specialization for float - just use values directly +template <> +__device__ inline bool devIoU( + float const* const a, + float const* const b, + const float threshold) { + float left = max(a[0], b[0]), right = min(a[2], b[2]); + float top = max(a[1], b[1]), bottom = min(a[3], b[3]); + float width = max(right - left, 0.0f), height = max(bottom - top, 0.0f); + float interS = width * height; + float Sa = (a[2] - a[0]) * (a[3] - a[1]); + float Sb = (b[2] - b[0]) * (b[3] - b[1]); + return (interS / (Sa + Sb - interS)) > threshold; +} + +// Specialization for double +template <> +__device__ inline bool devIoU( + double const* const a, + double const* const b, + const float threshold) { + double left = max(a[0], b[0]), right = min(a[2], b[2]); + double top = max(a[1], b[1]), bottom = min(a[3], b[3]); + double width = max(right - left, 0.0), height = max(bottom - top, 0.0); + double interS = width * height; + double Sa = (a[2] - a[0]) * (a[3] - a[1]); + double Sb = (b[2] - b[0]) * (b[3] - b[1]); return (interS / (Sa + Sb - interS)) > threshold; } @@ -122,25 +161,25 @@ __global__ static void gather_keep_from_mask( } } -at::Tensor nms_kernel( - const at::Tensor& dets, - const at::Tensor& scores, +Tensor nms_kernel( + const Tensor& dets, + const Tensor& scores, double iou_threshold) { - TORCH_CHECK(dets.is_cuda(), "dets must be a CUDA tensor"); - TORCH_CHECK(scores.is_cuda(), "scores must be a CUDA tensor"); + VISION_CHECK(dets.is_cuda(), "dets must be a CUDA tensor"); + VISION_CHECK(scores.is_cuda(), "scores must be a CUDA tensor"); - TORCH_CHECK( + VISION_CHECK( dets.dim() == 2, "boxes should be a 2d tensor, got ", dets.dim(), "D"); - TORCH_CHECK( + VISION_CHECK( dets.size(1) == 4, "boxes should have 4 elements in dimension 1, got ", dets.size(1)); - TORCH_CHECK( + VISION_CHECK( scores.dim() == 1, "scores should be a 1d tensor, got ", scores.dim(), "D"); - TORCH_CHECK( + VISION_CHECK( dets.size(0) == scores.size(0), "boxes and scores should have same number of elements in ", "dimension 0, got ", @@ -148,38 +187,49 @@ at::Tensor nms_kernel( " and ", scores.size(0)) - at::cuda::CUDAGuard device_guard(dets.device()); + DeviceGuard device_guard(dets.get_device_index()); if (dets.numel() == 0) { - return at::empty({0}, dets.options().dtype(at::kLong)); + return empty({0}, kLong, Device(kCUDA, dets.get_device_index())); } - auto order_t = std::get<1>( - scores.sort(/*stable=*/true, /*dim=*/0, /* descending=*/true)); - auto dets_sorted = dets.index_select(0, order_t).contiguous(); + // Sort scores descending and get indices + auto [sorted_scores, order_t] = sort(scores, /*dim=*/0, /*descending=*/true); + auto dets_sorted = torch::stable::contiguous(index_select(dets, 0, order_t)); int dets_num = dets.size(0); const int col_blocks = ceil_div(dets_num, threadsPerBlock); - at::Tensor mask = - at::empty({dets_num * col_blocks}, dets.options().dtype(at::kLong)); + Tensor mask = empty({dets_num * col_blocks}, kLong, Device(kCUDA, dets.get_device_index())); dim3 blocks(col_blocks, col_blocks); dim3 threads(threadsPerBlock); - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - AT_DISPATCH_FLOATING_TYPES_AND_HALF( - dets_sorted.scalar_type(), "nms_kernel", [&] { - nms_kernel_impl<<>>( - dets_num, - iou_threshold, - dets_sorted.data_ptr(), - (unsigned long long*)mask.data_ptr()); - }); + // Get CUDA stream + void* stream_ptr; + TORCH_ERROR_CODE_CHECK(aoti_torch_get_current_cuda_stream( + dets.get_device_index(), &stream_ptr)); + cudaStream_t stream = static_cast(stream_ptr); + + auto dtype = dets_sorted.scalar_type(); + if (dtype == kFloat) { + nms_kernel_impl<<>>( + dets_num, + iou_threshold, + dets_sorted.const_data_ptr(), + (unsigned long long*)mask.mutable_data_ptr()); + } else if (dtype == kDouble) { + nms_kernel_impl<<>>( + dets_num, + iou_threshold, + dets_sorted.const_data_ptr(), + (unsigned long long*)mask.mutable_data_ptr()); + } else { + VISION_CHECK(false, "nms only supports float and double types"); + } - at::Tensor keep = - at::zeros({dets_num}, dets.options().dtype(at::kBool).device(at::kCUDA)); + Tensor keep = zeros({dets_num}, kBool, Device(kCUDA, dets.get_device_index())); // Unwrap the mask to fill keep with proper values // Keeping the unwrap on device instead of applying iterative for loops on cpu @@ -191,18 +241,18 @@ at::Tensor nms_kernel( min(col_blocks, threadsPerBlock), col_blocks * sizeof(unsigned long long), stream>>>( - keep.data_ptr(), - (unsigned long long*)mask.data_ptr(), + keep.mutable_data_ptr(), + (unsigned long long*)mask.const_data_ptr(), dets_num); - AT_CUDA_CHECK(cudaGetLastError()); - return order_t.masked_select(keep); + STD_CUDA_KERNEL_LAUNCH_CHECK(); + return masked_select(order_t, keep); } } // namespace -TORCH_LIBRARY_IMPL(torchvision, CUDA, m) { - m.impl(TORCH_SELECTIVE_NAME("torchvision::nms"), TORCH_FN(nms_kernel)); +STABLE_TORCH_LIBRARY_IMPL(torchvision, CUDA, m) { + m.impl("nms", TORCH_BOX(&nms_kernel)); } } // namespace ops diff --git a/torchvision/csrc/ops/cuda/ps_roi_align_kernel.cu b/torchvision/csrc/ops/cuda/ps_roi_align_kernel.cu index 105c6a14256..ac00f406adb 100644 --- a/torchvision/csrc/ops/cuda/ps_roi_align_kernel.cu +++ b/torchvision/csrc/ops/cuda/ps_roi_align_kernel.cu @@ -1,8 +1,6 @@ -#include -#include -#include -#include -#include +#include "../../StableABICompat.h" +#include +#include #include "cuda_helpers.h" @@ -11,6 +9,8 @@ namespace ops { namespace { +using namespace vision::stable; + template __device__ T bilinear_interpolate( const T* input, @@ -212,8 +212,7 @@ __global__ void ps_roi_align_backward_kernel_impl( int sampling_ratio, int channels_out, T* grad_input, - const T* rois, - const int memory_span) { + const T* rois) { CUDA_1D_KERNEL_LOOP(index, nthreads) { // (n, *, ph, pw) is an element in the pooled output int pw = index % pooled_width; @@ -286,112 +285,117 @@ __global__ void ps_roi_align_backward_kernel_impl( T g4 = grad_output_this_bin * w4 / count; if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) { - at::native::fastAtomicAdd( - grad_input, - offset + y_low * width + x_low, - memory_span, - static_cast(g1), - true); - at::native::fastAtomicAdd( - grad_input, - offset + y_low * width + x_high, - memory_span, - static_cast(g2), - true); - at::native::fastAtomicAdd( - grad_input, - offset + y_high * width + x_low, - memory_span, - static_cast(g3), - true); - at::native::fastAtomicAdd( - grad_input, - offset + y_high * width + x_high, - memory_span, - static_cast(g4), - true); + atomicAdd(grad_input + offset + y_low * width + x_low, g1); + atomicAdd(grad_input + offset + y_low * width + x_high, g2); + atomicAdd(grad_input + offset + y_high * width + x_low, g3); + atomicAdd(grad_input + offset + y_high * width + x_high, g4); } // if } // ix } // iy } } -std::tuple ps_roi_align_forward_kernel( - const at::Tensor& input, - const at::Tensor& rois, +std::tuple ps_roi_align_forward_kernel( + const Tensor& input, + const Tensor& rois, double spatial_scale, int64_t pooled_height, int64_t pooled_width, int64_t sampling_ratio) { // Check if input tensors are CUDA tensors - TORCH_CHECK(input.is_cuda(), "input must be a CUDA tensor"); - TORCH_CHECK(rois.is_cuda(), "rois must be a CUDA tensor"); - TORCH_CHECK( + VISION_CHECK(input.is_cuda(), "input must be a CUDA tensor"); + VISION_CHECK(rois.is_cuda(), "rois must be a CUDA tensor"); + VISION_CHECK( rois.size(1) == 5, "Tensor rois should have shape as Tensor[K, 5]"); + VISION_CHECK( + input.scalar_type() == rois.scalar_type(), + "input and rois must have the same dtype"); - at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2}; - - at::CheckedFrom c = "ps_roi_align_forward_kernel"; - at::checkAllSameGPU(c, {input_t, rois_t}); - at::checkAllSameType(c, {input_t, rois_t}); - - at::cuda::CUDAGuard device_guard(input.device()); + DeviceGuard device_guard(input.get_device_index()); auto num_rois = rois.size(0); auto channels = input.size(1); auto height = input.size(2); auto width = input.size(3); - TORCH_CHECK( + VISION_CHECK( channels % (pooled_height * pooled_width) == 0, "input channels must be a multiple of pooling height * pooling width"); int channels_out = channels / (pooled_height * pooled_width); - auto output = at::zeros( - {num_rois, channels_out, pooled_height, pooled_width}, input.options()); - auto channel_mapping = - at::zeros(output.sizes(), input.options().dtype(at::kInt)); + Tensor output = zeros( + {num_rois, channels_out, pooled_height, pooled_width}, + input.scalar_type(), + Device(kCUDA, input.get_device_index())); + Tensor channel_mapping = zeros( + {num_rois, channels_out, pooled_height, pooled_width}, + kInt, + Device(kCUDA, input.get_device_index())); auto output_size = output.numel(); if (output_size == 0) { - AT_CUDA_CHECK(cudaGetLastError()); + STD_CUDA_KERNEL_LAUNCH_CHECK(); return std::make_tuple(output, channel_mapping); } - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + // Get CUDA stream + void* stream_ptr; + TORCH_ERROR_CODE_CHECK(aoti_torch_get_current_cuda_stream( + input.get_device_index(), &stream_ptr)); + cudaStream_t stream = static_cast(stream_ptr); dim3 grid(std::min( ceil_div(static_cast(output_size), static_cast(512)), static_cast(4096))); dim3 block(512); - auto input_ = input.contiguous(), rois_ = rois.contiguous(); - AT_DISPATCH_FLOATING_TYPES_AND_HALF( - input.scalar_type(), "ps_roi_align_forward_kernel", [&] { - ps_roi_align_forward_kernel_impl<<>>( - output_size, - input_.data_ptr(), - spatial_scale, - channels, - height, - width, - pooled_height, - pooled_width, - sampling_ratio, - rois_.data_ptr(), - channels_out, - output.data_ptr(), - channel_mapping.data_ptr()); - }); - AT_CUDA_CHECK(cudaGetLastError()); - cudaDeviceSynchronize(); + auto input_ = torch::stable::contiguous(input); + auto rois_ = torch::stable::contiguous(rois); + + auto dtype = input.scalar_type(); + if (dtype == kFloat) { + ps_roi_align_forward_kernel_impl<<>>( + output_size, + input_.const_data_ptr(), + spatial_scale, + channels, + height, + width, + pooled_height, + pooled_width, + sampling_ratio, + rois_.const_data_ptr(), + channels_out, + output.mutable_data_ptr(), + channel_mapping.mutable_data_ptr()); + } else if (dtype == kDouble) { + ps_roi_align_forward_kernel_impl<<>>( + output_size, + input_.const_data_ptr(), + spatial_scale, + channels, + height, + width, + pooled_height, + pooled_width, + sampling_ratio, + rois_.const_data_ptr(), + channels_out, + output.mutable_data_ptr(), + channel_mapping.mutable_data_ptr()); + } else { + VISION_CHECK( + false, "ps_roi_align only supports float and double types"); + } + + STD_CUDA_KERNEL_LAUNCH_CHECK(); return std::make_tuple(output, channel_mapping); } -at::Tensor ps_roi_align_backward_kernel( - const at::Tensor& grad, - const at::Tensor& rois, - const at::Tensor& channel_mapping, +Tensor ps_roi_align_backward_kernel( + const Tensor& grad, + const Tensor& rois, + const Tensor& channel_mapping, double spatial_scale, int64_t pooled_height, int64_t pooled_width, @@ -401,24 +405,27 @@ at::Tensor ps_roi_align_backward_kernel( int64_t height, int64_t width) { // Check if input tensors are CUDA tensors - TORCH_CHECK(grad.is_cuda(), "grad must be a CUDA tensor"); - TORCH_CHECK(rois.is_cuda(), "rois must be a CUDA tensor"); - TORCH_CHECK( + VISION_CHECK(grad.is_cuda(), "grad must be a CUDA tensor"); + VISION_CHECK(rois.is_cuda(), "rois must be a CUDA tensor"); + VISION_CHECK( channel_mapping.is_cuda(), "channel_mapping must be a CUDA tensor"); + VISION_CHECK( + grad.scalar_type() == rois.scalar_type(), + "grad and rois must have the same dtype"); - at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2}, - channel_mapping_t{channel_mapping, "channel_mapping", 3}; + DeviceGuard device_guard(grad.get_device_index()); - at::CheckedFrom c = "ps_roi_align_backward_kernel"; - at::checkAllSameGPU(c, {grad_t, rois_t, channel_mapping_t}); - at::checkAllSameType(c, {grad_t, rois_t}); - - at::cuda::CUDAGuard device_guard(grad.device()); - - auto grad_input = - at::zeros({batch_size, channels, height, width}, grad.options()); + auto num_rois = rois.size(0); + Tensor grad_input = zeros( + {batch_size, channels, height, width}, + grad.scalar_type(), + Device(kCUDA, grad.get_device_index())); - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + // Get CUDA stream + void* stream_ptr; + TORCH_ERROR_CODE_CHECK(aoti_torch_get_current_cuda_stream( + grad.get_device_index(), &stream_ptr)); + cudaStream_t stream = static_cast(stream_ptr); dim3 grid(std::min( ceil_div(static_cast(grad.numel()), static_cast(512)), @@ -427,46 +434,60 @@ at::Tensor ps_roi_align_backward_kernel( // handle possibly empty gradients if (grad.numel() == 0) { - AT_CUDA_CHECK(cudaGetLastError()); + STD_CUDA_KERNEL_LAUNCH_CHECK(); return grad_input; } int channels_out = channels / (pooled_height * pooled_width); - at::globalContext().alertNotDeterministic("ps_roi_align_backward_kernel"); - - auto grad_ = grad.contiguous(), rois_ = rois.contiguous(); - AT_DISPATCH_FLOATING_TYPES_AND_HALF( - grad.scalar_type(), "ps_roi_align_backward_kernel", [&] { - ps_roi_align_backward_kernel_impl<<>>( - grad.numel(), - grad_.data_ptr(), - channel_mapping.data_ptr(), - spatial_scale, - channels, - height, - width, - pooled_height, - pooled_width, - sampling_ratio, - channels_out, - grad_input.data_ptr(), - rois_.data_ptr(), - grad_input.numel()); - }); - AT_CUDA_CHECK(cudaGetLastError()); + auto grad_ = torch::stable::contiguous(grad); + auto rois_ = torch::stable::contiguous(rois); + + auto dtype = grad.scalar_type(); + if (dtype == kFloat) { + ps_roi_align_backward_kernel_impl<<>>( + grad.numel(), + grad_.const_data_ptr(), + channel_mapping.const_data_ptr(), + spatial_scale, + channels, + height, + width, + pooled_height, + pooled_width, + sampling_ratio, + channels_out, + grad_input.mutable_data_ptr(), + rois_.const_data_ptr()); + } else if (dtype == kDouble) { + ps_roi_align_backward_kernel_impl<<>>( + grad.numel(), + grad_.const_data_ptr(), + channel_mapping.const_data_ptr(), + spatial_scale, + channels, + height, + width, + pooled_height, + pooled_width, + sampling_ratio, + channels_out, + grad_input.mutable_data_ptr(), + rois_.const_data_ptr()); + } else { + VISION_CHECK( + false, "ps_roi_align backward only supports float and double types"); + } + + STD_CUDA_KERNEL_LAUNCH_CHECK(); return grad_input; } } // namespace -TORCH_LIBRARY_IMPL(torchvision, CUDA, m) { - m.impl( - TORCH_SELECTIVE_NAME("torchvision::ps_roi_align"), - TORCH_FN(ps_roi_align_forward_kernel)); - m.impl( - TORCH_SELECTIVE_NAME("torchvision::_ps_roi_align_backward"), - TORCH_FN(ps_roi_align_backward_kernel)); +STABLE_TORCH_LIBRARY_IMPL(torchvision, CUDA, m) { + m.impl("ps_roi_align", TORCH_BOX(&ps_roi_align_forward_kernel)); + m.impl("_ps_roi_align_backward", TORCH_BOX(&ps_roi_align_backward_kernel)); } } // namespace ops diff --git a/torchvision/csrc/ops/cuda/ps_roi_pool_kernel.cu b/torchvision/csrc/ops/cuda/ps_roi_pool_kernel.cu index 2c90690f4a5..ba2d5599e1c 100644 --- a/torchvision/csrc/ops/cuda/ps_roi_pool_kernel.cu +++ b/torchvision/csrc/ops/cuda/ps_roi_pool_kernel.cu @@ -1,8 +1,5 @@ -#include -#include -#include -#include -#include +#include "../../StableABICompat.h" +#include #include "cuda_helpers.h" @@ -11,6 +8,8 @@ namespace ops { namespace { +using namespace vision::stable; + template __global__ void ps_roi_pool_forward_kernel_impl( int nthreads, @@ -91,8 +90,7 @@ __global__ void ps_roi_pool_backward_kernel_impl( int pooled_width, int channels_out, T* grad_input, - const T* rois, - const int memory_span) { + const T* rois) { CUDA_1D_KERNEL_LOOP(index, nthreads) { // (n, *, ph, pw) is an element in the pooled output int pw = index % pooled_width; @@ -132,86 +130,110 @@ __global__ void ps_roi_pool_backward_kernel_impl( for (int h = hstart; h < hend; ++h) { for (int w = wstart; w < wend; ++w) { int grad_input_index = h * width + w; - at::native::fastAtomicAdd( - grad_input, offset + grad_input_index, memory_span, diff_val, true); + atomicAdd(grad_input + offset + grad_input_index, diff_val); } } } } -std::tuple ps_roi_pool_forward_kernel( - const at::Tensor& input, - const at::Tensor& rois, +std::tuple ps_roi_pool_forward_kernel( + const Tensor& input, + const Tensor& rois, double spatial_scale, int64_t pooled_height, int64_t pooled_width) { // Check if input tensors are CUDA tensors - TORCH_CHECK(input.is_cuda(), "input must be a CUDA tensor"); - TORCH_CHECK(rois.is_cuda(), "rois must be a CUDA tensor"); - TORCH_CHECK( + VISION_CHECK(input.is_cuda(), "input must be a CUDA tensor"); + VISION_CHECK(rois.is_cuda(), "rois must be a CUDA tensor"); + VISION_CHECK( rois.size(1) == 5, "Tensor rois should have shape as Tensor[K, 5]"); + VISION_CHECK( + input.scalar_type() == rois.scalar_type(), + "input and rois must have the same dtype"); - at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2}; - - at::CheckedFrom c = "ps_roi_pool_forward_kernel"; - at::checkAllSameGPU(c, {input_t, rois_t}); - at::checkAllSameType(c, {input_t, rois_t}); - - at::cuda::CUDAGuard device_guard(input.device()); + DeviceGuard device_guard(input.get_device_index()); auto num_rois = rois.size(0); auto channels = input.size(1); auto height = input.size(2); auto width = input.size(3); - TORCH_CHECK( + VISION_CHECK( channels % (pooled_height * pooled_width) == 0, "input channels must be a multiple of pooling height * pooling width"); int channels_out = channels / (pooled_height * pooled_width); - auto output = at::zeros( - {num_rois, channels_out, pooled_height, pooled_width}, input.options()); - auto channel_mapping = - at::zeros(output.sizes(), input.options().dtype(at::kInt)); + Tensor output = zeros( + {num_rois, channels_out, pooled_height, pooled_width}, + input.scalar_type(), + Device(kCUDA, input.get_device_index())); + Tensor channel_mapping = zeros( + {num_rois, channels_out, pooled_height, pooled_width}, + kInt, + Device(kCUDA, input.get_device_index())); auto output_size = output.numel(); if (output_size == 0) { - AT_CUDA_CHECK(cudaGetLastError()); + STD_CUDA_KERNEL_LAUNCH_CHECK(); return std::make_tuple(output, channel_mapping); } - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + // Get CUDA stream + void* stream_ptr; + TORCH_ERROR_CODE_CHECK(aoti_torch_get_current_cuda_stream( + input.get_device_index(), &stream_ptr)); + cudaStream_t stream = static_cast(stream_ptr); dim3 grid(std::min( ceil_div(static_cast(output_size), static_cast(512)), static_cast(4096))); dim3 block(512); - auto input_ = input.contiguous(), rois_ = rois.contiguous(); - AT_DISPATCH_FLOATING_TYPES_AND_HALF( - input.scalar_type(), "ps_roi_pool_forward_kernel", [&] { - ps_roi_pool_forward_kernel_impl<<>>( - output_size, - input_.data_ptr(), - spatial_scale, - channels, - height, - width, - pooled_height, - pooled_width, - rois_.data_ptr(), - channels_out, - output.data_ptr(), - channel_mapping.data_ptr()); - }); - AT_CUDA_CHECK(cudaGetLastError()); + auto input_ = torch::stable::contiguous(input); + auto rois_ = torch::stable::contiguous(rois); + + auto dtype = input.scalar_type(); + if (dtype == kFloat) { + ps_roi_pool_forward_kernel_impl<<>>( + output_size, + input_.const_data_ptr(), + spatial_scale, + channels, + height, + width, + pooled_height, + pooled_width, + rois_.const_data_ptr(), + channels_out, + output.mutable_data_ptr(), + channel_mapping.mutable_data_ptr()); + } else if (dtype == kDouble) { + ps_roi_pool_forward_kernel_impl<<>>( + output_size, + input_.const_data_ptr(), + spatial_scale, + channels, + height, + width, + pooled_height, + pooled_width, + rois_.const_data_ptr(), + channels_out, + output.mutable_data_ptr(), + channel_mapping.mutable_data_ptr()); + } else { + VISION_CHECK( + false, "ps_roi_pool only supports float and double types"); + } + + STD_CUDA_KERNEL_LAUNCH_CHECK(); return std::make_tuple(output, channel_mapping); } -at::Tensor ps_roi_pool_backward_kernel( - const at::Tensor& grad, - const at::Tensor& rois, - const at::Tensor& channel_mapping, +Tensor ps_roi_pool_backward_kernel( + const Tensor& grad, + const Tensor& rois, + const Tensor& channel_mapping, double spatial_scale, int64_t pooled_height, int64_t pooled_width, @@ -220,25 +242,27 @@ at::Tensor ps_roi_pool_backward_kernel( int64_t height, int64_t width) { // Check if input tensors are CUDA tensors - TORCH_CHECK(grad.is_cuda(), "grad must be a CUDA tensor"); - TORCH_CHECK(rois.is_cuda(), "rois must be a CUDA tensor"); - TORCH_CHECK( + VISION_CHECK(grad.is_cuda(), "grad must be a CUDA tensor"); + VISION_CHECK(rois.is_cuda(), "rois must be a CUDA tensor"); + VISION_CHECK( channel_mapping.is_cuda(), "channel_mapping must be a CUDA tensor"); + VISION_CHECK( + grad.scalar_type() == rois.scalar_type(), + "grad and rois must have the same dtype"); - at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2}, - channel_mapping_t{channel_mapping, "channel_mapping", 3}; - - at::CheckedFrom c = "ps_roi_pool_backward_kernel"; - at::checkAllSameGPU(c, {grad_t, rois_t, channel_mapping_t}); - at::checkAllSameType(c, {grad_t, rois_t}); - - at::cuda::CUDAGuard device_guard(grad.device()); + DeviceGuard device_guard(grad.get_device_index()); auto num_rois = rois.size(0); - auto grad_input = - at::zeros({batch_size, channels, height, width}, grad.options()); + Tensor grad_input = zeros( + {batch_size, channels, height, width}, + grad.scalar_type(), + Device(kCUDA, grad.get_device_index())); - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + // Get CUDA stream + void* stream_ptr; + TORCH_ERROR_CODE_CHECK(aoti_torch_get_current_cuda_stream( + grad.get_device_index(), &stream_ptr)); + cudaStream_t stream = static_cast(stream_ptr); dim3 grid(std::min( ceil_div(static_cast(grad.numel()), static_cast(512)), @@ -247,46 +271,61 @@ at::Tensor ps_roi_pool_backward_kernel( // handle possibly empty gradients if (grad.numel() == 0) { - AT_CUDA_CHECK(cudaGetLastError()); + STD_CUDA_KERNEL_LAUNCH_CHECK(); return grad_input; } int channels_out = channels / (pooled_height * pooled_width); - at::globalContext().alertNotDeterministic("ps_roi_pool_backward_kernel"); - - auto grad_ = grad.contiguous(), rois_ = rois.contiguous(); - AT_DISPATCH_FLOATING_TYPES_AND_HALF( - grad.scalar_type(), "ps_roi_pool_backward_kernel", [&] { - ps_roi_pool_backward_kernel_impl<<>>( - grad.numel(), - grad_.data_ptr(), - channel_mapping.data_ptr(), - num_rois, - spatial_scale, - channels, - height, - width, - pooled_height, - pooled_width, - channels_out, - grad_input.data_ptr(), - rois_.data_ptr(), - grad_input.numel()); - }); - AT_CUDA_CHECK(cudaGetLastError()); + auto grad_ = torch::stable::contiguous(grad); + auto rois_ = torch::stable::contiguous(rois); + + auto dtype = grad.scalar_type(); + if (dtype == kFloat) { + ps_roi_pool_backward_kernel_impl<<>>( + grad.numel(), + grad_.const_data_ptr(), + channel_mapping.const_data_ptr(), + num_rois, + spatial_scale, + channels, + height, + width, + pooled_height, + pooled_width, + channels_out, + grad_input.mutable_data_ptr(), + rois_.const_data_ptr()); + } else if (dtype == kDouble) { + ps_roi_pool_backward_kernel_impl<<>>( + grad.numel(), + grad_.const_data_ptr(), + channel_mapping.const_data_ptr(), + num_rois, + spatial_scale, + channels, + height, + width, + pooled_height, + pooled_width, + channels_out, + grad_input.mutable_data_ptr(), + rois_.const_data_ptr()); + } else { + VISION_CHECK( + false, + "ps_roi_pool backward only supports float and double types"); + } + + STD_CUDA_KERNEL_LAUNCH_CHECK(); return grad_input; } } // namespace -TORCH_LIBRARY_IMPL(torchvision, CUDA, m) { - m.impl( - TORCH_SELECTIVE_NAME("torchvision::ps_roi_pool"), - TORCH_FN(ps_roi_pool_forward_kernel)); - m.impl( - TORCH_SELECTIVE_NAME("torchvision::_ps_roi_pool_backward"), - TORCH_FN(ps_roi_pool_backward_kernel)); +STABLE_TORCH_LIBRARY_IMPL(torchvision, CUDA, m) { + m.impl("ps_roi_pool", TORCH_BOX(&ps_roi_pool_forward_kernel)); + m.impl("_ps_roi_pool_backward", TORCH_BOX(&ps_roi_pool_backward_kernel)); } } // namespace ops diff --git a/torchvision/csrc/ops/cuda/roi_align_kernel.cu b/torchvision/csrc/ops/cuda/roi_align_kernel.cu index 26c53448663..9a5ed82832d 100644 --- a/torchvision/csrc/ops/cuda/roi_align_kernel.cu +++ b/torchvision/csrc/ops/cuda/roi_align_kernel.cu @@ -1,8 +1,5 @@ -#include -#include -#include -#include -#include +#include "../../StableABICompat.h" +#include #include "cuda_helpers.h" @@ -11,6 +8,8 @@ namespace ops { namespace { +using namespace vision::stable; + template __device__ T bilinear_interpolate( const T* input, @@ -218,8 +217,7 @@ __global__ void roi_align_backward_kernel_impl( int n_stride, int c_stride, int h_stride, - int w_stride, - const int memory_span) { + int w_stride) { CUDA_1D_KERNEL_LOOP(index, nthreads) { // (n, c, ph, pw) is an element in the pooled output int pw = index % pooled_width; @@ -301,66 +299,58 @@ __global__ void roi_align_backward_kernel_impl( T g4 = grad_output_this_bin * w4 / count; if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) { - at::native::fastAtomicAdd( - grad_input, - input_offset + y_low * width + x_low, - memory_span, - static_cast(g1), - true); - at::native::fastAtomicAdd( - grad_input, - input_offset + y_low * width + x_high, - memory_span, - static_cast(g2), - true); - at::native::fastAtomicAdd( - grad_input, - input_offset + y_high * width + x_low, - memory_span, - static_cast(g3), - true); - at::native::fastAtomicAdd( - grad_input, - input_offset + y_high * width + x_high, - memory_span, - static_cast(g4), - true); + atomicAdd( + grad_input + input_offset + y_low * width + x_low, + static_cast(g1)); + atomicAdd( + grad_input + input_offset + y_low * width + x_high, + static_cast(g2)); + atomicAdd( + grad_input + input_offset + y_high * width + x_low, + static_cast(g3)); + atomicAdd( + grad_input + input_offset + y_high * width + x_high, + static_cast(g4)); } // if } // ix } // iy } // CUDA_1D_KERNEL_LOOP } -at::Tensor roi_align_forward_kernel( - const at::Tensor& input, - const at::Tensor& rois, +Tensor roi_align_forward_kernel( + const Tensor& input, + const Tensor& rois, double spatial_scale, int64_t pooled_height, int64_t pooled_width, int64_t sampling_ratio, bool aligned) { - TORCH_CHECK(input.is_cuda(), "input must be a CUDA tensor"); - TORCH_CHECK(rois.is_cuda(), "rois must be a CUDA tensor"); - TORCH_CHECK(rois.size(1) == 5, "rois must have shape as Tensor[K, 5]"); - - at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2}; + VISION_CHECK(input.is_cuda(), "input must be a CUDA tensor"); + VISION_CHECK(rois.is_cuda(), "rois must be a CUDA tensor"); + VISION_CHECK(rois.size(1) == 5, "rois must have shape as Tensor[K, 5]"); + VISION_CHECK( + input.scalar_type() == rois.scalar_type(), + "input and rois must have the same dtype"); - at::CheckedFrom c = "roi_align_forward_kernel"; - at::checkAllSameGPU(c, {input_t, rois_t}); - at::checkAllSameType(c, {input_t, rois_t}); - - at::cuda::CUDAGuard device_guard(input.device()); + DeviceGuard device_guard(input.get_device_index()); auto num_rois = rois.size(0); auto channels = input.size(1); auto height = input.size(2); auto width = input.size(3); - at::Tensor output = at::zeros( - {num_rois, channels, pooled_height, pooled_width}, input.options()); + Tensor output = zeros( + {num_rois, channels, pooled_height, pooled_width}, + input.scalar_type(), + Device(kCUDA, input.get_device_index())); auto output_size = num_rois * pooled_height * pooled_width * channels; - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Get CUDA stream + void* stream_ptr; + TORCH_ERROR_CODE_CHECK(aoti_torch_get_current_cuda_stream( + input.get_device_index(), &stream_ptr)); + cudaStream_t stream = static_cast(stream_ptr); dim3 grid(std::min( ceil_div(static_cast(output_size), static_cast(512)), @@ -368,34 +358,54 @@ at::Tensor roi_align_forward_kernel( dim3 block(512); if (output.numel() == 0) { - AT_CUDA_CHECK(cudaGetLastError()); + STD_CUDA_KERNEL_LAUNCH_CHECK(); return output; } - auto input_ = input.contiguous(), rois_ = rois.contiguous(); - AT_DISPATCH_FLOATING_TYPES_AND_HALF( - input.scalar_type(), "roi_align_forward_kernel", [&] { - roi_align_forward_kernel_impl<<>>( - output_size, - input_.data_ptr(), - spatial_scale, - channels, - height, - width, - pooled_height, - pooled_width, - sampling_ratio, - aligned, - rois_.data_ptr(), - output.data_ptr()); - }); - AT_CUDA_CHECK(cudaGetLastError()); + auto input_ = torch::stable::contiguous(input); + auto rois_ = torch::stable::contiguous(rois); + + auto dtype = input.scalar_type(); + if (dtype == kFloat) { + roi_align_forward_kernel_impl<<>>( + output_size, + input_.const_data_ptr(), + spatial_scale, + channels, + height, + width, + pooled_height, + pooled_width, + sampling_ratio, + aligned, + rois_.const_data_ptr(), + output.mutable_data_ptr()); + } else if (dtype == kDouble) { + roi_align_forward_kernel_impl<<>>( + output_size, + input_.const_data_ptr(), + spatial_scale, + channels, + height, + width, + pooled_height, + pooled_width, + sampling_ratio, + aligned, + rois_.const_data_ptr(), + output.mutable_data_ptr()); + } else { + VISION_CHECK( + false, "roi_align only supports float and double types"); + } + + STD_CUDA_KERNEL_LAUNCH_CHECK(); return output; } -at::Tensor roi_align_backward_kernel( - const at::Tensor& grad, - const at::Tensor& rois, +Tensor roi_align_backward_kernel( + const Tensor& grad, + const Tensor& rois, double spatial_scale, int64_t pooled_height, int64_t pooled_width, @@ -405,21 +415,24 @@ at::Tensor roi_align_backward_kernel( int64_t width, int64_t sampling_ratio, bool aligned) { - TORCH_CHECK(grad.is_cuda(), "grad must be a CUDA tensor"); - TORCH_CHECK(rois.is_cuda(), "rois must be a CUDA tensor"); - - at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2}; + VISION_CHECK(grad.is_cuda(), "grad must be a CUDA tensor"); + VISION_CHECK(rois.is_cuda(), "rois must be a CUDA tensor"); + VISION_CHECK( + grad.scalar_type() == rois.scalar_type(), + "grad and rois must have the same dtype"); - at::CheckedFrom c = "roi_align_backward_kernel"; - at::checkAllSameGPU(c, {grad_t, rois_t}); - at::checkAllSameType(c, {grad_t, rois_t}); + DeviceGuard device_guard(grad.get_device_index()); - at::cuda::CUDAGuard device_guard(grad.device()); + Tensor grad_input = zeros( + {batch_size, channels, height, width}, + grad.scalar_type(), + Device(kCUDA, grad.get_device_index())); - at::Tensor grad_input = - at::zeros({batch_size, channels, height, width}, grad.options()); - - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + // Get CUDA stream + void* stream_ptr; + TORCH_ERROR_CODE_CHECK(aoti_torch_get_current_cuda_stream( + grad.get_device_index(), &stream_ptr)); + cudaStream_t stream = static_cast(stream_ptr); dim3 grid(std::min( ceil_div(static_cast(grad.numel()), static_cast(512)), @@ -428,7 +441,7 @@ at::Tensor roi_align_backward_kernel( // handle possibly empty gradients if (grad.numel() == 0) { - AT_CUDA_CHECK(cudaGetLastError()); + STD_CUDA_KERNEL_LAUNCH_CHECK(); return grad_input; } @@ -437,43 +450,60 @@ at::Tensor roi_align_backward_kernel( int h_stride = grad.stride(2); int w_stride = grad.stride(3); - at::globalContext().alertNotDeterministic("roi_align_backward_kernel"); + auto rois_ = torch::stable::contiguous(rois); + + auto dtype = grad.scalar_type(); + if (dtype == kFloat) { + roi_align_backward_kernel_impl<<>>( + grad.numel(), + grad.const_data_ptr(), + spatial_scale, + channels, + height, + width, + pooled_height, + pooled_width, + sampling_ratio, + aligned, + grad_input.mutable_data_ptr(), + rois_.const_data_ptr(), + n_stride, + c_stride, + h_stride, + w_stride); + } else if (dtype == kDouble) { + roi_align_backward_kernel_impl<<>>( + grad.numel(), + grad.const_data_ptr(), + spatial_scale, + channels, + height, + width, + pooled_height, + pooled_width, + sampling_ratio, + aligned, + grad_input.mutable_data_ptr(), + rois_.const_data_ptr(), + n_stride, + c_stride, + h_stride, + w_stride); + } else { + VISION_CHECK( + false, + "roi_align backward only supports float and double types"); + } - auto rois_ = rois.contiguous(); - AT_DISPATCH_FLOATING_TYPES_AND_HALF( - grad.scalar_type(), "roi_align_backward_kernel", [&] { - roi_align_backward_kernel_impl<<>>( - grad.numel(), - grad.data_ptr(), - spatial_scale, - channels, - height, - width, - pooled_height, - pooled_width, - sampling_ratio, - aligned, - grad_input.data_ptr(), - rois_.data_ptr(), - n_stride, - c_stride, - h_stride, - w_stride, - grad_input.numel()); - }); - AT_CUDA_CHECK(cudaGetLastError()); + STD_CUDA_KERNEL_LAUNCH_CHECK(); return grad_input; } } // namespace -TORCH_LIBRARY_IMPL(torchvision, CUDA, m) { - m.impl( - TORCH_SELECTIVE_NAME("torchvision::roi_align"), - TORCH_FN(roi_align_forward_kernel)); - m.impl( - TORCH_SELECTIVE_NAME("torchvision::_roi_align_backward"), - TORCH_FN(roi_align_backward_kernel)); +STABLE_TORCH_LIBRARY_IMPL(torchvision, CUDA, m) { + m.impl("roi_align", TORCH_BOX(&roi_align_forward_kernel)); + m.impl("_roi_align_backward", TORCH_BOX(&roi_align_backward_kernel)); } } // namespace ops diff --git a/torchvision/csrc/ops/cuda/roi_pool_kernel.cu b/torchvision/csrc/ops/cuda/roi_pool_kernel.cu index 3a9374bb438..cf33ca67442 100644 --- a/torchvision/csrc/ops/cuda/roi_pool_kernel.cu +++ b/torchvision/csrc/ops/cuda/roi_pool_kernel.cu @@ -1,9 +1,7 @@ -#include -#include -#include #include -#include -#include + +#include "../../StableABICompat.h" +#include #include "cuda_helpers.h" @@ -12,6 +10,8 @@ namespace ops { namespace { +using namespace vision::stable; + template __global__ void roi_pool_forward_kernel_impl( int nthreads, @@ -94,8 +94,7 @@ __global__ void roi_pool_backward_kernel_impl( int n_stride, int c_stride, int h_stride, - int w_stride, - const int memory_span) { + int w_stride) { CUDA_1D_KERNEL_LOOP(index, nthreads) { // (n, c, ph, pw) is an element in the pooled output int pw = index % pooled_width; @@ -113,49 +112,51 @@ __global__ void roi_pool_backward_kernel_impl( const int offset = (roi_batch_ind * channels + c) * height * width; if (argmax != -1) { - at::native::fastAtomicAdd( - grad_input, - offset + argmax, - memory_span, + atomicAdd( + grad_input + offset + argmax, static_cast( - grad_output[output_offset + ph * h_stride + pw * w_stride]), - true); + grad_output[output_offset + ph * h_stride + pw * w_stride])); } } } -std::tuple roi_pool_forward_kernel( - const at::Tensor& input, - const at::Tensor& rois, +std::tuple roi_pool_forward_kernel( + const Tensor& input, + const Tensor& rois, double spatial_scale, int64_t pooled_height, int64_t pooled_width) { - TORCH_CHECK(input.is_cuda(), "input must be a CUDA tensor"); - TORCH_CHECK(rois.is_cuda(), "rois must be a CUDA tensor"); - TORCH_CHECK( + VISION_CHECK(input.is_cuda(), "input must be a CUDA tensor"); + VISION_CHECK(rois.is_cuda(), "rois must be a CUDA tensor"); + VISION_CHECK( rois.size(1) == 5, "Tensor rois should have shape as Tensor[K, 5]"); + VISION_CHECK( + input.scalar_type() == rois.scalar_type(), + "input and rois must have the same dtype"); - at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2}; - - at::CheckedFrom c = "roi_pool_forward_kernel"; - at::checkAllSameGPU(c, {input_t, rois_t}); - at::checkAllSameType(c, {input_t, rois_t}); - - at::cuda::CUDAGuard device_guard(input.device()); + DeviceGuard device_guard(input.get_device_index()); auto num_rois = rois.size(0); auto channels = input.size(1); auto height = input.size(2); auto width = input.size(3); - at::Tensor output = at::zeros( - {num_rois, channels, pooled_height, pooled_width}, input.options()); - at::Tensor argmax = at::zeros( + Tensor output = zeros( {num_rois, channels, pooled_height, pooled_width}, - input.options().dtype(at::kInt)); + input.scalar_type(), + Device(kCUDA, input.get_device_index())); + Tensor argmax = zeros( + {num_rois, channels, pooled_height, pooled_width}, + kInt, + Device(kCUDA, input.get_device_index())); auto output_size = num_rois * pooled_height * pooled_width * channels; - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Get CUDA stream + void* stream_ptr; + TORCH_ERROR_CODE_CHECK(aoti_torch_get_current_cuda_stream( + input.get_device_index(), &stream_ptr)); + cudaStream_t stream = static_cast(stream_ptr); dim3 grid(std::min( ceil_div(static_cast(output_size), static_cast(512)), @@ -163,34 +164,52 @@ std::tuple roi_pool_forward_kernel( dim3 block(512); if (output.numel() == 0) { - AT_CUDA_CHECK(cudaGetLastError()); + STD_CUDA_KERNEL_LAUNCH_CHECK(); return std::make_tuple(output, argmax); } - auto input_ = input.contiguous(), rois_ = rois.contiguous(); - AT_DISPATCH_FLOATING_TYPES_AND_HALF( - input.scalar_type(), "roi_pool_forward_kernel", [&] { - roi_pool_forward_kernel_impl<<>>( - output_size, - input_.data_ptr(), - spatial_scale, - channels, - height, - width, - pooled_height, - pooled_width, - rois_.data_ptr(), - output.data_ptr(), - argmax.data_ptr()); - }); - AT_CUDA_CHECK(cudaGetLastError()); + auto input_ = torch::stable::contiguous(input); + auto rois_ = torch::stable::contiguous(rois); + + auto dtype = input.scalar_type(); + if (dtype == kFloat) { + roi_pool_forward_kernel_impl<<>>( + output_size, + input_.const_data_ptr(), + spatial_scale, + channels, + height, + width, + pooled_height, + pooled_width, + rois_.const_data_ptr(), + output.mutable_data_ptr(), + argmax.mutable_data_ptr()); + } else if (dtype == kDouble) { + roi_pool_forward_kernel_impl<<>>( + output_size, + input_.const_data_ptr(), + spatial_scale, + channels, + height, + width, + pooled_height, + pooled_width, + rois_.const_data_ptr(), + output.mutable_data_ptr(), + argmax.mutable_data_ptr()); + } else { + VISION_CHECK(false, "roi_pool only supports float and double types"); + } + + STD_CUDA_KERNEL_LAUNCH_CHECK(); return std::make_tuple(output, argmax); } -at::Tensor roi_pool_backward_kernel( - const at::Tensor& grad, - const at::Tensor& rois, - const at::Tensor& argmax, +Tensor roi_pool_backward_kernel( + const Tensor& grad, + const Tensor& rois, + const Tensor& argmax, double spatial_scale, int64_t pooled_height, int64_t pooled_width, @@ -199,25 +218,29 @@ at::Tensor roi_pool_backward_kernel( int64_t height, int64_t width) { // Check if input tensors are CUDA tensors - TORCH_CHECK(grad.is_cuda(), "grad must be a CUDA tensor"); - TORCH_CHECK(rois.is_cuda(), "rois must be a CUDA tensor"); - TORCH_CHECK(argmax.is_cuda(), "argmax must be a CUDA tensor"); - - at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2}, - argmax_t{argmax, "argmax", 3}; - - at::CheckedFrom c = "roi_pool_backward_kernel"; - at::checkAllSameGPU(c, {grad_t, rois_t, argmax_t}); - at::checkAllSameType(c, {grad_t, rois_t}); + VISION_CHECK(grad.is_cuda(), "grad must be a CUDA tensor"); + VISION_CHECK(rois.is_cuda(), "rois must be a CUDA tensor"); + VISION_CHECK(argmax.is_cuda(), "argmax must be a CUDA tensor"); + VISION_CHECK( + rois.size(1) == 5, "Tensor rois should have shape as Tensor[K, 5]"); + VISION_CHECK( + grad.scalar_type() == rois.scalar_type(), + "grad and rois must have the same dtype"); - at::cuda::CUDAGuard device_guard(grad.device()); + DeviceGuard device_guard(grad.get_device_index()); auto num_rois = rois.size(0); - at::Tensor grad_input = - at::zeros({batch_size, channels, height, width}, grad.options()); + Tensor grad_input = zeros( + {batch_size, channels, height, width}, + grad.scalar_type(), + Device(kCUDA, grad.get_device_index())); - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + // Get CUDA stream + void* stream_ptr; + TORCH_ERROR_CODE_CHECK(aoti_torch_get_current_cuda_stream( + grad.get_device_index(), &stream_ptr)); + cudaStream_t stream = static_cast(stream_ptr); dim3 grid(std::min( ceil_div(static_cast(grad.numel()), static_cast(512)), @@ -226,7 +249,7 @@ at::Tensor roi_pool_backward_kernel( // handle possibly empty gradients if (grad.numel() == 0) { - AT_CUDA_CHECK(cudaGetLastError()); + STD_CUDA_KERNEL_LAUNCH_CHECK(); return grad_input; } @@ -235,43 +258,60 @@ at::Tensor roi_pool_backward_kernel( int h_stride = grad.stride(2); int w_stride = grad.stride(3); - at::globalContext().alertNotDeterministic("roi_pool_backward_kernel"); - - auto argmax_ = argmax.contiguous(), rois_ = rois.contiguous(); - AT_DISPATCH_FLOATING_TYPES_AND_HALF( - grad.scalar_type(), "roi_pool_backward_kernel", [&] { - roi_pool_backward_kernel_impl<<>>( - grad.numel(), - grad.data_ptr(), - argmax_.data_ptr(), - num_rois, - spatial_scale, - channels, - height, - width, - pooled_height, - pooled_width, - grad_input.data_ptr(), - rois_.data_ptr(), - n_stride, - c_stride, - h_stride, - w_stride, - grad_input.numel()); - }); - AT_CUDA_CHECK(cudaGetLastError()); + auto argmax_ = torch::stable::contiguous(argmax); + auto rois_ = torch::stable::contiguous(rois); + + auto dtype = grad.scalar_type(); + if (dtype == kFloat) { + roi_pool_backward_kernel_impl<<>>( + grad.numel(), + grad.const_data_ptr(), + argmax_.const_data_ptr(), + num_rois, + spatial_scale, + channels, + height, + width, + pooled_height, + pooled_width, + grad_input.mutable_data_ptr(), + rois_.const_data_ptr(), + n_stride, + c_stride, + h_stride, + w_stride); + } else if (dtype == kDouble) { + roi_pool_backward_kernel_impl<<>>( + grad.numel(), + grad.const_data_ptr(), + argmax_.const_data_ptr(), + num_rois, + spatial_scale, + channels, + height, + width, + pooled_height, + pooled_width, + grad_input.mutable_data_ptr(), + rois_.const_data_ptr(), + n_stride, + c_stride, + h_stride, + w_stride); + } else { + VISION_CHECK( + false, "roi_pool backward only supports float and double types"); + } + + STD_CUDA_KERNEL_LAUNCH_CHECK(); return grad_input; } } // namespace -TORCH_LIBRARY_IMPL(torchvision, CUDA, m) { - m.impl( - TORCH_SELECTIVE_NAME("torchvision::roi_pool"), - TORCH_FN(roi_pool_forward_kernel)); - m.impl( - TORCH_SELECTIVE_NAME("torchvision::_roi_pool_backward"), - TORCH_FN(roi_pool_backward_kernel)); +STABLE_TORCH_LIBRARY_IMPL(torchvision, CUDA, m) { + m.impl("roi_pool", TORCH_BOX(&roi_pool_forward_kernel)); + m.impl("_roi_pool_backward", TORCH_BOX(&roi_pool_backward_kernel)); } } // namespace ops diff --git a/torchvision/csrc/ops/nms.cpp b/torchvision/csrc/ops/nms.cpp index 5ecf8812f1b..0b691c7f578 100644 --- a/torchvision/csrc/ops/nms.cpp +++ b/torchvision/csrc/ops/nms.cpp @@ -1,28 +1,17 @@ #include "nms.h" -#include -#include -#include +#include namespace vision { namespace ops { -at::Tensor nms( - const at::Tensor& dets, - const at::Tensor& scores, - double iou_threshold) { - C10_LOG_API_USAGE_ONCE("torchvision.csrc.ops.nms.nms"); - static auto op = c10::Dispatcher::singleton() - .findSchemaOrThrow("torchvision::nms", "") - .typed(); - return op.call(dets, scores, iou_threshold); -} - -TORCH_LIBRARY_FRAGMENT(torchvision, m) { - m.set_python_module("torchvision._meta_registrations"); - m.def(TORCH_SELECTIVE_SCHEMA( - "torchvision::nms(Tensor dets, Tensor scores, float iou_threshold) -> Tensor")); -} +// Note: With stable ABI, ops are called directly via torch.ops.torchvision.nms +// The dispatcher wrapper functions are no longer needed. } // namespace ops } // namespace vision + +STABLE_TORCH_LIBRARY_FRAGMENT(torchvision, m) { + m.def( + "nms(Tensor dets, Tensor scores, float iou_threshold) -> Tensor"); +} diff --git a/torchvision/csrc/ops/nms.h b/torchvision/csrc/ops/nms.h index 8c75a242bff..4edcdabb3c3 100644 --- a/torchvision/csrc/ops/nms.h +++ b/torchvision/csrc/ops/nms.h @@ -1,15 +1,13 @@ #pragma once -#include +#include "../StableABICompat.h" #include "../macros.h" namespace vision { namespace ops { -VISION_API at::Tensor nms( - const at::Tensor& dets, - const at::Tensor& scores, - double iou_threshold); +// Note: With stable ABI, nms is called directly via torch.ops.torchvision.nms +// This header is kept for backwards compatibility but the C++ API is deprecated. } // namespace ops } // namespace vision diff --git a/torchvision/csrc/ops/ps_roi_align.cpp b/torchvision/csrc/ops/ps_roi_align.cpp index de458c0d62d..2243241463c 100644 --- a/torchvision/csrc/ops/ps_roi_align.cpp +++ b/torchvision/csrc/ops/ps_roi_align.cpp @@ -1,112 +1,19 @@ #include "ps_roi_align.h" -#include -#include -#include +#include namespace vision { namespace ops { -std::tuple ps_roi_align( - const at::Tensor& input, - const at::Tensor& rois, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, - int64_t sampling_ratio) { - C10_LOG_API_USAGE_ONCE("torchvision.csrc.ops.ps_roi_align.ps_roi_align"); - static auto op = c10::Dispatcher::singleton() - .findSchemaOrThrow("torchvision::ps_roi_align", "") - .typed(); - return op.call( - input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio); -} - -std::tuple ps_roi_align_symint( - const at::Tensor& input, - const at::Tensor& rois, - double spatial_scale, - c10::SymInt pooled_height, - c10::SymInt pooled_width, - int64_t sampling_ratio) { - C10_LOG_API_USAGE_ONCE("torchvision.csrc.ops.ps_roi_align.ps_roi_align"); - static auto op = c10::Dispatcher::singleton() - .findSchemaOrThrow("torchvision::ps_roi_align", "") - .typed(); - return op.call( - input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio); -} - -namespace detail { - -at::Tensor _ps_roi_align_backward( - const at::Tensor& grad, - const at::Tensor& rois, - const at::Tensor& channel_mapping, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, - int64_t sampling_ratio, - int64_t batch_size, - int64_t channels, - int64_t height, - int64_t width) { - static auto op = - c10::Dispatcher::singleton() - .findSchemaOrThrow("torchvision::_ps_roi_align_backward", "") - .typed(); - return op.call( - grad, - rois, - channel_mapping, - spatial_scale, - pooled_height, - pooled_width, - sampling_ratio, - batch_size, - channels, - height, - width); -} - -at::Tensor _ps_roi_align_backward_symint( - const at::Tensor& grad, - const at::Tensor& rois, - const at::Tensor& channel_mapping, - double spatial_scale, - c10::SymInt pooled_height, - c10::SymInt pooled_width, - int64_t sampling_ratio, - c10::SymInt batch_size, - c10::SymInt channels, - c10::SymInt height, - c10::SymInt width) { - static auto op = - c10::Dispatcher::singleton() - .findSchemaOrThrow("torchvision::_ps_roi_align_backward", "") - .typed(); - return op.call( - grad, - rois, - channel_mapping, - spatial_scale, - pooled_height, - pooled_width, - sampling_ratio, - batch_size, - channels, - height, - width); -} - -} // namespace detail - -TORCH_LIBRARY_FRAGMENT(torchvision, m) { - m.def(TORCH_SELECTIVE_SCHEMA( - "torchvision::ps_roi_align(Tensor input, Tensor rois, float spatial_scale, SymInt pooled_height, SymInt pooled_width, int sampling_ratio) -> (Tensor, Tensor)")); - m.def(TORCH_SELECTIVE_SCHEMA( - "torchvision::_ps_roi_align_backward(Tensor grad, Tensor rois, Tensor channel_mapping, float spatial_scale, SymInt pooled_height, SymInt pooled_width, int sampling_ratio, SymInt batch_size, SymInt channels, SymInt height, SymInt width) -> Tensor")); -} +// Note: With stable ABI, ops are called directly via torch.ops.torchvision.* +// The dispatcher wrapper functions are no longer needed. } // namespace ops } // namespace vision + +STABLE_TORCH_LIBRARY_FRAGMENT(torchvision, m) { + m.def( + "ps_roi_align(Tensor input, Tensor rois, float spatial_scale, SymInt pooled_height, SymInt pooled_width, int sampling_ratio) -> (Tensor, Tensor)"); + m.def( + "_ps_roi_align_backward(Tensor grad, Tensor rois, Tensor channel_mapping, float spatial_scale, SymInt pooled_height, SymInt pooled_width, int sampling_ratio, SymInt batch_size, SymInt channels, SymInt height, SymInt width) -> Tensor"); +} diff --git a/torchvision/csrc/ops/ps_roi_align.h b/torchvision/csrc/ops/ps_roi_align.h index 75650586bc6..b049ce9d6fa 100644 --- a/torchvision/csrc/ops/ps_roi_align.h +++ b/torchvision/csrc/ops/ps_roi_align.h @@ -1,56 +1,13 @@ #pragma once -#include +#include "../StableABICompat.h" #include "../macros.h" namespace vision { namespace ops { -VISION_API std::tuple ps_roi_align( - const at::Tensor& input, - const at::Tensor& rois, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, - int64_t sampling_ratio); - -VISION_API std::tuple ps_roi_align_symint( - const at::Tensor& input, - const at::Tensor& rois, - double spatial_scale, - c10::SymInt pooled_height, - c10::SymInt pooled_width, - int64_t sampling_ratio); - -namespace detail { - -at::Tensor _ps_roi_align_backward( - const at::Tensor& grad, - const at::Tensor& rois, - const at::Tensor& channel_mapping, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, - int64_t sampling_ratio, - int64_t batch_size, - int64_t channels, - int64_t height, - int64_t width); - -at::Tensor _ps_roi_align_backward_symint( - const at::Tensor& grad, - const at::Tensor& rois, - const at::Tensor& channel_mapping, - double spatial_scale, - c10::SymInt pooled_height, - c10::SymInt pooled_width, - int64_t sampling_ratio, - c10::SymInt batch_size, - c10::SymInt channels, - c10::SymInt height, - c10::SymInt width); - -} // namespace detail +// Note: With stable ABI, ops are called directly via torch.ops.torchvision.* +// This header is kept for backwards compatibility but the C++ API is deprecated. } // namespace ops } // namespace vision diff --git a/torchvision/csrc/ops/ps_roi_pool.cpp b/torchvision/csrc/ops/ps_roi_pool.cpp index 92469d5e380..3f7b607e3f1 100644 --- a/torchvision/csrc/ops/ps_roi_pool.cpp +++ b/torchvision/csrc/ops/ps_roi_pool.cpp @@ -1,104 +1,22 @@ #include "ps_roi_pool.h" -#include -#include -#include +#include namespace vision { namespace ops { -std::tuple ps_roi_pool( - const at::Tensor& input, - const at::Tensor& rois, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width) { - C10_LOG_API_USAGE_ONCE("torchvision.csrc.ops.ps_roi_pool.ps_roi_pool"); - static auto op = c10::Dispatcher::singleton() - .findSchemaOrThrow("torchvision::ps_roi_pool", "") - .typed(); - return op.call(input, rois, spatial_scale, pooled_height, pooled_width); -} - -std::tuple ps_roi_pool_symint( - const at::Tensor& input, - const at::Tensor& rois, - double spatial_scale, - c10::SymInt pooled_height, - c10::SymInt pooled_width) { - C10_LOG_API_USAGE_ONCE("torchvision.csrc.ops.ps_roi_pool.ps_roi_pool"); - static auto op = c10::Dispatcher::singleton() - .findSchemaOrThrow("torchvision::ps_roi_pool", "") - .typed(); - return op.call(input, rois, spatial_scale, pooled_height, pooled_width); -} - -namespace detail { - -at::Tensor _ps_roi_pool_backward( - const at::Tensor& grad, - const at::Tensor& rois, - const at::Tensor& channel_mapping, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, - int64_t batch_size, - int64_t channels, - int64_t height, - int64_t width) { - static auto op = - c10::Dispatcher::singleton() - .findSchemaOrThrow("torchvision::_ps_roi_pool_backward", "") - .typed(); - return op.call( - grad, - rois, - channel_mapping, - spatial_scale, - pooled_height, - pooled_width, - batch_size, - channels, - height, - width); -} - -at::Tensor _ps_roi_pool_backward_symint( - const at::Tensor& grad, - const at::Tensor& rois, - const at::Tensor& channel_mapping, - double spatial_scale, - c10::SymInt pooled_height, - c10::SymInt pooled_width, - c10::SymInt batch_size, - c10::SymInt channels, - c10::SymInt height, - c10::SymInt width) { - static auto op = - c10::Dispatcher::singleton() - .findSchemaOrThrow("torchvision::_ps_roi_pool_backward", "") - .typed(); - return op.call( - grad, - rois, - channel_mapping, - spatial_scale, - pooled_height, - pooled_width, - batch_size, - channels, - height, - width); -} - -} // namespace detail - -TORCH_LIBRARY_FRAGMENT(torchvision, m) { - m.def(TORCH_SELECTIVE_SCHEMA( - "torchvision::ps_roi_pool(Tensor input, Tensor rois, float spatial_scale, SymInt pooled_height, SymInt pooled_width) -> (Tensor, Tensor)")); - m.def(TORCH_SELECTIVE_SCHEMA( - "torchvision::_ps_roi_pool_backward(Tensor grad, Tensor rois, Tensor channel_mapping, float spatial_scale, SymInt pooled_height, SymInt pooled_width, SymInt batch_size, SymInt channels, SymInt height, SymInt width) -> Tensor")); -} +// Note: With stable ABI, ops are called directly via torch.ops.torchvision.* +// The dispatcher wrapper functions are no longer needed. } // namespace ops } // namespace vision + +STABLE_TORCH_LIBRARY_FRAGMENT(torchvision, m) { + m.def( + "ps_roi_pool(Tensor input, Tensor rois, float spatial_scale, " + "SymInt pooled_height, SymInt pooled_width) -> (Tensor, Tensor)"); + m.def( + "_ps_roi_pool_backward(Tensor grad, Tensor rois, Tensor channel_mapping, " + "float spatial_scale, SymInt pooled_height, SymInt pooled_width, " + "SymInt batch_size, SymInt channels, SymInt height, SymInt width) -> Tensor"); +} diff --git a/torchvision/csrc/ops/ps_roi_pool.h b/torchvision/csrc/ops/ps_roi_pool.h index 4a3cc54e0e5..b049ce9d6fa 100644 --- a/torchvision/csrc/ops/ps_roi_pool.h +++ b/torchvision/csrc/ops/ps_roi_pool.h @@ -1,52 +1,13 @@ #pragma once -#include +#include "../StableABICompat.h" #include "../macros.h" namespace vision { namespace ops { -VISION_API std::tuple ps_roi_pool( - const at::Tensor& input, - const at::Tensor& rois, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width); - -VISION_API std::tuple ps_roi_pool_symint( - const at::Tensor& input, - const at::Tensor& rois, - double spatial_scale, - c10::SymInt pooled_height, - c10::SymInt pooled_width); - -namespace detail { - -at::Tensor _ps_roi_pool_backward( - const at::Tensor& grad, - const at::Tensor& rois, - const at::Tensor& channel_mapping, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, - int64_t batch_size, - int64_t channels, - int64_t height, - int64_t width); - -at::Tensor _ps_roi_pool_backward_symint( - const at::Tensor& grad, - const at::Tensor& rois, - const at::Tensor& channel_mapping, - double spatial_scale, - c10::SymInt pooled_height, - c10::SymInt pooled_width, - c10::SymInt batch_size, - c10::SymInt channels, - c10::SymInt height, - c10::SymInt width); - -} // namespace detail +// Note: With stable ABI, ops are called directly via torch.ops.torchvision.* +// This header is kept for backwards compatibility but the C++ API is deprecated. } // namespace ops } // namespace vision diff --git a/torchvision/csrc/ops/roi_align.cpp b/torchvision/csrc/ops/roi_align.cpp index aa6dccb44f2..51d6421d9cc 100644 --- a/torchvision/csrc/ops/roi_align.cpp +++ b/torchvision/csrc/ops/roi_align.cpp @@ -1,132 +1,24 @@ #include "roi_align.h" -#include -#include -#include +#include namespace vision { namespace ops { -at::Tensor roi_align( - const at::Tensor& input, // Input feature map. - const at::Tensor& rois, // List of ROIs to pool over. - double spatial_scale, // The scale of the image features. ROIs will be - // scaled to this. - int64_t pooled_height, // The height of the pooled feature map. - int64_t pooled_width, // The width of the pooled feature - int64_t sampling_ratio, // The number of points to sample in each bin - bool aligned) // The flag for pixel shift -// along each axis. -{ - C10_LOG_API_USAGE_ONCE("torchvision.csrc.ops.roi_align.roi_align"); - static auto op = c10::Dispatcher::singleton() - .findSchemaOrThrow("torchvision::roi_align", "") - .typed(); - return op.call( - input, - rois, - spatial_scale, - pooled_height, - pooled_width, - sampling_ratio, - aligned); -} - -at::Tensor roi_align_symint( - const at::Tensor& input, // Input feature map. - const at::Tensor& rois, // List of ROIs to pool over. - double spatial_scale, // The scale of the image features. ROIs will be - // scaled to this. - c10::SymInt pooled_height, // The height of the pooled feature map. - c10::SymInt pooled_width, // The width of the pooled feature - int64_t sampling_ratio, // The number of points to sample in each bin - bool aligned) // The flag for pixel shift -// along each axis. -{ - C10_LOG_API_USAGE_ONCE("torchvision.csrc.ops.roi_align.roi_align"); - static auto op = c10::Dispatcher::singleton() - .findSchemaOrThrow("torchvision::roi_align", "") - .typed(); - return op.call( - input, - rois, - spatial_scale, - pooled_height, - pooled_width, - sampling_ratio, - aligned); -} - -namespace detail { - -at::Tensor _roi_align_backward( - const at::Tensor& grad, - const at::Tensor& rois, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, - int64_t batch_size, - int64_t channels, - int64_t height, - int64_t width, - int64_t sampling_ratio, - bool aligned) { - static auto op = - c10::Dispatcher::singleton() - .findSchemaOrThrow("torchvision::_roi_align_backward", "") - .typed(); - return op.call( - grad, - rois, - spatial_scale, - pooled_height, - pooled_width, - batch_size, - channels, - height, - width, - sampling_ratio, - aligned); -} - -at::Tensor _roi_align_backward_symint( - const at::Tensor& grad, - const at::Tensor& rois, - double spatial_scale, - c10::SymInt pooled_height, - c10::SymInt pooled_width, - c10::SymInt batch_size, - c10::SymInt channels, - c10::SymInt height, - c10::SymInt width, - int64_t sampling_ratio, - bool aligned) { - static auto op = - c10::Dispatcher::singleton() - .findSchemaOrThrow("torchvision::_roi_align_backward", "") - .typed(); - return op.call( - grad, - rois, - spatial_scale, - pooled_height, - pooled_width, - batch_size, - channels, - height, - width, - sampling_ratio, - aligned); -} - -} // namespace detail - -TORCH_LIBRARY_FRAGMENT(torchvision, m) { - m.def(TORCH_SELECTIVE_SCHEMA( - "torchvision::roi_align(Tensor input, Tensor rois, float spatial_scale, SymInt pooled_height, SymInt pooled_width, int sampling_ratio, bool aligned) -> Tensor")); - m.def(TORCH_SELECTIVE_SCHEMA( - "torchvision::_roi_align_backward(Tensor grad, Tensor rois, float spatial_scale, SymInt pooled_height, SymInt pooled_width, SymInt batch_size, SymInt channels, SymInt height, SymInt width, int sampling_ratio, bool aligned) -> Tensor")); -} +// Note: With stable ABI, ops are called directly via torch.ops.torchvision.* +// The dispatcher wrapper functions are no longer needed. } // namespace ops } // namespace vision + +STABLE_TORCH_LIBRARY_FRAGMENT(torchvision, m) { + m.def( + "roi_align(Tensor input, Tensor rois, float spatial_scale, " + "SymInt pooled_height, SymInt pooled_width, int sampling_ratio, " + "bool aligned) -> Tensor"); + m.def( + "_roi_align_backward(Tensor grad, Tensor rois, float spatial_scale, " + "SymInt pooled_height, SymInt pooled_width, SymInt batch_size, " + "SymInt channels, SymInt height, SymInt width, int sampling_ratio, " + "bool aligned) -> Tensor"); +} diff --git a/torchvision/csrc/ops/roi_align.h b/torchvision/csrc/ops/roi_align.h index 072d6d4231c..b049ce9d6fa 100644 --- a/torchvision/csrc/ops/roi_align.h +++ b/torchvision/csrc/ops/roi_align.h @@ -1,58 +1,13 @@ #pragma once -#include +#include "../StableABICompat.h" #include "../macros.h" namespace vision { namespace ops { -VISION_API at::Tensor roi_align( - const at::Tensor& input, - const at::Tensor& rois, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, - int64_t sampling_ratio, - bool aligned); - -VISION_API at::Tensor roi_align_symint( - const at::Tensor& input, - const at::Tensor& rois, - double spatial_scale, - c10::SymInt pooled_height, - c10::SymInt pooled_width, - int64_t sampling_ratio, - bool aligned); - -namespace detail { - -at::Tensor _roi_align_backward( - const at::Tensor& grad, - const at::Tensor& rois, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, - int64_t batch_size, - int64_t channels, - int64_t height, - int64_t width, - int64_t sampling_ratio, - bool aligned); - -at::Tensor _roi_align_backward_symint( - const at::Tensor& grad, - const at::Tensor& rois, - double spatial_scale, - c10::SymInt pooled_height, - c10::SymInt pooled_width, - c10::SymInt batch_size, - c10::SymInt channels, - c10::SymInt height, - c10::SymInt width, - int64_t sampling_ratio, - bool aligned); - -} // namespace detail +// Note: With stable ABI, ops are called directly via torch.ops.torchvision.* +// This header is kept for backwards compatibility but the C++ API is deprecated. } // namespace ops } // namespace vision diff --git a/torchvision/csrc/ops/roi_pool.cpp b/torchvision/csrc/ops/roi_pool.cpp index 20ca3ca91e7..2b18f5c7595 100644 --- a/torchvision/csrc/ops/roi_pool.cpp +++ b/torchvision/csrc/ops/roi_pool.cpp @@ -1,102 +1,22 @@ #include "roi_pool.h" -#include -#include -#include +#include namespace vision { namespace ops { -std::tuple roi_pool( - const at::Tensor& input, - const at::Tensor& rois, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width) { - C10_LOG_API_USAGE_ONCE("torchvision.csrc.ops.roi_pool.roi_pool"); - static auto op = c10::Dispatcher::singleton() - .findSchemaOrThrow("torchvision::roi_pool", "") - .typed(); - return op.call(input, rois, spatial_scale, pooled_height, pooled_width); -} - -std::tuple roi_pool_symint( - const at::Tensor& input, - const at::Tensor& rois, - double spatial_scale, - c10::SymInt pooled_height, - c10::SymInt pooled_width) { - C10_LOG_API_USAGE_ONCE("torchvision.csrc.ops.roi_pool.roi_pool"); - static auto op = c10::Dispatcher::singleton() - .findSchemaOrThrow("torchvision::roi_pool", "") - .typed(); - return op.call(input, rois, spatial_scale, pooled_height, pooled_width); -} - -namespace detail { - -at::Tensor _roi_pool_backward( - const at::Tensor& grad, - const at::Tensor& rois, - const at::Tensor& argmax, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, - int64_t batch_size, - int64_t channels, - int64_t height, - int64_t width) { - static auto op = c10::Dispatcher::singleton() - .findSchemaOrThrow("torchvision::_roi_pool_backward", "") - .typed(); - return op.call( - grad, - rois, - argmax, - spatial_scale, - pooled_height, - pooled_width, - batch_size, - channels, - height, - width); -} - -at::Tensor _roi_pool_backward_symint( - const at::Tensor& grad, - const at::Tensor& rois, - const at::Tensor& argmax, - double spatial_scale, - c10::SymInt pooled_height, - c10::SymInt pooled_width, - c10::SymInt batch_size, - c10::SymInt channels, - c10::SymInt height, - c10::SymInt width) { - static auto op = c10::Dispatcher::singleton() - .findSchemaOrThrow("torchvision::_roi_pool_backward", "") - .typed(); - return op.call( - grad, - rois, - argmax, - spatial_scale, - pooled_height, - pooled_width, - batch_size, - channels, - height, - width); -} - -} // namespace detail - -TORCH_LIBRARY_FRAGMENT(torchvision, m) { - m.def(TORCH_SELECTIVE_SCHEMA( - "torchvision::roi_pool(Tensor input, Tensor rois, float spatial_scale, SymInt pooled_height, SymInt pooled_width) -> (Tensor, Tensor)")); - m.def(TORCH_SELECTIVE_SCHEMA( - "torchvision::_roi_pool_backward(Tensor grad, Tensor rois, Tensor argmax, float spatial_scale, SymInt pooled_height, SymInt pooled_width, SymInt batch_size, SymInt channels, SymInt height, SymInt width) -> Tensor")); -} +// Note: With stable ABI, ops are called directly via torch.ops.torchvision.* +// The dispatcher wrapper functions are no longer needed. } // namespace ops } // namespace vision + +STABLE_TORCH_LIBRARY_FRAGMENT(torchvision, m) { + m.def( + "roi_pool(Tensor input, Tensor rois, float spatial_scale, " + "SymInt pooled_height, SymInt pooled_width) -> (Tensor, Tensor)"); + m.def( + "_roi_pool_backward(Tensor grad, Tensor rois, Tensor argmax, " + "float spatial_scale, SymInt pooled_height, SymInt pooled_width, " + "SymInt batch_size, SymInt channels, SymInt height, SymInt width) -> Tensor"); +} diff --git a/torchvision/csrc/ops/roi_pool.h b/torchvision/csrc/ops/roi_pool.h index e2133240f4f..b049ce9d6fa 100644 --- a/torchvision/csrc/ops/roi_pool.h +++ b/torchvision/csrc/ops/roi_pool.h @@ -1,52 +1,13 @@ #pragma once -#include +#include "../StableABICompat.h" #include "../macros.h" namespace vision { namespace ops { -VISION_API std::tuple roi_pool( - const at::Tensor& input, - const at::Tensor& rois, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width); - -VISION_API std::tuple roi_pool_symint( - const at::Tensor& input, - const at::Tensor& rois, - double spatial_scale, - c10::SymInt pooled_height, - c10::SymInt pooled_width); - -namespace detail { - -at::Tensor _roi_pool_backward( - const at::Tensor& grad, - const at::Tensor& rois, - const at::Tensor& argmax, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, - int64_t batch_size, - int64_t channels, - int64_t height, - int64_t width); - -at::Tensor _roi_pool_backward_symint( - const at::Tensor& grad, - const at::Tensor& rois, - const at::Tensor& argmax, - double spatial_scale, - c10::SymInt pooled_height, - c10::SymInt pooled_width, - c10::SymInt batch_size, - c10::SymInt channels, - c10::SymInt height, - c10::SymInt width); - -} // namespace detail +// Note: With stable ABI, ops are called directly via torch.ops.torchvision.* +// This header is kept for backwards compatibility but the C++ API is deprecated. } // namespace ops } // namespace vision diff --git a/torchvision/ops/ps_roi_align.py b/torchvision/ops/ps_roi_align.py index 82809b8f888..8624d96f0c2 100644 --- a/torchvision/ops/ps_roi_align.py +++ b/torchvision/ops/ps_roi_align.py @@ -8,6 +8,61 @@ from ._utils import check_roi_boxes_shape, convert_boxes_to_roi_format +class _PSRoIAlignFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + input: Tensor, + rois: Tensor, + spatial_scale: float, + pooled_height: int, + pooled_width: int, + sampling_ratio: int, + ) -> Tensor: + ctx.spatial_scale = spatial_scale + ctx.pooled_height = pooled_height + ctx.pooled_width = pooled_width + ctx.sampling_ratio = sampling_ratio + ctx.input_shape = input.shape + output, channel_mapping = torch.ops.torchvision.ps_roi_align( + input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio + ) + ctx.save_for_backward(rois, channel_mapping) + return output + + @staticmethod + def backward(ctx, grad_output: Tensor): + rois, channel_mapping = ctx.saved_tensors + batch_size, channels, height, width = ctx.input_shape + grad_input = torch.ops.torchvision._ps_roi_align_backward( + grad_output, + rois, + channel_mapping, + ctx.spatial_scale, + ctx.pooled_height, + ctx.pooled_width, + ctx.sampling_ratio, + batch_size, + channels, + height, + width, + ) + return grad_input, None, None, None, None, None + + +def _ps_roi_align_autograd( + input: Tensor, + rois: Tensor, + spatial_scale: float, + pooled_height: int, + pooled_width: int, + sampling_ratio: int, +) -> Tensor: + return _PSRoIAlignFunction.apply( + input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio + ) + + @torch.fx.wrap def ps_roi_align( input: Tensor, @@ -53,10 +108,9 @@ def ps_roi_align( output_size = _pair(output_size) if not isinstance(rois, torch.Tensor): rois = convert_boxes_to_roi_format(rois) - output, _ = torch.ops.torchvision.ps_roi_align( + return _ps_roi_align_autograd( input, rois, spatial_scale, output_size[0], output_size[1], sampling_ratio ) - return output class PSRoIAlign(nn.Module): diff --git a/torchvision/ops/ps_roi_pool.py b/torchvision/ops/ps_roi_pool.py index 15292dcad97..5d5fab5fa55 100644 --- a/torchvision/ops/ps_roi_pool.py +++ b/torchvision/ops/ps_roi_pool.py @@ -8,6 +8,55 @@ from ._utils import check_roi_boxes_shape, convert_boxes_to_roi_format +class _PSRoIPoolFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + input: Tensor, + rois: Tensor, + spatial_scale: float, + pooled_height: int, + pooled_width: int, + ) -> Tensor: + ctx.spatial_scale = spatial_scale + ctx.pooled_height = pooled_height + ctx.pooled_width = pooled_width + ctx.input_shape = input.shape + output, channel_mapping = torch.ops.torchvision.ps_roi_pool( + input, rois, spatial_scale, pooled_height, pooled_width + ) + ctx.save_for_backward(rois, channel_mapping) + return output + + @staticmethod + def backward(ctx, grad_output: Tensor): + rois, channel_mapping = ctx.saved_tensors + batch_size, channels, height, width = ctx.input_shape + grad_input = torch.ops.torchvision._ps_roi_pool_backward( + grad_output, + rois, + channel_mapping, + ctx.spatial_scale, + ctx.pooled_height, + ctx.pooled_width, + batch_size, + channels, + height, + width, + ) + return grad_input, None, None, None, None + + +def _ps_roi_pool_autograd( + input: Tensor, + rois: Tensor, + spatial_scale: float, + pooled_height: int, + pooled_width: int, +) -> Tensor: + return _PSRoIPoolFunction.apply(input, rois, spatial_scale, pooled_height, pooled_width) + + @torch.fx.wrap def ps_roi_pool( input: Tensor, @@ -47,8 +96,7 @@ def ps_roi_pool( output_size = _pair(output_size) if not isinstance(rois, torch.Tensor): rois = convert_boxes_to_roi_format(rois) - output, _ = torch.ops.torchvision.ps_roi_pool(input, rois, spatial_scale, output_size[0], output_size[1]) - return output + return _ps_roi_pool_autograd(input, rois, spatial_scale, output_size[0], output_size[1]) class PSRoIPool(nn.Module): diff --git a/torchvision/ops/roi_align.py b/torchvision/ops/roi_align.py index 25214d6b130..fef1d269e2c 100644 --- a/torchvision/ops/roi_align.py +++ b/torchvision/ops/roi_align.py @@ -200,6 +200,64 @@ def from_K(t): return output +class _RoIAlignFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + input: Tensor, + rois: Tensor, + spatial_scale: float, + pooled_height: int, + pooled_width: int, + sampling_ratio: int, + aligned: bool, + ) -> Tensor: + ctx.save_for_backward(rois) + ctx.spatial_scale = spatial_scale + ctx.pooled_height = pooled_height + ctx.pooled_width = pooled_width + ctx.sampling_ratio = sampling_ratio + ctx.aligned = aligned + ctx.input_shape = input.shape + output = torch.ops.torchvision.roi_align( + input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned + ) + return output + + @staticmethod + def backward(ctx, grad_output: Tensor): + (rois,) = ctx.saved_tensors + batch_size, channels, height, width = ctx.input_shape + grad_input = torch.ops.torchvision._roi_align_backward( + grad_output, + rois, + ctx.spatial_scale, + ctx.pooled_height, + ctx.pooled_width, + batch_size, + channels, + height, + width, + ctx.sampling_ratio, + ctx.aligned, + ) + return grad_input, None, None, None, None, None, None + + +def _roi_align_autograd( + input: Tensor, + rois: Tensor, + spatial_scale: float, + pooled_height: int, + pooled_width: int, + sampling_ratio: int, + aligned: bool, +) -> Tensor: + return _RoIAlignFunction.apply( + input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned + ) + + @torch.fx.wrap def roi_align( input: Tensor, @@ -255,7 +313,7 @@ def roi_align( ) and is_compile_supported(input.device.type): return _roi_align(input, rois, spatial_scale, output_size[0], output_size[1], sampling_ratio, aligned) _assert_has_ops() - return torch.ops.torchvision.roi_align( + return _roi_align_autograd( input, rois, spatial_scale, output_size[0], output_size[1], sampling_ratio, aligned ) diff --git a/torchvision/ops/roi_pool.py b/torchvision/ops/roi_pool.py index 5f4bb95c0f3..c0249921a3f 100644 --- a/torchvision/ops/roi_pool.py +++ b/torchvision/ops/roi_pool.py @@ -11,6 +11,56 @@ from ._utils import check_roi_boxes_shape, convert_boxes_to_roi_format +class _RoIPoolFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + input: Tensor, + rois: Tensor, + spatial_scale: float, + pooled_height: int, + pooled_width: int, + ) -> Tensor: + ctx.save_for_backward(rois) + ctx.spatial_scale = spatial_scale + ctx.pooled_height = pooled_height + ctx.pooled_width = pooled_width + ctx.input_shape = input.shape + output, argmax = torch.ops.torchvision.roi_pool( + input, rois, spatial_scale, pooled_height, pooled_width + ) + ctx.save_for_backward(rois, argmax) + return output + + @staticmethod + def backward(ctx, grad_output: Tensor): + rois, argmax = ctx.saved_tensors + batch_size, channels, height, width = ctx.input_shape + grad_input = torch.ops.torchvision._roi_pool_backward( + grad_output, + rois, + argmax, + ctx.spatial_scale, + ctx.pooled_height, + ctx.pooled_width, + batch_size, + channels, + height, + width, + ) + return grad_input, None, None, None, None + + +def _roi_pool_autograd( + input: Tensor, + rois: Tensor, + spatial_scale: float, + pooled_height: int, + pooled_width: int, +) -> Tensor: + return _RoIPoolFunction.apply(input, rois, spatial_scale, pooled_height, pooled_width) + + @torch.fx.wrap def roi_pool( input: Tensor, @@ -49,8 +99,7 @@ def roi_pool( output_size = _pair(output_size) if not isinstance(rois, torch.Tensor): rois = convert_boxes_to_roi_format(rois) - output, _ = torch.ops.torchvision.roi_pool(input, rois, spatial_scale, output_size[0], output_size[1]) - return output + return _roi_pool_autograd(input, rois, spatial_scale, output_size[0], output_size[1]) class RoIPool(nn.Module):