From e8a1c9d51f957be84816c14a3ec0985d11b09176 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Fri, 23 Jan 2026 17:05:53 +0000 Subject: [PATCH 1/9] Port io/image to stable ABI --- setup.py | 43 +- torchvision/csrc/StableABICompat.h | 558 ++++++++++++++++++ torchvision/csrc/io/image/common.cpp | 15 +- torchvision/csrc/io/image/common.h | 4 +- torchvision/csrc/io/image/cpu/decode_gif.cpp | 58 +- torchvision/csrc/io/image/cpu/decode_gif.h | 4 +- .../csrc/io/image/cpu/decode_image.cpp | 25 +- torchvision/csrc/io/image/cpu/decode_image.h | 6 +- torchvision/csrc/io/image/cpu/decode_jpeg.cpp | 35 +- torchvision/csrc/io/image/cpu/decode_jpeg.h | 10 +- torchvision/csrc/io/image/cpu/decode_png.cpp | 46 +- torchvision/csrc/io/image/cpu/decode_png.h | 6 +- torchvision/csrc/io/image/cpu/decode_webp.cpp | 40 +- torchvision/csrc/io/image/cpu/decode_webp.h | 6 +- torchvision/csrc/io/image/cpu/encode_jpeg.cpp | 42 +- torchvision/csrc/io/image/cpu/encode_jpeg.h | 6 +- torchvision/csrc/io/image/cpu/encode_png.cpp | 43 +- torchvision/csrc/io/image/cpu/encode_png.h | 6 +- torchvision/csrc/io/image/cpu/exif.h | 24 +- .../csrc/io/image/cpu/read_write_file.cpp | 47 +- .../csrc/io/image/cpu/read_write_file.h | 6 +- .../csrc/io/image/cuda/decode_jpegs_cuda.cpp | 168 +++--- .../csrc/io/image/cuda/decode_jpegs_cuda.h | 16 +- .../io/image/cuda/encode_decode_jpegs_cuda.h | 34 +- .../csrc/io/image/cuda/encode_jpegs_cuda.cpp | 95 +-- .../csrc/io/image/cuda/encode_jpegs_cuda.h | 10 +- torchvision/csrc/io/image/image.cpp | 68 ++- torchvision/csrc/io/image/image.h | 4 +- 28 files changed, 1019 insertions(+), 406 deletions(-) create mode 100644 torchvision/csrc/StableABICompat.h diff --git a/setup.py b/setup.py index 6181007924e..223d2419b7a 100644 --- a/setup.py +++ b/setup.py @@ -293,16 +293,21 @@ def make_image_extension(): libraries = [] define_macros, extra_compile_args = get_macros_and_flags() + # PyTorch Stable ABI target version (2.11) - required for string handling in TORCH_BOX + define_macros += [("TORCH_TARGET_VERSION", "0x020b000000000000")] image_dir = CSRS_DIR / "io/image" sources = list(image_dir.glob("*.cpp")) + list(image_dir.glob("cpu/*.cpp")) + list(image_dir.glob("cpu/giflib/*.c")) - if IS_ROCM: - sources += list(image_dir.glob("hip/*.cpp")) - # we need to exclude this in favor of the hipified source - sources.remove(image_dir / "image.cpp") - else: - sources += list(image_dir.glob("cuda/*.cpp")) + # Note: CUDA sources are excluded when building with stable ABI (TORCH_TARGET_VERSION) + # because the stable ABI doesn't expose raw CUDA streams needed by nvJPEG. + # When stable ABI CUDA support is added to PyTorch, this can be re-enabled. + # if IS_ROCM: + # sources += list(image_dir.glob("hip/*.cpp")) + # # we need to exclude this in favor of the hipified source + # sources.remove(image_dir / "image.cpp") + # else: + # sources += list(image_dir.glob("cuda/*.cpp")) Extension = CppExtension @@ -350,18 +355,20 @@ def make_image_extension(): else: warnings.warn("Building torchvision without WEBP support") - if USE_NVJPEG and (torch.cuda.is_available() or FORCE_CUDA): - nvjpeg_found = CUDA_HOME is not None and (Path(CUDA_HOME) / "include/nvjpeg.h").exists() - - if nvjpeg_found: - print("Building torchvision with NVJPEG image support") - libraries.append("nvjpeg") - define_macros += [("NVJPEG_FOUND", 1)] - Extension = CUDAExtension - else: - warnings.warn("Building torchvision without NVJPEG support") - elif USE_NVJPEG: - warnings.warn("Building torchvision without NVJPEG support") + # NVJPEG is disabled when building with stable ABI (TORCH_TARGET_VERSION) + # because the stable ABI doesn't expose raw CUDA streams needed by nvJPEG. + # if USE_NVJPEG and (torch.cuda.is_available() or FORCE_CUDA): + # nvjpeg_found = CUDA_HOME is not None and (Path(CUDA_HOME) / "include/nvjpeg.h").exists() + # + # if nvjpeg_found: + # print("Building torchvision with NVJPEG image support") + # libraries.append("nvjpeg") + # define_macros += [("NVJPEG_FOUND", 1)] + # Extension = CUDAExtension + # else: + # warnings.warn("Building torchvision without NVJPEG support") + # elif USE_NVJPEG: + # warnings.warn("Building torchvision without NVJPEG support") return Extension( name="torchvision.image", diff --git a/torchvision/csrc/StableABICompat.h b/torchvision/csrc/StableABICompat.h new file mode 100644 index 00000000000..31eb5b1cd74 --- /dev/null +++ b/torchvision/csrc/StableABICompat.h @@ -0,0 +1,558 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +// =========================================================================== +// PyTorch Stable ABI Compatibility Header for TorchVision +// =========================================================================== +// +// This header provides compatibility types and macros for using PyTorch's +// stable ABI API. It replaces the standard PyTorch C++ APIs (torch::, at::, +// c10::) with their stable ABI equivalents. +// +// Target PyTorch version: 2.11+ +// +// Note: TORCH_TARGET_VERSION is set to 0x020b000000000000 (PyTorch 2.11) in +// CMakeLists.txt. This ensures we only use stable ABI features available in +// PyTorch 2.11+, providing forward compatibility when building against newer +// PyTorch versions. + +// Include stable ABI headers +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +// =========================================================================== +// Error Handling Macro +// =========================================================================== +// Replacement for TORCH_CHECK() that works with stable ABI. +// Uses STD_TORCH_CHECK from the stable ABI headers. +// Note: Unlike TORCH_CHECK, this always requires a message argument. + +#define VISION_CHECK(cond, ...) STD_TORCH_CHECK(cond, __VA_ARGS__) + +// =========================================================================== +// Type Aliases +// =========================================================================== +// Convenient aliases for stable ABI types in vision namespace + +namespace vision { +namespace stable { + +// Tensor types +using Tensor = torch::stable::Tensor; + +// Device types +using Device = torch::stable::Device; +using DeviceType = torch::headeronly::DeviceType; +using DeviceIndex = torch::stable::accelerator::DeviceIndex; + +// Scalar types (dtype) +using ScalarType = torch::headeronly::ScalarType; + +// DeviceGuard for CUDA context management +using DeviceGuard = torch::stable::accelerator::DeviceGuard; + +// Array reference type for sizes/strides +using IntArrayRef = torch::headeronly::IntHeaderOnlyArrayRef; + +// Layout and MemoryFormat +using Layout = torch::headeronly::Layout; +using MemoryFormat = torch::headeronly::MemoryFormat; + +// =========================================================================== +// Constants +// =========================================================================== + +// Device type constants +constexpr auto kCPU = torch::headeronly::DeviceType::CPU; +constexpr auto kCUDA = torch::headeronly::DeviceType::CUDA; + +// Scalar type constants (equivalents of at::kUInt8, at::kFloat32, etc.) +constexpr auto kByte = torch::headeronly::ScalarType::Byte; +constexpr auto kChar = torch::headeronly::ScalarType::Char; +constexpr auto kShort = torch::headeronly::ScalarType::Short; +constexpr auto kInt = torch::headeronly::ScalarType::Int; +constexpr auto kLong = torch::headeronly::ScalarType::Long; +constexpr auto kHalf = torch::headeronly::ScalarType::Half; +constexpr auto kFloat = torch::headeronly::ScalarType::Float; +constexpr auto kDouble = torch::headeronly::ScalarType::Double; +constexpr auto kBool = torch::headeronly::ScalarType::Bool; +constexpr auto kUInt16 = torch::headeronly::ScalarType::UInt16; + +// Layout constants +constexpr auto kStrided = torch::headeronly::Layout::Strided; + +// =========================================================================== +// Helper Functions - Tensor Creation +// =========================================================================== + +// Stable version of at::empty() +inline Tensor empty( + std::initializer_list sizes, + ScalarType dtype, + Device device) { + std::vector sizesVec(sizes); + return torch::stable::empty( + IntArrayRef(sizesVec.data(), sizesVec.size()), + dtype, + kStrided, + device); +} + +// Overload taking a vector +inline Tensor empty( + const std::vector& sizes, + ScalarType dtype, + Device device) { + return torch::stable::empty( + IntArrayRef(sizes.data(), sizes.size()), + dtype, + kStrided, + device); +} + +// Helper to create CPU tensors +inline Tensor emptyCPU( + std::initializer_list sizes, + ScalarType dtype) { + return empty(sizes, dtype, Device(kCPU)); +} + +// Stable version of at::zeros() - creates via empty then zeros +inline Tensor zeros( + std::initializer_list sizes, + ScalarType dtype, + Device device) { + std::vector sizesVec(sizes); + auto tensor = torch::stable::empty( + IntArrayRef(sizesVec.data(), sizesVec.size()), + dtype, + kStrided, + device); + // Use dispatcher to call aten::zero_ + std::array stack{torch::stable::detail::from(tensor)}; + TORCH_ERROR_CODE_CHECK( + torch_call_dispatcher("aten::zero_", "", stack.data(), TORCH_ABI_VERSION)); + return torch::stable::detail::to(stack[0]); +} + +// =========================================================================== +// Helper Functions - Tensor Operations +// =========================================================================== + +// Stable version of tensor.copy_(src) +inline void copy_(Tensor& dst, const Tensor& src) { + torch::stable::copy_(dst, src); +} + +// Stable version of tensor.to(device) +inline Tensor to(const Tensor& tensor, const Device& device) { + return torch::stable::to(tensor, device); +} + +// Stable version of tensor.narrow(dim, start, length) +inline Tensor narrow(Tensor tensor, int64_t dim, int64_t start, int64_t length) { + return torch::stable::narrow(tensor, dim, start, length); +} + +// Note: contiguous() is provided by torch::stable::contiguous() directly +// Do NOT define a vision::stable::contiguous wrapper as it conflicts with the +// default parameter in torch::stable::contiguous(tensor, memory_format = Contiguous) + +// Stable version of tensor.select(dim, index) - from torch::stable::select +inline Tensor select(const Tensor& tensor, int64_t dim, int64_t index) { + return torch::stable::select(tensor, dim, index); +} + +// Helper for tensor.is_contiguous() +inline bool is_contiguous(const Tensor& tensor) { + return tensor.is_contiguous(); +} + +// =========================================================================== +// Helper Functions - Dispatcher Wrappers +// =========================================================================== + +// Stable version of tensor.sort() - returns (values, indices) +// Uses dispatcher to call aten::sort.stable +inline std::pair sort( + const Tensor& tensor, + int64_t dim, + bool descending) { + std::array stack{ + torch::stable::detail::from(tensor), + torch::stable::detail::from(true), // stable sort + torch::stable::detail::from(dim), + torch::stable::detail::from(descending)}; + TORCH_ERROR_CODE_CHECK(torch_call_dispatcher( + "aten::sort", "stable", stack.data(), TORCH_ABI_VERSION)); + return std::make_pair( + torch::stable::detail::to(stack[0]), + torch::stable::detail::to(stack[1])); +} + +// Stable version of argsort - returns indices only +inline Tensor argsort(const Tensor& tensor, int64_t dim, bool descending) { + auto [values, indices] = sort(tensor, dim, descending); + return indices; +} + +// Stable version of tensor.permute() +inline Tensor permute( + const Tensor& tensor, + std::initializer_list dims) { + std::vector dimsVec(dims); + std::array stack{ + torch::stable::detail::from(tensor), + torch::stable::detail::from( + IntArrayRef(dimsVec.data(), dimsVec.size()))}; + TORCH_ERROR_CODE_CHECK( + torch_call_dispatcher("aten::permute", "", stack.data(), TORCH_ABI_VERSION)); + return torch::stable::detail::to(stack[0]); +} + +// Stable version of at::cat() - concatenates tensors along a dimension +inline Tensor cat(const std::vector& tensors, int64_t dim = 0) { + std::array stack{ + torch::stable::detail::from(tensors), + torch::stable::detail::from(dim)}; + TORCH_ERROR_CODE_CHECK( + torch_call_dispatcher("aten::cat", "", stack.data(), TORCH_ABI_VERSION)); + return torch::stable::detail::to(stack[0]); +} + +// Stable version of at::clamp() +inline Tensor clamp( + const Tensor& tensor, + double min_val, + double max_val) { + std::array stack{ + torch::stable::detail::from(tensor), + torch::stable::detail::from(min_val), + torch::stable::detail::from(max_val)}; + TORCH_ERROR_CODE_CHECK( + torch_call_dispatcher("aten::clamp", "", stack.data(), TORCH_ABI_VERSION)); + return torch::stable::detail::to(stack[0]); +} + +// Stable version of at::floor() +inline Tensor floor(const Tensor& tensor) { + std::array stack{torch::stable::detail::from(tensor)}; + TORCH_ERROR_CODE_CHECK( + torch_call_dispatcher("aten::floor", "", stack.data(), TORCH_ABI_VERSION)); + return torch::stable::detail::to(stack[0]); +} + +// Stable version of at::ceil() +inline Tensor ceil(const Tensor& tensor) { + std::array stack{torch::stable::detail::from(tensor)}; + TORCH_ERROR_CODE_CHECK( + torch_call_dispatcher("aten::ceil", "", stack.data(), TORCH_ABI_VERSION)); + return torch::stable::detail::to(stack[0]); +} + +// Stable version of tensor.reshape() +inline Tensor reshape(const Tensor& tensor, const std::vector& shape) { + std::array stack{ + torch::stable::detail::from(tensor), + torch::stable::detail::from(IntArrayRef(shape.data(), shape.size()))}; + TORCH_ERROR_CODE_CHECK( + torch_call_dispatcher("aten::reshape", "", stack.data(), TORCH_ABI_VERSION)); + return torch::stable::detail::to(stack[0]); +} + +// Stable version of tensor.view() +inline Tensor view(const Tensor& tensor, const std::vector& shape) { + std::array stack{ + torch::stable::detail::from(tensor), + torch::stable::detail::from(IntArrayRef(shape.data(), shape.size()))}; + TORCH_ERROR_CODE_CHECK( + torch_call_dispatcher("aten::view", "", stack.data(), TORCH_ABI_VERSION)); + return torch::stable::detail::to(stack[0]); +} + +// Stable version of tensor.flatten(start_dim) +inline Tensor flatten(const Tensor& tensor, int64_t start_dim) { + std::array stack{ + torch::stable::detail::from(tensor), + torch::stable::detail::from(start_dim), + torch::stable::detail::from(static_cast(-1))}; + TORCH_ERROR_CODE_CHECK( + torch_call_dispatcher("aten::flatten", "using_ints", stack.data(), TORCH_ABI_VERSION)); + return torch::stable::detail::to(stack[0]); +} + +// Note: transpose() is provided by torch::stable::transpose() + +// Stable version of tensor.zero_() +inline Tensor& zero_(Tensor& tensor) { + std::array stack{torch::stable::detail::from(tensor)}; + TORCH_ERROR_CODE_CHECK( + torch_call_dispatcher("aten::zero_", "", stack.data(), TORCH_ABI_VERSION)); + tensor = torch::stable::detail::to(stack[0]); + return tensor; +} + +// Stable version of tensor.addmm_(mat1, mat2) +inline Tensor& addmm_(Tensor& self, const Tensor& mat1, const Tensor& mat2) { + std::array stack{ + torch::stable::detail::from(self), + torch::stable::detail::from(mat1), + torch::stable::detail::from(mat2), + torch::stable::detail::from(1.0), // beta + torch::stable::detail::from(1.0)}; // alpha + TORCH_ERROR_CODE_CHECK( + torch_call_dispatcher("aten::addmm_", "", stack.data(), TORCH_ABI_VERSION)); + self = torch::stable::detail::to(stack[0]); + return self; +} + +// Stable version of at::zeros with vector sizes +inline Tensor zeros( + const std::vector& sizes, + ScalarType dtype, + Device device) { + auto tensor = torch::stable::empty( + IntArrayRef(sizes.data(), sizes.size()), + dtype, + kStrided, + device); + std::array stack{torch::stable::detail::from(tensor)}; + TORCH_ERROR_CODE_CHECK( + torch_call_dispatcher("aten::zero_", "", stack.data(), TORCH_ABI_VERSION)); + return torch::stable::detail::to(stack[0]); +} + +// Stable version of at::zeros_like() +inline Tensor zeros_like(const Tensor& tensor) { + // Use dispatcher to call aten::zeros_like + std::array stack{torch::stable::detail::from(tensor)}; + TORCH_ERROR_CODE_CHECK( + torch_call_dispatcher("aten::zeros_like", "", stack.data(), TORCH_ABI_VERSION)); + return torch::stable::detail::to(stack[0]); +} + +// Stable version of at::ones_like() * val +inline Tensor ones_like_times(const Tensor& tensor, const Tensor& val) { + std::array stack{ + torch::stable::detail::from(tensor), + torch::stable::detail::from(val)}; + TORCH_ERROR_CODE_CHECK( + torch_call_dispatcher("aten::mul", "Tensor", stack.data(), TORCH_ABI_VERSION)); + return torch::stable::detail::to(stack[0]); +} + +// Stable version of tensor.sum(dims) +inline Tensor sum(const Tensor& tensor, const std::vector& dims) { + std::array stack{ + torch::stable::detail::from(tensor), + torch::stable::detail::from(IntArrayRef(dims.data(), dims.size())), + torch::stable::detail::from(false)}; // keepdim + TORCH_ERROR_CODE_CHECK( + torch_call_dispatcher("aten::sum", "dim_IntList", stack.data(), TORCH_ABI_VERSION)); + return torch::stable::detail::to(stack[0]); +} + +// Stable version of tensor + other (broadcasting add) +inline Tensor add(const Tensor& self, const Tensor& other) { + std::array stack{ + torch::stable::detail::from(self), + torch::stable::detail::from(other), + torch::stable::detail::from(1.0)}; // alpha + TORCH_ERROR_CODE_CHECK( + torch_call_dispatcher("aten::add", "Tensor", stack.data(), TORCH_ABI_VERSION)); + return torch::stable::detail::to(stack[0]); +} + +// Stable version of tensor.index_select(dim, index) +inline Tensor index_select(const Tensor& tensor, int64_t dim, const Tensor& index) { + std::array stack{ + torch::stable::detail::from(tensor), + torch::stable::detail::from(dim), + torch::stable::detail::from(index)}; + TORCH_ERROR_CODE_CHECK( + torch_call_dispatcher("aten::index_select", "", stack.data(), TORCH_ABI_VERSION)); + return torch::stable::detail::to(stack[0]); +} + +// Stable version of tensor.masked_select(mask) +inline Tensor masked_select(const Tensor& tensor, const Tensor& mask) { + std::array stack{ + torch::stable::detail::from(tensor), + torch::stable::detail::from(mask)}; + TORCH_ERROR_CODE_CHECK( + torch_call_dispatcher("aten::masked_select", "", stack.data(), TORCH_ABI_VERSION)); + return torch::stable::detail::to(stack[0]); +} + +// Stable version of tensor.flip(dims) +inline Tensor flip(const Tensor& tensor, std::initializer_list dims) { + std::vector dimsVec(dims); + std::array stack{ + torch::stable::detail::from(tensor), + torch::stable::detail::from(IntArrayRef(dimsVec.data(), dimsVec.size()))}; + TORCH_ERROR_CODE_CHECK( + torch_call_dispatcher("aten::flip", "", stack.data(), TORCH_ABI_VERSION)); + return torch::stable::detail::to(stack[0]); +} + +// Stable version of torch.from_file() - read file into tensor +inline Tensor from_file( + const std::string& filename, + bool shared, + int64_t size, + ScalarType dtype) { + std::array stack{ + torch::stable::detail::from(filename), + torch::stable::detail::from(shared), + torch::stable::detail::from(size), + torch::stable::detail::from(dtype)}; + TORCH_ERROR_CODE_CHECK( + torch_call_dispatcher("aten::from_file", "", stack.data(), TORCH_ABI_VERSION)); + return torch::stable::detail::to(stack[0]); +} + +// Note: squeeze() is provided by torch::stable::squeeze() + +// Stable version of tensor.unsqueeze(dim) +inline Tensor unsqueeze(const Tensor& tensor, int64_t dim) { + std::array stack{ + torch::stable::detail::from(tensor), + torch::stable::detail::from(dim)}; + TORCH_ERROR_CODE_CHECK( + torch_call_dispatcher("aten::unsqueeze", "", stack.data(), TORCH_ABI_VERSION)); + return torch::stable::detail::to(stack[0]); +} + +// Stable version of tensor.clone() +inline Tensor clone(const Tensor& tensor) { + std::array stack{torch::stable::detail::from(tensor)}; + TORCH_ERROR_CODE_CHECK( + torch_call_dispatcher("aten::clone", "", stack.data(), TORCH_ABI_VERSION)); + return torch::stable::detail::to(stack[0]); +} + +// Create empty tensor with ChannelsLast memory format +inline Tensor emptyCPUChannelsLast(const std::vector& sizes, ScalarType dtype) { + std::array stack{ + torch::stable::detail::from(IntArrayRef(sizes.data(), sizes.size())), + torch::stable::detail::from(dtype), + torch::stable::detail::from(kStrided), + torch::stable::detail::from(Device(kCPU)), + torch::stable::detail::from(MemoryFormat::ChannelsLast)}; + TORCH_ERROR_CODE_CHECK( + torch_call_dispatcher("aten::empty", "memory_format", stack.data(), TORCH_ABI_VERSION)); + return torch::stable::detail::to(stack[0]); +} + +// =========================================================================== +// Helper Functions - Utility +// =========================================================================== + +// Helper to get a human-readable name for a scalar type +inline const char* scalarTypeName(ScalarType dtype) { + switch (dtype) { + case ScalarType::Byte: + return "uint8"; + case ScalarType::Char: + return "int8"; + case ScalarType::Short: + return "int16"; + case ScalarType::Int: + return "int32"; + case ScalarType::Long: + return "int64"; + case ScalarType::Half: + return "float16"; + case ScalarType::Float: + return "float32"; + case ScalarType::Double: + return "float64"; + case ScalarType::Bool: + return "bool"; + default: + return "unknown"; + } +} + +// Helper to get a human-readable name for a device type +inline const char* deviceTypeName(DeviceType dtype) { + switch (dtype) { + case DeviceType::CPU: + return "cpu"; + case DeviceType::CUDA: + return "cuda"; + default: + return "unknown"; + } +} + +// Helper to convert IntArrayRef to a string for error messages +inline std::string intArrayRefToString(const IntArrayRef& arr) { + std::string result = "["; + for (size_t i = 0; i < arr.size(); ++i) { + if (i > 0) { + result += ", "; + } + result += std::to_string(arr[i]); + } + result += "]"; + return result; +} + +} // namespace stable + +// =========================================================================== +// CUDA Accumulator Type Traits +// =========================================================================== +// Replacement for at::acc_type that works without ATen headers +// Used in CUDA device code for accumulating sums with higher precision + +#ifdef __CUDACC__ +namespace cuda { + +// Primary template - default accumulator type is the same as input +template +struct acc_type { + using type = T; +}; + +// Specializations for CUDA: Half uses float for accumulation +template <> +struct acc_type<__half, true> { + using type = float; +}; + +// Float and double use themselves +template <> +struct acc_type { + using type = float; +}; + +template <> +struct acc_type { + using type = double; +}; + +// Helper alias for convenience +template +using acc_type_t = typename acc_type::type; + +} // namespace cuda +#endif // __CUDACC__ + +} // namespace vision diff --git a/torchvision/csrc/io/image/common.cpp b/torchvision/csrc/io/image/common.cpp index 7743961a09d..4f8d3c42b2a 100644 --- a/torchvision/csrc/io/image/common.cpp +++ b/torchvision/csrc/io/image/common.cpp @@ -12,13 +12,16 @@ void* PyInit_image(void) { namespace vision { namespace image { -void validate_encoded_data(const torch::Tensor& encoded_data) { - TORCH_CHECK(encoded_data.is_contiguous(), "Input tensor must be contiguous."); - TORCH_CHECK( - encoded_data.dtype() == torch::kU8, +using namespace vision::stable; + +void validate_encoded_data(const Tensor& encoded_data) { + VISION_CHECK( + encoded_data.is_contiguous(), "Input tensor must be contiguous."); + VISION_CHECK( + encoded_data.scalar_type() == kByte, "Input tensor must have uint8 data type, got ", - encoded_data.dtype()); - TORCH_CHECK( + scalarTypeName(encoded_data.scalar_type())); + VISION_CHECK( encoded_data.dim() == 1 && encoded_data.numel() > 0, "Input tensor must be 1-dimensional and non-empty, got ", encoded_data.dim(), diff --git a/torchvision/csrc/io/image/common.h b/torchvision/csrc/io/image/common.h index d81acfda7d4..84645c9452a 100644 --- a/torchvision/csrc/io/image/common.h +++ b/torchvision/csrc/io/image/common.h @@ -1,7 +1,7 @@ #pragma once #include -#include +#include "../../StableABICompat.h" namespace vision { namespace image { @@ -14,7 +14,7 @@ const ImageReadMode IMAGE_READ_MODE_GRAY_ALPHA = 2; const ImageReadMode IMAGE_READ_MODE_RGB = 3; const ImageReadMode IMAGE_READ_MODE_RGB_ALPHA = 4; -void validate_encoded_data(const torch::Tensor& encoded_data); +void validate_encoded_data(const vision::stable::Tensor& encoded_data); bool should_this_return_rgb_or_rgba_let_me_know_in_the_comments_down_below_guys_see_you_in_the_next_video( ImageReadMode mode, diff --git a/torchvision/csrc/io/image/cpu/decode_gif.cpp b/torchvision/csrc/io/image/cpu/decode_gif.cpp index 93b0861c5da..67fa3119ce3 100644 --- a/torchvision/csrc/io/image/cpu/decode_gif.cpp +++ b/torchvision/csrc/io/image/cpu/decode_gif.cpp @@ -6,6 +6,8 @@ namespace vision { namespace image { +using namespace vision::stable; + typedef struct reader_helper_t { uint8_t const* encoded_data; // input tensor data pointer size_t encoded_data_size; // size of input tensor in bytes @@ -30,7 +32,7 @@ int read_from_tensor(GifFileType* gifFile, GifByteType* buf, int len) { return num_bytes_to_read; } -torch::Tensor decode_gif(const torch::Tensor& encoded_data) { +Tensor decode_gif(const Tensor& encoded_data) { // LibGif docs: https://giflib.sourceforge.net/intro.html // Refer over there for more details on the libgif API, API ref, and a // detailed description of the GIF format. @@ -57,13 +59,13 @@ torch::Tensor decode_gif(const torch::Tensor& encoded_data) { // If we do that, we'd have to make sure the buffers are never written to by // GIFLIB, otherwise we'd be overridding the tensor data. reader_helper_t reader_helper; - reader_helper.encoded_data = encoded_data.data_ptr(); + reader_helper.encoded_data = encoded_data.const_data_ptr(); reader_helper.encoded_data_size = encoded_data.numel(); reader_helper.num_bytes_read = 0; GifFileType* gifFile = DGifOpen(static_cast(&reader_helper), read_from_tensor, &error); - TORCH_CHECK( + VISION_CHECK( (gifFile != nullptr) && (error == D_GIF_SUCCEEDED), "DGifOpenFileName() failed - ", error); @@ -71,12 +73,12 @@ torch::Tensor decode_gif(const torch::Tensor& encoded_data) { if (DGifSlurp(gifFile) == GIF_ERROR) { auto gifFileError = gifFile->Error; DGifCloseFile(gifFile, &error); - TORCH_CHECK(false, "DGifSlurp() failed - ", gifFileError); + VISION_CHECK(false, "DGifSlurp() failed - ", gifFileError); } auto num_images = gifFile->ImageCount; // This check should already done within DGifSlurp(), just to be safe - TORCH_CHECK(num_images > 0, "GIF file should contain at least one image!"); + VISION_CHECK(num_images > 0, "GIF file should contain at least one image!"); GifColorType bg = {0, 0, 0}; if (gifFile->SColorMap) { @@ -94,12 +96,19 @@ torch::Tensor decode_gif(const torch::Tensor& encoded_data) { // We output a channels-last tensor for consistency with other image decoders. // Torchvision's resize tends to be is faster on uint8 channels-last tensors. - auto options = torch::TensorOptions() - .dtype(torch::kU8) - .memory_format(torch::MemoryFormat::ChannelsLast); - auto out = torch::empty( - {int64_t(num_images), 3, int64_t(out_h), int64_t(out_w)}, options); - auto out_a = out.accessor(); + std::vector sizes = { + int64_t(num_images), 3, int64_t(out_h), int64_t(out_w)}; + auto out = emptyCPUChannelsLast(sizes, kByte); + auto out_ptr = out.mutable_data_ptr(); + + // Calculate strides for NCHW layout with ChannelsLast memory format + // In ChannelsLast format for NCHW tensor, memory is laid out as NHWC + // Stride order: N -> HWC, C -> 1, H -> WC, W -> C + int64_t stride_n = 3 * out_h * out_w; + int64_t stride_c = 1; + int64_t stride_h = out_w * 3; + int64_t stride_w = 3; + for (int i = 0; i < num_images; i++) { const SavedImage& img = gifFile->SavedImages[i]; @@ -109,7 +118,7 @@ torch::Tensor decode_gif(const torch::Tensor& encoded_data) { const GifImageDesc& desc = img.ImageDesc; const ColorMapObject* cmap = desc.ColorMap ? desc.ColorMap : gifFile->SColorMap; - TORCH_CHECK( + VISION_CHECK( cmap != nullptr, "Global and local color maps are missing. This should never happen!"); @@ -132,14 +141,18 @@ torch::Tensor decode_gif(const torch::Tensor& encoded_data) { (gcb.DisposalMode == DISPOSAL_UNSPECIFIED || gcb.DisposalMode == DISPOSE_DO_NOT || gcb.DisposalMode == DISPOSE_PREVIOUS)) { - out[i] = out[i - 1]; + // Copy previous frame to current frame + auto prev_frame_ptr = out_ptr + (i - 1) * stride_n; + auto curr_frame_ptr = out_ptr + i * stride_n; + std::memcpy(curr_frame_ptr, prev_frame_ptr, stride_n); } else { // Background. If bg wasn't defined, it will be (0, 0, 0) for (int h = 0; h < gifFile->SHeight; h++) { for (int w = 0; w < gifFile->SWidth; w++) { - out_a[i][0][h][w] = bg.Red; - out_a[i][1][h][w] = bg.Green; - out_a[i][2][h][w] = bg.Blue; + auto base_idx = i * stride_n + h * stride_h + w * stride_w; + out_ptr[base_idx + 0 * stride_c] = bg.Red; + out_ptr[base_idx + 1 * stride_c] = bg.Green; + out_ptr[base_idx + 2 * stride_c] = bg.Blue; } } } @@ -151,17 +164,20 @@ torch::Tensor decode_gif(const torch::Tensor& encoded_data) { continue; } GifColorType rgb = cmap->Colors[c]; - out_a[i][0][h + desc.Top][w + desc.Left] = rgb.Red; - out_a[i][1][h + desc.Top][w + desc.Left] = rgb.Green; - out_a[i][2][h + desc.Top][w + desc.Left] = rgb.Blue; + auto base_idx = i * stride_n + (h + desc.Top) * stride_h + + (w + desc.Left) * stride_w; + out_ptr[base_idx + 0 * stride_c] = rgb.Red; + out_ptr[base_idx + 1 * stride_c] = rgb.Green; + out_ptr[base_idx + 2 * stride_c] = rgb.Blue; } } } - out = out.squeeze(0); // remove batch dim if there's only one image + out = torch::stable::squeeze( + out, 0); // remove batch dim if there's only one image DGifCloseFile(gifFile, &error); - TORCH_CHECK(error == D_GIF_SUCCEEDED, "DGifCloseFile() failed - ", error); + VISION_CHECK(error == D_GIF_SUCCEEDED, "DGifCloseFile() failed - ", error); return out; } diff --git a/torchvision/csrc/io/image/cpu/decode_gif.h b/torchvision/csrc/io/image/cpu/decode_gif.h index 68d5073c91b..2b14112a219 100644 --- a/torchvision/csrc/io/image/cpu/decode_gif.h +++ b/torchvision/csrc/io/image/cpu/decode_gif.h @@ -1,12 +1,12 @@ #pragma once -#include +#include "../../../StableABICompat.h" namespace vision { namespace image { // encoded_data tensor must be 1D uint8 and contiguous -C10_EXPORT torch::Tensor decode_gif(const torch::Tensor& encoded_data); +vision::stable::Tensor decode_gif(const vision::stable::Tensor& encoded_data); } // namespace image } // namespace vision diff --git a/torchvision/csrc/io/image/cpu/decode_image.cpp b/torchvision/csrc/io/image/cpu/decode_image.cpp index 43a688604f6..aa536fef59e 100644 --- a/torchvision/csrc/io/image/cpu/decode_image.cpp +++ b/torchvision/csrc/io/image/cpu/decode_image.cpp @@ -1,4 +1,5 @@ #include "decode_image.h" +#include "../common.h" #include "decode_gif.h" #include "decode_jpeg.h" @@ -8,32 +9,34 @@ namespace vision { namespace image { -torch::Tensor decode_image( - const torch::Tensor& data, +using namespace vision::stable; + +Tensor decode_image( + const Tensor& data, ImageReadMode mode, bool apply_exif_orientation) { // Check that tensor is a CPU tensor - TORCH_CHECK(data.device() == torch::kCPU, "Expected a CPU tensor"); + VISION_CHECK(data.device() == Device(kCPU), "Expected a CPU tensor"); // Check that the input tensor dtype is uint8 - TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor"); + VISION_CHECK(data.scalar_type() == kByte, "Expected a torch.uint8 tensor"); // Check that the input tensor is 1-dimensional - TORCH_CHECK( + VISION_CHECK( data.dim() == 1 && data.numel() > 0, "Expected a non empty 1-dimensional tensor"); auto err_msg = "Unsupported image file. Only jpeg, png, webp and gif are currently supported. For avif and heic format, please rely on `decode_avif` and `decode_heic` directly."; - auto datap = data.data_ptr(); + auto datap = data.const_data_ptr(); const uint8_t jpeg_signature[3] = {255, 216, 255}; // == "\xFF\xD8\xFF" - TORCH_CHECK(data.numel() >= 3, err_msg); + VISION_CHECK(data.numel() >= 3, err_msg); if (memcmp(jpeg_signature, datap, 3) == 0) { return decode_jpeg(data, mode, apply_exif_orientation); } const uint8_t png_signature[4] = {137, 80, 78, 71}; // == "\211PNG" - TORCH_CHECK(data.numel() >= 4, err_msg); + VISION_CHECK(data.numel() >= 4, err_msg); if (memcmp(png_signature, datap, 4) == 0) { return decode_png(data, mode, apply_exif_orientation); } @@ -42,7 +45,7 @@ torch::Tensor decode_image( 0x47, 0x49, 0x46, 0x38, 0x39, 0x61}; // == "GIF89a" const uint8_t gif_signature_2[6] = { 0x47, 0x49, 0x46, 0x38, 0x37, 0x61}; // == "GIF87a" - TORCH_CHECK(data.numel() >= 6, err_msg); + VISION_CHECK(data.numel() >= 6, err_msg); if (memcmp(gif_signature_1, datap, 6) == 0 || memcmp(gif_signature_2, datap, 6) == 0) { return decode_gif(data); @@ -51,13 +54,13 @@ torch::Tensor decode_image( const uint8_t webp_signature_begin[4] = {0x52, 0x49, 0x46, 0x46}; // == "RIFF" const uint8_t webp_signature_end[7] = { 0x57, 0x45, 0x42, 0x50, 0x56, 0x50, 0x38}; // == "WEBPVP8" - TORCH_CHECK(data.numel() >= 15, err_msg); + VISION_CHECK(data.numel() >= 15, err_msg); if ((memcmp(webp_signature_begin, datap, 4) == 0) && (memcmp(webp_signature_end, datap + 8, 7) == 0)) { return decode_webp(data, mode); } - TORCH_CHECK(false, err_msg); + VISION_CHECK(false, err_msg); } } // namespace image diff --git a/torchvision/csrc/io/image/cpu/decode_image.h b/torchvision/csrc/io/image/cpu/decode_image.h index f66d47eccd4..e20e6b355cf 100644 --- a/torchvision/csrc/io/image/cpu/decode_image.h +++ b/torchvision/csrc/io/image/cpu/decode_image.h @@ -1,13 +1,13 @@ #pragma once -#include +#include "../../../StableABICompat.h" #include "../common.h" namespace vision { namespace image { -C10_EXPORT torch::Tensor decode_image( - const torch::Tensor& data, +vision::stable::Tensor decode_image( + const vision::stable::Tensor& data, ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED, bool apply_exif_orientation = false); diff --git a/torchvision/csrc/io/image/cpu/decode_jpeg.cpp b/torchvision/csrc/io/image/cpu/decode_jpeg.cpp index 052b98e1be9..3e7c739b716 100644 --- a/torchvision/csrc/io/image/cpu/decode_jpeg.cpp +++ b/torchvision/csrc/io/image/cpu/decode_jpeg.cpp @@ -6,12 +6,14 @@ namespace vision { namespace image { +using namespace vision::stable; + #if !JPEG_FOUND -torch::Tensor decode_jpeg( - const torch::Tensor& data, +Tensor decode_jpeg( + const Tensor& data, ImageReadMode mode, bool apply_exif_orientation) { - TORCH_CHECK( + VISION_CHECK( false, "decode_jpeg: torchvision not compiled with libjpeg support"); } #else @@ -129,19 +131,16 @@ void convert_line_cmyk_to_gray( } // namespace -torch::Tensor decode_jpeg( - const torch::Tensor& data, +Tensor decode_jpeg( + const Tensor& data, ImageReadMode mode, bool apply_exif_orientation) { - C10_LOG_API_USAGE_ONCE( - "torchvision.csrc.io.image.cpu.decode_jpeg.decode_jpeg"); - validate_encoded_data(data); struct jpeg_decompress_struct cinfo; struct torch_jpeg_error_mgr jerr; - auto datap = data.data_ptr(); + auto datap = data.const_data_ptr(); // Setup decompression structure cinfo.err = jpeg_std_error(&jerr.pub); jerr.pub.error_exit = torch_jpeg_error_exit; @@ -151,7 +150,7 @@ torch::Tensor decode_jpeg( * We need to clean up the JPEG object. */ jpeg_destroy_decompress(&cinfo); - TORCH_CHECK(false, jerr.jpegLastErrorMsg); + VISION_CHECK(false, jerr.jpegLastErrorMsg); } jpeg_create_decompress(&cinfo); @@ -192,7 +191,8 @@ torch::Tensor decode_jpeg( */ default: jpeg_destroy_decompress(&cinfo); - TORCH_CHECK(false, "The provided mode is not supported for JPEG files"); + VISION_CHECK( + false, "The provided mode is not supported for JPEG files"); } jpeg_calc_output_dimensions(&cinfo); @@ -209,12 +209,11 @@ torch::Tensor decode_jpeg( int width = cinfo.output_width; int stride = width * channels; - auto tensor = - torch::empty({int64_t(height), int64_t(width), channels}, torch::kU8); - auto ptr = tensor.data_ptr(); - torch::Tensor cmyk_line_tensor; + auto tensor = emptyCPU({int64_t(height), int64_t(width), channels}, kByte); + auto ptr = tensor.mutable_data_ptr(); + Tensor cmyk_line_tensor; if (cmyk_to_rgb_or_gray) { - cmyk_line_tensor = torch::empty({int64_t(width), 4}, torch::kU8); + cmyk_line_tensor = emptyCPU({int64_t(width), 4}, kByte); } while (cinfo.output_scanline < cinfo.output_height) { @@ -223,7 +222,7 @@ torch::Tensor decode_jpeg( * more than one scanline at a time if that's more convenient. */ if (cmyk_to_rgb_or_gray) { - auto cmyk_line_ptr = cmyk_line_tensor.data_ptr(); + auto cmyk_line_ptr = cmyk_line_tensor.mutable_data_ptr(); jpeg_read_scanlines(&cinfo, &cmyk_line_ptr, 1); if (channels == 3) { @@ -239,7 +238,7 @@ torch::Tensor decode_jpeg( jpeg_finish_decompress(&cinfo); jpeg_destroy_decompress(&cinfo); - auto output = tensor.permute({2, 0, 1}); + auto output = permute(tensor, {2, 0, 1}); if (apply_exif_orientation) { return exif_orientation_transform(output, exif_orientation); diff --git a/torchvision/csrc/io/image/cpu/decode_jpeg.h b/torchvision/csrc/io/image/cpu/decode_jpeg.h index 7412a46d2ea..ed8760a5862 100644 --- a/torchvision/csrc/io/image/cpu/decode_jpeg.h +++ b/torchvision/csrc/io/image/cpu/decode_jpeg.h @@ -1,18 +1,18 @@ #pragma once -#include +#include "../../../StableABICompat.h" #include "../common.h" namespace vision { namespace image { -C10_EXPORT torch::Tensor decode_jpeg( - const torch::Tensor& data, +vision::stable::Tensor decode_jpeg( + const vision::stable::Tensor& data, ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED, bool apply_exif_orientation = false); -C10_EXPORT int64_t _jpeg_version(); -C10_EXPORT bool _is_compiled_against_turbo(); +int64_t _jpeg_version(); +bool _is_compiled_against_turbo(); } // namespace image } // namespace vision diff --git a/torchvision/csrc/io/image/cpu/decode_png.cpp b/torchvision/csrc/io/image/cpu/decode_png.cpp index 5ea6f073975..4db0eb36b75 100644 --- a/torchvision/csrc/io/image/cpu/decode_png.cpp +++ b/torchvision/csrc/io/image/cpu/decode_png.cpp @@ -6,14 +6,15 @@ namespace vision { namespace image { +using namespace vision::stable; using namespace exif_private; #if !PNG_FOUND -torch::Tensor decode_png( - const torch::Tensor& data, +Tensor decode_png( + const Tensor& data, ImageReadMode mode, bool apply_exif_orientation) { - TORCH_CHECK( + VISION_CHECK( false, "decode_png: torchvision not compiled with libPNG support"); } #else @@ -23,35 +24,32 @@ bool is_little_endian() { return *(uint8_t*)&x; } -torch::Tensor decode_png( - const torch::Tensor& data, +Tensor decode_png( + const Tensor& data, ImageReadMode mode, bool apply_exif_orientation) { - C10_LOG_API_USAGE_ONCE("torchvision.csrc.io.image.cpu.decode_png.decode_png"); - validate_encoded_data(data); auto png_ptr = png_create_read_struct(PNG_LIBPNG_VER_STRING, nullptr, nullptr, nullptr); - TORCH_CHECK(png_ptr, "libpng read structure allocation failed!") + VISION_CHECK(png_ptr, "libpng read structure allocation failed!") auto info_ptr = png_create_info_struct(png_ptr); if (!info_ptr) { png_destroy_read_struct(&png_ptr, nullptr, nullptr); // Seems redundant with the if statement. done here to avoid leaking memory. - TORCH_CHECK(info_ptr, "libpng info structure allocation failed!") + VISION_CHECK(info_ptr, "libpng info structure allocation failed!") } - auto accessor = data.accessor(); - auto datap = accessor.data(); - auto datap_len = accessor.size(0); + auto datap = data.const_data_ptr(); + auto datap_len = data.numel(); if (setjmp(png_jmpbuf(png_ptr)) != 0) { png_destroy_read_struct(&png_ptr, &info_ptr, nullptr); - TORCH_CHECK(false, "Internal error."); + VISION_CHECK(false, "Internal error."); } - TORCH_CHECK(datap_len >= 8, "Content is too small for png!") + VISION_CHECK(datap_len >= 8, "Content is too small for png!") auto is_png = !png_sig_cmp(datap, 0, 8); - TORCH_CHECK(is_png, "Content is not png!") + VISION_CHECK(is_png, "Content is not png!") struct Reader { png_const_bytep ptr; @@ -64,7 +62,7 @@ torch::Tensor decode_png( png_bytep output, png_size_t bytes) { auto reader = static_cast(png_get_io_ptr(png_ptr)); - TORCH_CHECK( + VISION_CHECK( reader->count >= bytes, "Out of bound read in decode_png. Probably, the input image is corrupted"); std::copy(reader->ptr, reader->ptr + bytes, output); @@ -91,12 +89,12 @@ torch::Tensor decode_png( if (retval != 1) { png_destroy_read_struct(&png_ptr, &info_ptr, nullptr); - TORCH_CHECK(retval == 1, "Could read image metadata from content.") + VISION_CHECK(retval == 1, "Could read image metadata from content.") } if (bit_depth > 8 && bit_depth != 16) { png_destroy_read_struct(&png_ptr, &info_ptr, nullptr); - TORCH_CHECK( + VISION_CHECK( false, "bit depth of png image is " + std::to_string(bit_depth) + ". Only <=8 and 16 are supported.") @@ -188,7 +186,7 @@ torch::Tensor decode_png( break; default: png_destroy_read_struct(&png_ptr, &info_ptr, nullptr); - TORCH_CHECK(false, "The provided mode is not supported for PNG files"); + VISION_CHECK(false, "The provided mode is not supported for PNG files"); } png_read_update_info(png_ptr, info_ptr); @@ -196,19 +194,19 @@ torch::Tensor decode_png( auto num_pixels_per_row = width * channels; auto is_16_bits = bit_depth == 16; - auto tensor = torch::empty( + auto tensor = emptyCPU( {int64_t(height), int64_t(width), channels}, - is_16_bits ? at::kUInt16 : torch::kU8); + is_16_bits ? kUInt16 : kByte); if (is_little_endian()) { png_set_swap(png_ptr); } - auto t_ptr = (uint8_t*)tensor.data_ptr(); + auto t_ptr = (uint8_t*)tensor.mutable_data_ptr(); for (int pass = 0; pass < number_of_passes; pass++) { for (png_uint_32 i = 0; i < height; ++i) { png_read_row(png_ptr, t_ptr, nullptr); t_ptr += num_pixels_per_row * (is_16_bits ? 2 : 1); } - t_ptr = (uint8_t*)tensor.data_ptr(); + t_ptr = (uint8_t*)tensor.mutable_data_ptr(); } int exif_orientation = -1; @@ -218,7 +216,7 @@ torch::Tensor decode_png( png_destroy_read_struct(&png_ptr, &info_ptr, nullptr); - auto output = tensor.permute({2, 0, 1}); + auto output = permute(tensor, {2, 0, 1}); if (apply_exif_orientation) { return exif_orientation_transform(output, exif_orientation); } diff --git a/torchvision/csrc/io/image/cpu/decode_png.h b/torchvision/csrc/io/image/cpu/decode_png.h index faaffa7ae49..286cdf7571b 100644 --- a/torchvision/csrc/io/image/cpu/decode_png.h +++ b/torchvision/csrc/io/image/cpu/decode_png.h @@ -1,13 +1,13 @@ #pragma once -#include +#include "../../../StableABICompat.h" #include "../common.h" namespace vision { namespace image { -C10_EXPORT torch::Tensor decode_png( - const torch::Tensor& data, +vision::stable::Tensor decode_png( + const vision::stable::Tensor& data, ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED, bool apply_exif_orientation = false); diff --git a/torchvision/csrc/io/image/cpu/decode_webp.cpp b/torchvision/csrc/io/image/cpu/decode_webp.cpp index 80fe68862fb..945fffeba70 100644 --- a/torchvision/csrc/io/image/cpu/decode_webp.cpp +++ b/torchvision/csrc/io/image/cpu/decode_webp.cpp @@ -9,35 +9,30 @@ namespace vision { namespace image { +using namespace vision::stable; + #if !WEBP_FOUND -torch::Tensor decode_webp( - const torch::Tensor& encoded_data, - ImageReadMode mode) { - TORCH_CHECK( +Tensor decode_webp(const Tensor& encoded_data, ImageReadMode mode) { + VISION_CHECK( false, "decode_webp: torchvision not compiled with libwebp support"); } #else -torch::Tensor decode_webp( - const torch::Tensor& encoded_data, - ImageReadMode mode) { +Tensor decode_webp(const Tensor& encoded_data, ImageReadMode mode) { validate_encoded_data(encoded_data); - auto encoded_data_p = encoded_data.data_ptr(); + auto encoded_data_p = encoded_data.const_data_ptr(); auto encoded_data_size = encoded_data.numel(); WebPBitstreamFeatures features; auto res = WebPGetFeatures(encoded_data_p, encoded_data_size, &features); - TORCH_CHECK( + VISION_CHECK( res == VP8_STATUS_OK, "WebPGetFeatures failed with error code ", res); - TORCH_CHECK( + VISION_CHECK( !features.has_animation, "Animated webp files are not supported."); - if (mode == IMAGE_READ_MODE_GRAY || mode == IMAGE_READ_MODE_GRAY_ALPHA) { - TORCH_WARN_ONCE( - "Webp does not support grayscale conversions. " - "The returned tensor will be in the colorspace of the original image."); - } + // Note: TORCH_WARN_ONCE is not available in stable ABI, so we just skip the + // warning auto return_rgb = should_this_return_rgb_or_rgba_let_me_know_in_the_comments_down_below_guys_see_you_in_the_next_video( @@ -52,13 +47,18 @@ torch::Tensor decode_webp( auto decoded_data = decoding_func(encoded_data_p, encoded_data_size, &width, &height); - TORCH_CHECK(decoded_data != nullptr, "WebPDecodeRGB[A] failed."); + VISION_CHECK(decoded_data != nullptr, "WebPDecodeRGB[A] failed."); + + // Create tensor and copy data (from_blob with deleter not available in stable + // ABI) + auto out = emptyCPU({height, width, num_channels}, kByte); + auto out_ptr = out.mutable_data_ptr(); + std::memcpy(out_ptr, decoded_data, height * width * num_channels); - auto deleter = [decoded_data](void*) { WebPFree(decoded_data); }; - auto out = torch::from_blob( - decoded_data, {height, width, num_channels}, deleter, torch::kUInt8); + // Free the webp-allocated memory + WebPFree(decoded_data); - return out.permute({2, 0, 1}); + return permute(out, {2, 0, 1}); } #endif // WEBP_FOUND diff --git a/torchvision/csrc/io/image/cpu/decode_webp.h b/torchvision/csrc/io/image/cpu/decode_webp.h index d5c81547c42..8e897ff8d8b 100644 --- a/torchvision/csrc/io/image/cpu/decode_webp.h +++ b/torchvision/csrc/io/image/cpu/decode_webp.h @@ -1,13 +1,13 @@ #pragma once -#include +#include "../../../StableABICompat.h" #include "../common.h" namespace vision { namespace image { -C10_EXPORT torch::Tensor decode_webp( - const torch::Tensor& encoded_data, +vision::stable::Tensor decode_webp( + const vision::stable::Tensor& encoded_data, ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED); } // namespace image diff --git a/torchvision/csrc/io/image/cpu/encode_jpeg.cpp b/torchvision/csrc/io/image/cpu/encode_jpeg.cpp index d2ed73071a2..71b1749f5b2 100644 --- a/torchvision/csrc/io/image/cpu/encode_jpeg.cpp +++ b/torchvision/csrc/io/image/cpu/encode_jpeg.cpp @@ -1,14 +1,17 @@ #include "encode_jpeg.h" +#include "../common.h" #include "common_jpeg.h" namespace vision { namespace image { +using namespace vision::stable; + #if !JPEG_FOUND -torch::Tensor encode_jpeg(const torch::Tensor& data, int64_t quality) { - TORCH_CHECK( +Tensor encode_jpeg(const Tensor& data, int64_t quality) { + VISION_CHECK( false, "encode_jpeg: torchvision not compiled with libjpeg support"); } @@ -24,10 +27,9 @@ using JpegSizeType = size_t; using namespace detail; -torch::Tensor encode_jpeg(const torch::Tensor& data, int64_t quality) { - C10_LOG_API_USAGE_ONCE( - "torchvision.csrc.io.image.cpu.encode_jpeg.encode_jpeg"); - // Define compression structures and error handling +Tensor encode_jpeg(const Tensor& data, int64_t quality) { + // Note: C10_LOG_API_USAGE_ONCE is not available in stable ABI, so we just + // skip the logging Define compression structures and error handling struct jpeg_compress_struct cinfo {}; struct torch_jpeg_error_mgr jerr {}; @@ -48,25 +50,26 @@ torch::Tensor encode_jpeg(const torch::Tensor& data, int64_t quality) { free(jpegBuf); } - TORCH_CHECK(false, (const char*)jerr.jpegLastErrorMsg); + VISION_CHECK(false, (const char*)jerr.jpegLastErrorMsg); } // Check that the input tensor is on CPU - TORCH_CHECK(data.device() == torch::kCPU, "Input tensor should be on CPU"); + VISION_CHECK(data.device() == Device(kCPU), "Input tensor should be on CPU"); // Check that the input tensor dtype is uint8 - TORCH_CHECK(data.dtype() == torch::kU8, "Input tensor dtype should be uint8"); + VISION_CHECK( + data.scalar_type() == kByte, "Input tensor dtype should be uint8"); // Check that the input tensor is 3-dimensional - TORCH_CHECK(data.dim() == 3, "Input data should be a 3-dimensional tensor"); + VISION_CHECK(data.dim() == 3, "Input data should be a 3-dimensional tensor"); // Get image info int channels = data.size(0); int height = data.size(1); int width = data.size(2); - auto input = data.permute({1, 2, 0}).contiguous(); + auto input = torch::stable::contiguous(permute(data, {1, 2, 0})); - TORCH_CHECK( + VISION_CHECK( channels == 1 || channels == 3, "The number of channels should be 1 or 3, got: ", channels); @@ -90,21 +93,24 @@ torch::Tensor encode_jpeg(const torch::Tensor& data, int64_t quality) { jpeg_start_compress(&cinfo, TRUE); auto stride = width * channels; - auto ptr = input.data_ptr(); + auto ptr = input.const_data_ptr(); // Encode JPEG file while (cinfo.next_scanline < cinfo.image_height) { - jpeg_write_scanlines(&cinfo, &ptr, 1); + jpeg_write_scanlines(&cinfo, const_cast(&ptr), 1); ptr += stride; } jpeg_finish_compress(&cinfo); jpeg_destroy_compress(&cinfo); - torch::TensorOptions options = torch::TensorOptions{torch::kU8}; - auto out_tensor = - torch::from_blob(jpegBuf, {(long)jpegSize}, ::free, options); - jpegBuf = nullptr; + // Create tensor and copy data (from_blob with deleter not available in stable + // ABI) + auto out_tensor = emptyCPU({(int64_t)jpegSize}, kByte); + auto out_ptr = out_tensor.mutable_data_ptr(); + std::memcpy(out_ptr, jpegBuf, jpegSize); + free(jpegBuf); + return out_tensor; } #endif diff --git a/torchvision/csrc/io/image/cpu/encode_jpeg.h b/torchvision/csrc/io/image/cpu/encode_jpeg.h index 25084e154d6..2b69cc4e319 100644 --- a/torchvision/csrc/io/image/cpu/encode_jpeg.h +++ b/torchvision/csrc/io/image/cpu/encode_jpeg.h @@ -1,12 +1,12 @@ #pragma once -#include +#include "../../../StableABICompat.h" namespace vision { namespace image { -C10_EXPORT torch::Tensor encode_jpeg( - const torch::Tensor& data, +vision::stable::Tensor encode_jpeg( + const vision::stable::Tensor& data, int64_t quality); } // namespace image diff --git a/torchvision/csrc/io/image/cpu/encode_png.cpp b/torchvision/csrc/io/image/cpu/encode_png.cpp index d55a0ed3ff6..8a9671b5b7b 100644 --- a/torchvision/csrc/io/image/cpu/encode_png.cpp +++ b/torchvision/csrc/io/image/cpu/encode_png.cpp @@ -1,14 +1,17 @@ -#include "encode_jpeg.h" +#include "encode_png.h" +#include "../common.h" #include "common_png.h" namespace vision { namespace image { +using namespace vision::stable; + #if !PNG_FOUND -torch::Tensor encode_png(const torch::Tensor& data, int64_t compression_level) { - TORCH_CHECK( +Tensor encode_png(const Tensor& data, int64_t compression_level) { + VISION_CHECK( false, "encode_png: torchvision not compiled with libpng support"); } @@ -64,9 +67,9 @@ void torch_png_write_data( } // namespace -torch::Tensor encode_png(const torch::Tensor& data, int64_t compression_level) { - C10_LOG_API_USAGE_ONCE("torchvision.csrc.io.image.cpu.encode_png.encode_png"); - // Define compression structures and error handling +Tensor encode_png(const Tensor& data, int64_t compression_level) { + // Note: C10_LOG_API_USAGE_ONCE is not available in stable ABI, so we just + // skip the logging Define compression structures and error handling png_structp png_write; png_infop info_ptr; struct torch_png_error_mgr err_ptr; @@ -93,30 +96,31 @@ torch::Tensor encode_png(const torch::Tensor& data, int64_t compression_level) { free(buf_info.buffer); } - TORCH_CHECK(false, err_ptr.pngLastErrorMsg); + VISION_CHECK(false, err_ptr.pngLastErrorMsg); } // Check that the compression level is between 0 and 9 - TORCH_CHECK( + VISION_CHECK( compression_level >= 0 && compression_level <= 9, "Compression level should be between 0 and 9"); // Check that the input tensor is on CPU - TORCH_CHECK(data.device() == torch::kCPU, "Input tensor should be on CPU"); + VISION_CHECK(data.device() == Device(kCPU), "Input tensor should be on CPU"); // Check that the input tensor dtype is uint8 - TORCH_CHECK(data.dtype() == torch::kU8, "Input tensor dtype should be uint8"); + VISION_CHECK( + data.scalar_type() == kByte, "Input tensor dtype should be uint8"); // Check that the input tensor is 3-dimensional - TORCH_CHECK(data.dim() == 3, "Input data should be a 3-dimensional tensor"); + VISION_CHECK(data.dim() == 3, "Input data should be a 3-dimensional tensor"); // Get image info int channels = data.size(0); int height = data.size(1); int width = data.size(2); - auto input = data.permute({1, 2, 0}).contiguous(); + auto input = torch::stable::contiguous(permute(data, {1, 2, 0})); - TORCH_CHECK( + VISION_CHECK( channels == 1 || channels == 3, "The number of channels should be 1 or 3, got: ", channels); @@ -150,7 +154,7 @@ torch::Tensor encode_png(const torch::Tensor& data, int64_t compression_level) { png_write_info(png_write, info_ptr); auto stride = width * channels; - auto ptr = input.data_ptr(); + auto ptr = input.const_data_ptr(); // Encode PNG file for (int y = 0; y < height; ++y) { @@ -164,13 +168,10 @@ torch::Tensor encode_png(const torch::Tensor& data, int64_t compression_level) { // Destroy structures png_destroy_write_struct(&png_write, &info_ptr); - torch::TensorOptions options = torch::TensorOptions{torch::kU8}; - auto outTensor = torch::empty({(long)buf_info.size}, options); - - // Copy memory from png buffer, since torch cannot get ownership of it via - // `from_blob` - auto outPtr = outTensor.data_ptr(); - std::memcpy(outPtr, buf_info.buffer, sizeof(uint8_t) * outTensor.numel()); + // Create tensor and copy data + auto outTensor = emptyCPU({(int64_t)buf_info.size}, kByte); + auto outPtr = outTensor.mutable_data_ptr(); + std::memcpy(outPtr, buf_info.buffer, buf_info.size); free(buf_info.buffer); return outTensor; diff --git a/torchvision/csrc/io/image/cpu/encode_png.h b/torchvision/csrc/io/image/cpu/encode_png.h index 86a67c8706e..a27b4d47b5a 100644 --- a/torchvision/csrc/io/image/cpu/encode_png.h +++ b/torchvision/csrc/io/image/cpu/encode_png.h @@ -1,12 +1,12 @@ #pragma once -#include +#include "../../../StableABICompat.h" namespace vision { namespace image { -C10_EXPORT torch::Tensor encode_png( - const torch::Tensor& data, +vision::stable::Tensor encode_png( + const vision::stable::Tensor& data, int64_t compression_level); } // namespace image diff --git a/torchvision/csrc/io/image/cpu/exif.h b/torchvision/csrc/io/image/cpu/exif.h index 7680737f8c0..022340288fc 100644 --- a/torchvision/csrc/io/image/cpu/exif.h +++ b/torchvision/csrc/io/image/cpu/exif.h @@ -58,12 +58,14 @@ direct, #include #endif -#include +#include "../../../StableABICompat.h" namespace vision { namespace image { namespace exif_private { +using namespace vision::stable; + constexpr uint16_t APP1 = 0xe1; constexpr uint16_t ENDIANNESS_INTEL = 0x49; constexpr uint16_t ENDIANNESS_MOTO = 0x4d; @@ -78,7 +80,7 @@ class ExifDataReader { return _size; } const unsigned char& operator[](size_t index) const { - TORCH_CHECK(index >= 0 && index < _size); + VISION_CHECK(index >= 0 && index < _size, "EXIF data index out of bounds"); return _ptr[index]; } @@ -227,27 +229,25 @@ constexpr uint16_t IMAGE_ORIENTATION_RB = 7; // mirrored horizontal & rotate 90 CW constexpr uint16_t IMAGE_ORIENTATION_LB = 8; // needs 270 CW rotation -inline torch::Tensor exif_orientation_transform( - const torch::Tensor& image, - int orientation) { +inline Tensor exif_orientation_transform(const Tensor& image, int orientation) { if (orientation == IMAGE_ORIENTATION_TL) { return image; } else if (orientation == IMAGE_ORIENTATION_TR) { - return image.flip(-1); + return flip(image, {-1}); } else if (orientation == IMAGE_ORIENTATION_BR) { // needs 180 rotation equivalent to // flip both horizontally and vertically - return image.flip({-2, -1}); + return flip(image, {-2, -1}); } else if (orientation == IMAGE_ORIENTATION_BL) { - return image.flip(-2); + return flip(image, {-2}); } else if (orientation == IMAGE_ORIENTATION_LT) { - return image.transpose(-1, -2); + return torch::stable::transpose(image, -1, -2); } else if (orientation == IMAGE_ORIENTATION_RT) { - return image.transpose(-1, -2).flip(-1); + return flip(torch::stable::transpose(image, -1, -2), {-1}); } else if (orientation == IMAGE_ORIENTATION_RB) { - return image.transpose(-1, -2).flip({-2, -1}); + return flip(torch::stable::transpose(image, -1, -2), {-2, -1}); } else if (orientation == IMAGE_ORIENTATION_LB) { - return image.transpose(-1, -2).flip(-2); + return flip(torch::stable::transpose(image, -1, -2), {-2}); } return image; } diff --git a/torchvision/csrc/io/image/cpu/read_write_file.cpp b/torchvision/csrc/io/image/cpu/read_write_file.cpp index 06de72a5053..9c94304f6f0 100644 --- a/torchvision/csrc/io/image/cpu/read_write_file.cpp +++ b/torchvision/csrc/io/image/cpu/read_write_file.cpp @@ -10,6 +10,8 @@ namespace vision { namespace image { +using namespace vision::stable; + #ifdef _WIN32 namespace { std::wstring utf8_decode(const std::string& str) { @@ -18,7 +20,7 @@ std::wstring utf8_decode(const std::string& str) { } int size_needed = MultiByteToWideChar( CP_UTF8, 0, str.c_str(), static_cast(str.size()), nullptr, 0); - TORCH_CHECK(size_needed > 0, "Error converting the content to Unicode"); + VISION_CHECK(size_needed > 0, "Error converting the content to Unicode"); std::wstring wstrTo(size_needed, 0); MultiByteToWideChar( CP_UTF8, @@ -32,9 +34,7 @@ std::wstring utf8_decode(const std::string& str) { } // namespace #endif -torch::Tensor read_file(const std::string& filename) { - C10_LOG_API_USAGE_ONCE( - "torchvision.csrc.io.image.cpu.read_write_file.read_file"); +Tensor read_file(std::string filename) { #ifdef _WIN32 // According to // https://docs.microsoft.com/en-us/cpp/c-runtime-library/reference/stat-functions?view=vs-2019, @@ -47,49 +47,44 @@ torch::Tensor read_file(const std::string& filename) { int rc = stat(filename.c_str(), &stat_buf); #endif // errno is a variable defined in errno.h - TORCH_CHECK( + VISION_CHECK( rc == 0, "[Errno ", errno, "] ", strerror(errno), ": '", filename, "'"); int64_t size = stat_buf.st_size; - TORCH_CHECK(size > 0, "Expected a non empty file"); + VISION_CHECK(size > 0, "Expected a non empty file"); #ifdef _WIN32 - // TODO: Once torch::from_file handles UTF-8 paths correctly, we should move - // back to use the following implementation since it uses file mapping. - // auto data = - // torch::from_file(filename, /*shared=*/false, /*size=*/size, - // torch::kU8).clone() + // On Windows, read file manually since torch::from_file may not handle UTF-8 FILE* infile = _wfopen(fileW.c_str(), L"rb"); +#else + // Use fopen/fread instead of from_file for stable ABI compatibility + FILE* infile = fopen(filename.c_str(), "rb"); +#endif - TORCH_CHECK(infile != nullptr, "Error opening input file"); + VISION_CHECK(infile != nullptr, "Error opening input file"); - auto data = torch::empty({size}, torch::kU8); - auto dataBytes = data.data_ptr(); + auto data = emptyCPU({size}, kByte); + auto dataBytes = data.mutable_data_ptr(); fread(dataBytes, sizeof(uint8_t), size, infile); fclose(infile); -#else - auto data = - torch::from_file(filename, /*shared=*/false, /*size=*/size, torch::kU8); -#endif return data; } -void write_file(const std::string& filename, torch::Tensor& data) { - C10_LOG_API_USAGE_ONCE( - "torchvision.csrc.io.image.cpu.read_write_file.write_file"); +void write_file(std::string filename, const Tensor& data) { // Check that the input tensor is on CPU - TORCH_CHECK(data.device() == torch::kCPU, "Input tensor should be on CPU"); + VISION_CHECK(data.device().type() == kCPU, "Input tensor should be on CPU"); // Check that the input tensor dtype is uint8 - TORCH_CHECK(data.dtype() == torch::kU8, "Input tensor dtype should be uint8"); + VISION_CHECK( + data.scalar_type() == kByte, "Input tensor dtype should be uint8"); // Check that the input tensor is 3-dimensional - TORCH_CHECK(data.dim() == 1, "Input data should be a 1-dimensional tensor"); + VISION_CHECK(data.dim() == 1, "Input data should be a 1-dimensional tensor"); - auto fileBytes = data.data_ptr(); + auto fileBytes = data.const_data_ptr(); auto fileCStr = filename.c_str(); #ifdef _WIN32 auto fileW = utf8_decode(filename); @@ -98,7 +93,7 @@ void write_file(const std::string& filename, torch::Tensor& data) { FILE* outfile = fopen(fileCStr, "wb"); #endif - TORCH_CHECK(outfile != nullptr, "Error opening output file"); + VISION_CHECK(outfile != nullptr, "Error opening output file"); fwrite(fileBytes, sizeof(uint8_t), data.numel(), outfile); fclose(outfile); diff --git a/torchvision/csrc/io/image/cpu/read_write_file.h b/torchvision/csrc/io/image/cpu/read_write_file.h index a5a712dd8e2..2170c71bdce 100644 --- a/torchvision/csrc/io/image/cpu/read_write_file.h +++ b/torchvision/csrc/io/image/cpu/read_write_file.h @@ -1,13 +1,13 @@ #pragma once -#include +#include "../../../StableABICompat.h" namespace vision { namespace image { -C10_EXPORT torch::Tensor read_file(const std::string& filename); +vision::stable::Tensor read_file(std::string filename); -C10_EXPORT void write_file(const std::string& filename, torch::Tensor& data); +void write_file(std::string filename, const vision::stable::Tensor& data); } // namespace image } // namespace vision diff --git a/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp b/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp index 85aa6c760c1..5ef3a5ac996 100644 --- a/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp +++ b/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp @@ -2,11 +2,14 @@ #if !NVJPEG_FOUND namespace vision { namespace image { -std::vector decode_jpegs_cuda( - const std::vector& encoded_images, + +using namespace vision::stable; + +std::vector decode_jpegs_cuda( + const std::vector& encoded_images, vision::image::ImageReadMode mode, - torch::Device device) { - TORCH_CHECK( + Device device) { + VISION_CHECK( false, "decode_jpegs_cuda: torchvision not compiled with nvJPEG support"); } } // namespace image @@ -28,40 +31,41 @@ std::vector decode_jpegs_cuda( namespace vision { namespace image { +using namespace vision::stable; + std::mutex decoderMutex; std::unique_ptr cudaJpegDecoder; -std::vector decode_jpegs_cuda( - const std::vector& encoded_images, +std::vector decode_jpegs_cuda( + const std::vector& encoded_images, vision::image::ImageReadMode mode, - torch::Device device) { - C10_LOG_API_USAGE_ONCE( - "torchvision.csrc.io.image.cuda.decode_jpegs_cuda.decode_jpegs_cuda"); + Device device) { + // Note: C10_LOG_API_USAGE_ONCE is not available in stable ABI std::lock_guard lock(decoderMutex); - std::vector contig_images; + std::vector contig_images; contig_images.reserve(encoded_images.size()); - TORCH_CHECK( + VISION_CHECK( device.is_cuda(), "Expected the device parameter to be a cuda device"); for (auto& encoded_image : encoded_images) { - TORCH_CHECK( - encoded_image.dtype() == torch::kU8, "Expected a torch.uint8 tensor"); + VISION_CHECK( + encoded_image.scalar_type() == kByte, "Expected a torch.uint8 tensor"); - TORCH_CHECK( + VISION_CHECK( !encoded_image.is_cuda(), "The input tensor must be on CPU when decoding with nvjpeg") - TORCH_CHECK( + VISION_CHECK( encoded_image.dim() == 1 && encoded_image.numel() > 0, "Expected a non empty 1-dimensional tensor"); // nvjpeg requires images to be contiguous - if (encoded_image.is_contiguous()) { + if (is_contiguous(encoded_image)) { contig_images.push_back(encoded_image); } else { - contig_images.push_back(encoded_image.contiguous()); + contig_images.push_back(torch::stable::contiguous(encoded_image)); } } @@ -72,21 +76,18 @@ std::vector decode_jpegs_cuda( nvjpegStatus_t get_minor_property_status = nvjpegGetProperty(MINOR_VERSION, &minor_version); - TORCH_CHECK( + VISION_CHECK( get_major_property_status == NVJPEG_STATUS_SUCCESS, "nvjpegGetProperty failed: ", get_major_property_status); - TORCH_CHECK( + VISION_CHECK( get_minor_property_status == NVJPEG_STATUS_SUCCESS, "nvjpegGetProperty failed: ", get_minor_property_status); - if ((major_version < 11) || ((major_version == 11) && (minor_version < 6))) { - TORCH_WARN_ONCE( - "There is a memory leak issue in the nvjpeg library for CUDA versions < 11.6. " - "Make sure to rely on CUDA 11.6 or above before using decode_jpeg(..., device='cuda')."); - } + // Note: TORCH_WARN_ONCE is not available in stable ABI, so we skip the + // warning about nvjpeg memory leaks in CUDA versions < 11.6 - at::cuda::CUDAGuard device_guard(device); + at::cuda::CUDAGuard device_guard(at::Device(at::kCUDA, device.index())); if (cudaJpegDecoder == nullptr || device != cudaJpegDecoder->target_device) { if (cudaJpegDecoder != nullptr) { @@ -114,7 +115,7 @@ std::vector decode_jpegs_cuda( output_format = NVJPEG_OUTPUT_RGB; break; default: - TORCH_CHECK( + VISION_CHECK( false, "The provided mode is not supported for JPEG decoding on GPU"); } @@ -130,15 +131,15 @@ std::vector decode_jpegs_cuda( return result; } catch (const std::exception& e) { if (typeid(e) != typeid(std::runtime_error)) { - TORCH_CHECK(false, "Error while decoding JPEG images: ", e.what()); + VISION_CHECK(false, "Error while decoding JPEG images: ", e.what()); } else { throw; } } } -CUDAJpegDecoder::CUDAJpegDecoder(const torch::Device& target_device) - : original_device{torch::kCUDA, c10::cuda::current_device()}, +CUDAJpegDecoder::CUDAJpegDecoder(const Device& target_device) + : original_device{kCUDA, c10::cuda::current_device()}, target_device{target_device}, stream{ target_device.has_index() @@ -160,70 +161,70 @@ CUDAJpegDecoder::CUDAJpegDecoder(const torch::Device& target_device) NULL, NVJPEG_FLAGS_DEFAULT, &nvjpeg_handle); - TORCH_CHECK( + VISION_CHECK( status == NVJPEG_STATUS_SUCCESS, "Failed to initialize nvjpeg with default backend: ", status); hw_decode_available = false; } else { - TORCH_CHECK( + VISION_CHECK( status == NVJPEG_STATUS_SUCCESS, "Failed to initialize nvjpeg with hardware backend: ", status); } status = nvjpegJpegStateCreate(nvjpeg_handle, &nvjpeg_state); - TORCH_CHECK( + VISION_CHECK( status == NVJPEG_STATUS_SUCCESS, "Failed to create nvjpeg state: ", status); status = nvjpegDecoderCreate( nvjpeg_handle, NVJPEG_BACKEND_DEFAULT, &nvjpeg_decoder); - TORCH_CHECK( + VISION_CHECK( status == NVJPEG_STATUS_SUCCESS, "Failed to create nvjpeg decoder: ", status); status = nvjpegDecoderStateCreate( nvjpeg_handle, nvjpeg_decoder, &nvjpeg_decoupled_state); - TORCH_CHECK( + VISION_CHECK( status == NVJPEG_STATUS_SUCCESS, "Failed to create nvjpeg decoder state: ", status); status = nvjpegBufferPinnedCreate(nvjpeg_handle, NULL, &pinned_buffers[0]); - TORCH_CHECK( + VISION_CHECK( status == NVJPEG_STATUS_SUCCESS, "Failed to create pinned buffer: ", status); status = nvjpegBufferPinnedCreate(nvjpeg_handle, NULL, &pinned_buffers[1]); - TORCH_CHECK( + VISION_CHECK( status == NVJPEG_STATUS_SUCCESS, "Failed to create pinned buffer: ", status); status = nvjpegBufferDeviceCreate(nvjpeg_handle, NULL, &device_buffer); - TORCH_CHECK( + VISION_CHECK( status == NVJPEG_STATUS_SUCCESS, "Failed to create device buffer: ", status); status = nvjpegJpegStreamCreate(nvjpeg_handle, &jpeg_streams[0]); - TORCH_CHECK( + VISION_CHECK( status == NVJPEG_STATUS_SUCCESS, "Failed to create jpeg stream: ", status); status = nvjpegJpegStreamCreate(nvjpeg_handle, &jpeg_streams[1]); - TORCH_CHECK( + VISION_CHECK( status == NVJPEG_STATUS_SUCCESS, "Failed to create jpeg stream: ", status); status = nvjpegDecodeParamsCreate(nvjpeg_handle, &nvjpeg_decode_params); - TORCH_CHECK( + VISION_CHECK( status == NVJPEG_STATUS_SUCCESS, "Failed to create decode params: ", status); @@ -301,19 +302,16 @@ CUDAJpegDecoder::~CUDAJpegDecoder() { // status == NVJPEG_STATUS_SUCCESS, "nvjpegDestroy failed: ", status); } -std::tuple< - std::vector, - std::vector, - std::vector> +std::tuple, std::vector, std::vector> CUDAJpegDecoder::prepare_buffers( - const std::vector& encoded_images, + const std::vector& encoded_images, const nvjpegOutputFormat_t& output_format) { /* This function scans the encoded images' jpeg headers and allocates decoding buffers based on the metadata found Args: - - encoded_images (std::vector): a vector of tensors + - encoded_images (std::vector): a vector of tensors containing the jpeg bitstreams to be decoded. Each tensor must have dtype torch.uint8 and device cpu - output_format (nvjpegOutputFormat_t): NVJPEG_OUTPUT_RGB, NVJPEG_OUTPUT_Y @@ -322,7 +320,7 @@ CUDAJpegDecoder::prepare_buffers( Returns: - decoded_images (std::vector): a vector of nvjpegImages containing pointers to the memory of the decoded images - - output_tensors (std::vector): a vector of Tensors + - output_tensors (std::vector): a vector of Tensors containing the decoded images. `decoded_images` points to the memory of output_tensors - channels (std::vector): a vector of ints containing the number of @@ -335,25 +333,24 @@ CUDAJpegDecoder::prepare_buffers( nvjpegChromaSubsampling_t subsampling; nvjpegStatus_t status; - std::vector output_tensors{encoded_images.size()}; + std::vector output_tensors{encoded_images.size()}; std::vector decoded_images{encoded_images.size()}; - for (std::vector::size_type i = 0; i < encoded_images.size(); - i++) { + for (size_t i = 0; i < encoded_images.size(); i++) { // extract bitstream meta data to figure out the number of channels, height, // width for every image status = nvjpegGetImageInfo( nvjpeg_handle, - (unsigned char*)encoded_images[i].data_ptr(), + encoded_images[i].const_data_ptr(), encoded_images[i].numel(), &channels[i], &subsampling, width, height); - TORCH_CHECK( + VISION_CHECK( status == NVJPEG_STATUS_SUCCESS, "Failed to get image info: ", status); - TORCH_CHECK( + VISION_CHECK( subsampling != NVJPEG_CSS_UNKNOWN, "Unknown chroma subsampling"); // output channels may be different from the actual number of channels in @@ -368,14 +365,16 @@ CUDAJpegDecoder::prepare_buffers( } // reserve output buffer - auto output_tensor = torch::empty( + auto output_tensor = empty( {int64_t(output_channels), int64_t(height[0]), int64_t(width[0])}, - torch::dtype(torch::kU8).device(target_device)); + kByte, + target_device); output_tensors[i] = output_tensor; // fill nvjpegImage_t struct for (int c = 0; c < output_channels; c++) { - decoded_images[i].channel[c] = output_tensor[c].data_ptr(); + decoded_images[i].channel[c] = torch::stable::select(output_tensor, 0, c) + .mutable_data_ptr(); decoded_images[i].pitch[c] = width[0]; } for (int c = output_channels; c < NVJPEG_MAX_COMPONENT; c++) { @@ -386,8 +385,8 @@ CUDAJpegDecoder::prepare_buffers( return {decoded_images, output_tensors, channels}; } -std::vector CUDAJpegDecoder::decode_images( - const std::vector& encoded_images, +std::vector CUDAJpegDecoder::decode_images( + const std::vector& encoded_images, const nvjpegOutputFormat_t& output_format) { /* This function decodes a batch of jpeg bitstreams. @@ -402,14 +401,14 @@ std::vector CUDAJpegDecoder::decode_images( for reference. Args: - - encoded_images (std::vector): a vector of tensors + - encoded_images (std::vector): a vector of tensors containing the jpeg bitstreams to be decoded - output_format (nvjpegOutputFormat_t): NVJPEG_OUTPUT_RGB, NVJPEG_OUTPUT_Y or NVJPEG_OUTPUT_UNCHANGED - - device (torch::Device): The desired CUDA device for the returned Tensors + - device (Device): The desired CUDA device for the returned Tensors Returns: - - output_tensors (std::vector): a vector of Tensors + - output_tensors (std::vector): a vector of Tensors containing the decoded images */ @@ -420,7 +419,7 @@ std::vector CUDAJpegDecoder::decode_images( cudaError_t cudaStatus; cudaStatus = cudaStreamSynchronize(stream); - TORCH_CHECK( + VISION_CHECK( cudaStatus == cudaSuccess, "Failed to synchronize CUDA stream: ", cudaStatus); @@ -438,13 +437,12 @@ std::vector CUDAJpegDecoder::decode_images( std::vector sw_output_buffer; if (hw_decode_available) { - for (std::vector::size_type i = 0; i < encoded_images.size(); - ++i) { + for (size_t i = 0; i < encoded_images.size(); ++i) { // extract bitstream meta data to figure out whether a bit-stream can be // decoded nvjpegJpegStreamParseHeader( nvjpeg_handle, - encoded_images[i].data_ptr(), + encoded_images[i].const_data_ptr(), encoded_images[i].numel(), jpeg_streams[0]); int isSupported = -1; @@ -452,19 +450,18 @@ std::vector CUDAJpegDecoder::decode_images( nvjpeg_handle, jpeg_streams[0], &isSupported); if (isSupported == 0) { - hw_input_buffer.push_back(encoded_images[i].data_ptr()); + hw_input_buffer.push_back(encoded_images[i].const_data_ptr()); hw_input_buffer_size.push_back(encoded_images[i].numel()); hw_output_buffer.push_back(decoded_imgs_buf[i]); } else { - sw_input_buffer.push_back(encoded_images[i].data_ptr()); + sw_input_buffer.push_back(encoded_images[i].const_data_ptr()); sw_input_buffer_size.push_back(encoded_images[i].numel()); sw_output_buffer.push_back(decoded_imgs_buf[i]); } } } else { - for (std::vector::size_type i = 0; i < encoded_images.size(); - ++i) { - sw_input_buffer.push_back(encoded_images[i].data_ptr()); + for (size_t i = 0; i < encoded_images.size(); ++i) { + sw_input_buffer.push_back(encoded_images[i].const_data_ptr()); sw_input_buffer_size.push_back(encoded_images[i].numel()); sw_output_buffer.push_back(decoded_imgs_buf[i]); } @@ -479,7 +476,7 @@ std::vector CUDAJpegDecoder::decode_images( 1, output_format == NVJPEG_OUTPUT_UNCHANGED ? NVJPEG_OUTPUT_RGB : output_format); - TORCH_CHECK( + VISION_CHECK( status == NVJPEG_STATUS_SUCCESS, "Failed to initialize batch decoding: ", status); @@ -491,14 +488,14 @@ std::vector CUDAJpegDecoder::decode_images( hw_input_buffer_size.data(), hw_output_buffer.data(), stream); - TORCH_CHECK( + VISION_CHECK( status == NVJPEG_STATUS_SUCCESS, "Failed to decode batch: ", status); } if (sw_input_buffer.size() > 0) { status = nvjpegStateAttachDeviceBuffer(nvjpeg_decoupled_state, device_buffer); - TORCH_CHECK( + VISION_CHECK( status == NVJPEG_STATUS_SUCCESS, "Failed to attach device buffer: ", status); @@ -508,12 +505,11 @@ std::vector CUDAJpegDecoder::decode_images( nvjpeg_decode_params, output_format == NVJPEG_OUTPUT_UNCHANGED ? NVJPEG_OUTPUT_RGB : output_format); - TORCH_CHECK( + VISION_CHECK( status == NVJPEG_STATUS_SUCCESS, "Failed to set output format: ", status); - for (std::vector::size_type i = 0; i < sw_input_buffer.size(); - ++i) { + for (size_t i = 0; i < sw_input_buffer.size(); ++i) { status = nvjpegJpegStreamParse( nvjpeg_handle, sw_input_buffer[i], @@ -521,14 +517,14 @@ std::vector CUDAJpegDecoder::decode_images( 0, 0, jpeg_streams[buffer_index]); - TORCH_CHECK( + VISION_CHECK( status == NVJPEG_STATUS_SUCCESS, "Failed to parse jpeg stream: ", status); status = nvjpegStateAttachPinnedBuffer( nvjpeg_decoupled_state, pinned_buffers[buffer_index]); - TORCH_CHECK( + VISION_CHECK( status == NVJPEG_STATUS_SUCCESS, "Failed to attach pinned buffer: ", status); @@ -539,13 +535,13 @@ std::vector CUDAJpegDecoder::decode_images( nvjpeg_decoupled_state, nvjpeg_decode_params, jpeg_streams[buffer_index]); - TORCH_CHECK( + VISION_CHECK( status == NVJPEG_STATUS_SUCCESS, "Failed to decode jpeg stream: ", status); cudaStatus = cudaStreamSynchronize(stream); - TORCH_CHECK( + VISION_CHECK( cudaStatus == cudaSuccess, "Failed to synchronize CUDA stream: ", cudaStatus); @@ -556,7 +552,7 @@ std::vector CUDAJpegDecoder::decode_images( nvjpeg_decoupled_state, jpeg_streams[buffer_index], stream); - TORCH_CHECK( + VISION_CHECK( status == NVJPEG_STATUS_SUCCESS, "Failed to transfer jpeg to device: ", status); @@ -570,7 +566,7 @@ std::vector CUDAJpegDecoder::decode_images( nvjpeg_decoupled_state, &sw_output_buffer[i], stream); - TORCH_CHECK( + VISION_CHECK( status == NVJPEG_STATUS_SUCCESS, "Failed to decode jpeg stream: ", status); @@ -578,17 +574,17 @@ std::vector CUDAJpegDecoder::decode_images( } cudaStatus = cudaStreamSynchronize(stream); - TORCH_CHECK( + VISION_CHECK( cudaStatus == cudaSuccess, "Failed to synchronize CUDA stream: ", cudaStatus); // prune extraneous channels from single channel images if (output_format == NVJPEG_OUTPUT_UNCHANGED) { - for (std::vector::size_type i = 0; i < output_tensors.size(); - ++i) { + for (size_t i = 0; i < output_tensors.size(); ++i) { if (channels[i] == 1) { - output_tensors[i] = output_tensors[i][0].unsqueeze(0).clone(); + output_tensors[i] = torch::stable::clone(torch::stable::unsqueeze( + torch::stable::select(output_tensors[i], 0, 0), 0)); } } } diff --git a/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.h b/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.h index 6f72d9e35b2..7797e7ad345 100644 --- a/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.h +++ b/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.h @@ -1,6 +1,6 @@ #pragma once -#include #include +#include "../../../StableABICompat.h" #include "../common.h" #if NVJPEG_FOUND @@ -11,24 +11,24 @@ namespace vision { namespace image { class CUDAJpegDecoder { public: - CUDAJpegDecoder(const torch::Device& target_device); + CUDAJpegDecoder(const vision::stable::Device& target_device); ~CUDAJpegDecoder(); - std::vector decode_images( - const std::vector& encoded_images, + std::vector decode_images( + const std::vector& encoded_images, const nvjpegOutputFormat_t& output_format); - const torch::Device original_device; - const torch::Device target_device; + const vision::stable::Device original_device; + const vision::stable::Device target_device; const c10::cuda::CUDAStream stream; private: std::tuple< std::vector, - std::vector, + std::vector, std::vector> prepare_buffers( - const std::vector& encoded_images, + const std::vector& encoded_images, const nvjpegOutputFormat_t& output_format); nvjpegJpegState_t nvjpeg_state; nvjpegJpegState_t nvjpeg_decoupled_state; diff --git a/torchvision/csrc/io/image/cuda/encode_decode_jpegs_cuda.h b/torchvision/csrc/io/image/cuda/encode_decode_jpegs_cuda.h index 8c3ad8f9a9d..9074e9464f9 100644 --- a/torchvision/csrc/io/image/cuda/encode_decode_jpegs_cuda.h +++ b/torchvision/csrc/io/image/cuda/encode_decode_jpegs_cuda.h @@ -1,6 +1,6 @@ #pragma once -#include +#include "../../../StableABICompat.h" #include "../common.h" #include "decode_jpegs_cuda.h" #include "encode_jpegs_cuda.h" @@ -14,45 +14,45 @@ Fast jpeg decoding with CUDA. A100+ GPUs have dedicated hardware support for jpeg decoding. Args: - - encoded_images (const std::vector&): a vector of tensors - containing the jpeg bitstreams to be decoded. Each tensor must have dtype - torch.uint8 and device cpu + - encoded_images (const std::vector&): a vector of +tensors containing the jpeg bitstreams to be decoded. Each tensor must have +dtype torch.uint8 and device cpu - mode (ImageReadMode): IMAGE_READ_MODE_UNCHANGED, IMAGE_READ_MODE_GRAY and IMAGE_READ_MODE_RGB are supported - - device (torch::Device): The desired CUDA device to run the decoding on and -which will contain the output tensors + - device (vision::stable::Device): The desired CUDA device to run the +decoding on and which will contain the output tensors Returns: - - decoded_images (std::vector): a vector of torch::Tensors of -dtype torch.uint8 on the specified containing the decoded images + - decoded_images (std::vector): a vector of Tensors +of dtype torch.uint8 on the specified containing the decoded images Notes: - If a single image fails, the whole batch fails. - This function is thread-safe */ -C10_EXPORT std::vector decode_jpegs_cuda( - const std::vector& encoded_images, +C10_EXPORT std::vector decode_jpegs_cuda( + const std::vector& encoded_images, vision::image::ImageReadMode mode, - torch::Device device); + vision::stable::Device device); /* Fast jpeg encoding with CUDA. Args: - - decoded_images (const std::vector&): a vector of contiguous -CUDA tensors of dtype torch.uint8 to be encoded. + - decoded_images (const std::vector&): a vector of +contiguous CUDA tensors of dtype torch.uint8 to be encoded. - quality (int64_t): 0-100, 75 is the default Returns: - - encoded_images (std::vector): a vector of CUDA -torch::Tensors of dtype torch.uint8 containing the encoded images + - encoded_images (std::vector): a vector of CUDA +Tensors of dtype torch.uint8 containing the encoded images Notes: - If a single image fails, the whole batch fails. - This function is thread-safe */ -C10_EXPORT std::vector encode_jpegs_cuda( - const std::vector& decoded_images, +C10_EXPORT std::vector encode_jpegs_cuda( + const std::vector& decoded_images, const int64_t quality); } // namespace image diff --git a/torchvision/csrc/io/image/cuda/encode_jpegs_cuda.cpp b/torchvision/csrc/io/image/cuda/encode_jpegs_cuda.cpp index 80accc1a241..a61b12be5bd 100644 --- a/torchvision/csrc/io/image/cuda/encode_jpegs_cuda.cpp +++ b/torchvision/csrc/io/image/cuda/encode_jpegs_cuda.cpp @@ -2,10 +2,13 @@ #if !NVJPEG_FOUND namespace vision { namespace image { -std::vector encode_jpegs_cuda( - const std::vector& decoded_images, + +using namespace vision::stable; + +std::vector encode_jpegs_cuda( + const std::vector& decoded_images, const int64_t quality) { - TORCH_CHECK( + VISION_CHECK( false, "encode_jpegs_cuda: torchvision not compiled with nvJPEG support"); } } // namespace image @@ -17,29 +20,30 @@ std::vector encode_jpegs_cuda( #include #include #include -#include "c10/core/ScalarType.h" namespace vision { namespace image { +using namespace vision::stable; + // We use global variables to cache the encoder and decoder instances and // reuse them across calls to the corresponding pytorch functions std::mutex encoderMutex; std::unique_ptr cudaJpegEncoder; -std::vector encode_jpegs_cuda( - const std::vector& decoded_images, +std::vector encode_jpegs_cuda( + const std::vector& decoded_images, const int64_t quality) { - C10_LOG_API_USAGE_ONCE( - "torchvision.csrc.io.image.cuda.encode_jpegs_cuda.encode_jpegs_cuda"); + // Note: C10_LOG_API_USAGE_ONCE is not available in stable ABI // Some nvjpeg structures are not thread safe so we're keeping it single // threaded for now. In the future this may be an opportunity to unlock // further speedups std::lock_guard lock(encoderMutex); - TORCH_CHECK(decoded_images.size() > 0, "Empty input tensor list"); - torch::Device device = decoded_images[0].device(); - at::cuda::CUDAGuard device_guard(device); + VISION_CHECK(decoded_images.size() > 0, "Empty input tensor list"); + Device device = Device( + decoded_images[0].device().type(), decoded_images[0].device().index()); + at::cuda::CUDAGuard device_guard(at::Device(at::kCUDA, device.index())); // lazy init of the encoder class // the encoder object holds on to a lot of state and is expensive to create, @@ -63,35 +67,36 @@ std::vector encode_jpegs_cuda( std::atexit([]() { delete cudaJpegEncoder.release(); }); } - std::vector contig_images; + std::vector contig_images; contig_images.reserve(decoded_images.size()); for (const auto& image : decoded_images) { - TORCH_CHECK( - image.dtype() == torch::kU8, "Input tensor dtype should be uint8"); + VISION_CHECK( + image.scalar_type() == kByte, "Input tensor dtype should be uint8"); - TORCH_CHECK( - image.device() == device, + VISION_CHECK( + image.device().type() == device.type() && + image.device().index() == device.index(), "All input tensors must be on the same CUDA device when encoding with nvjpeg") - TORCH_CHECK( + VISION_CHECK( image.dim() == 3 && image.numel() > 0, "Input data should be a 3-dimensional tensor"); - TORCH_CHECK( + VISION_CHECK( image.size(0) == 3, "The number of channels should be 3, got: ", image.size(0)); // nvjpeg requires images to be contiguous - if (image.is_contiguous()) { + if (is_contiguous(image)) { contig_images.push_back(image); } else { - contig_images.push_back(image.contiguous()); + contig_images.push_back(torch::stable::contiguous(image)); } } cudaJpegEncoder->set_quality(quality); - std::vector encoded_images; + std::vector encoded_images; for (const auto& image : contig_images) { auto encoded_image = cudaJpegEncoder->encode_jpeg(image); encoded_images.push_back(encoded_image); @@ -110,8 +115,8 @@ std::vector encode_jpegs_cuda( return encoded_images; } -CUDAJpegEncoder::CUDAJpegEncoder(const torch::Device& target_device) - : original_device{torch::kCUDA, torch::cuda::current_device()}, +CUDAJpegEncoder::CUDAJpegEncoder(const Device& target_device) + : original_device{kCUDA, c10::cuda::current_device()}, target_device{target_device}, stream{ target_device.has_index() @@ -123,19 +128,19 @@ CUDAJpegEncoder::CUDAJpegEncoder(const torch::Device& target_device) : at::cuda::getCurrentCUDAStream()} { nvjpegStatus_t status; status = nvjpegCreateSimple(&nvjpeg_handle); - TORCH_CHECK( + VISION_CHECK( status == NVJPEG_STATUS_SUCCESS, "Failed to create nvjpeg handle: ", status); status = nvjpegEncoderStateCreate(nvjpeg_handle, &nv_enc_state, stream); - TORCH_CHECK( + VISION_CHECK( status == NVJPEG_STATUS_SUCCESS, "Failed to create nvjpeg encoder state: ", status); status = nvjpegEncoderParamsCreate(nvjpeg_handle, &nv_enc_params, stream); - TORCH_CHECK( + VISION_CHECK( status == NVJPEG_STATUS_SUCCESS, "Failed to create nvjpeg encoder params: ", status); @@ -166,13 +171,13 @@ CUDAJpegEncoder::~CUDAJpegEncoder() { // nvjpegStatus_t status; // status = nvjpegEncoderParamsDestroy(nv_enc_params); - // TORCH_CHECK( + // VISION_CHECK( // status == NVJPEG_STATUS_SUCCESS, // "Failed to destroy nvjpeg encoder params: ", // status); // status = nvjpegEncoderStateDestroy(nv_enc_state); - // TORCH_CHECK( + // VISION_CHECK( // status == NVJPEG_STATUS_SUCCESS, // "Failed to destroy nvjpeg encoder state: ", // status); @@ -180,17 +185,17 @@ CUDAJpegEncoder::~CUDAJpegEncoder() { // cudaStreamSynchronize(stream); // status = nvjpegDestroy(nvjpeg_handle); - // TORCH_CHECK( + // VISION_CHECK( // status == NVJPEG_STATUS_SUCCESS, "nvjpegDestroy failed: ", status); } -torch::Tensor CUDAJpegEncoder::encode_jpeg(const torch::Tensor& src_image) { +Tensor CUDAJpegEncoder::encode_jpeg(const Tensor& src_image) { nvjpegStatus_t status; cudaError_t cudaStatus; // Ensure that the incoming src_image is safe to use cudaStatus = cudaStreamSynchronize(current_stream); - TORCH_CHECK(cudaStatus == cudaSuccess, "CUDA ERROR: ", cudaStatus); + VISION_CHECK(cudaStatus == cudaSuccess, "CUDA ERROR: ", cudaStatus); int channels = src_image.size(0); int height = src_image.size(1); @@ -198,14 +203,15 @@ torch::Tensor CUDAJpegEncoder::encode_jpeg(const torch::Tensor& src_image) { status = nvjpegEncoderParamsSetSamplingFactors( nv_enc_params, NVJPEG_CSS_444, stream); - TORCH_CHECK( + VISION_CHECK( status == NVJPEG_STATUS_SUCCESS, "Failed to set nvjpeg encoder params sampling factors: ", status); nvjpegImage_t target_image; for (int c = 0; c < channels; c++) { - target_image.channel[c] = src_image[c].data_ptr(); + target_image.channel[c] = + torch::stable::select(src_image, 0, c).mutable_data_ptr(); // this is why we need contiguous tensors target_image.pitch[c] = width; } @@ -224,39 +230,34 @@ torch::Tensor CUDAJpegEncoder::encode_jpeg(const torch::Tensor& src_image) { height, stream); - TORCH_CHECK( + VISION_CHECK( status == NVJPEG_STATUS_SUCCESS, "image encoding failed: ", status); // Retrieve length of the encoded image size_t length; status = nvjpegEncodeRetrieveBitstreamDevice( nvjpeg_handle, nv_enc_state, NULL, &length, stream); - TORCH_CHECK( + VISION_CHECK( status == NVJPEG_STATUS_SUCCESS, "Failed to retrieve encoded image stream state: ", status); // Synchronize the stream to ensure that the encoded image is ready cudaStatus = cudaStreamSynchronize(stream); - TORCH_CHECK(cudaStatus == cudaSuccess, "CUDA ERROR: ", cudaStatus); + VISION_CHECK(cudaStatus == cudaSuccess, "CUDA ERROR: ", cudaStatus); // Reserve buffer for the encoded image - torch::Tensor encoded_image = torch::empty( - {static_cast(length)}, - torch::TensorOptions() - .dtype(torch::kByte) - .layout(torch::kStrided) - .device(target_device) - .requires_grad(false)); + Tensor encoded_image = + empty({static_cast(length)}, kByte, target_device); cudaStatus = cudaStreamSynchronize(stream); - TORCH_CHECK(cudaStatus == cudaSuccess, "CUDA ERROR: ", cudaStatus); + VISION_CHECK(cudaStatus == cudaSuccess, "CUDA ERROR: ", cudaStatus); // Retrieve the encoded image status = nvjpegEncodeRetrieveBitstreamDevice( nvjpeg_handle, nv_enc_state, - encoded_image.data_ptr(), + encoded_image.mutable_data_ptr(), &length, stream); - TORCH_CHECK( + VISION_CHECK( status == NVJPEG_STATUS_SUCCESS, "Failed to retrieve encoded image: ", status); @@ -266,7 +267,7 @@ torch::Tensor CUDAJpegEncoder::encode_jpeg(const torch::Tensor& src_image) { void CUDAJpegEncoder::set_quality(const int64_t quality) { nvjpegStatus_t paramsQualityStatus = nvjpegEncoderParamsSetQuality(nv_enc_params, quality, stream); - TORCH_CHECK( + VISION_CHECK( paramsQualityStatus == NVJPEG_STATUS_SUCCESS, "Failed to set nvjpeg encoder params quality: ", paramsQualityStatus); diff --git a/torchvision/csrc/io/image/cuda/encode_jpegs_cuda.h b/torchvision/csrc/io/image/cuda/encode_jpegs_cuda.h index 6ee0ad91df4..9964374e03a 100644 --- a/torchvision/csrc/io/image/cuda/encode_jpegs_cuda.h +++ b/torchvision/csrc/io/image/cuda/encode_jpegs_cuda.h @@ -1,6 +1,6 @@ #pragma once -#include #include +#include "../../../StableABICompat.h" #if NVJPEG_FOUND #include @@ -12,15 +12,15 @@ namespace image { class CUDAJpegEncoder { public: - CUDAJpegEncoder(const torch::Device& device); + CUDAJpegEncoder(const vision::stable::Device& device); ~CUDAJpegEncoder(); - torch::Tensor encode_jpeg(const torch::Tensor& src_image); + vision::stable::Tensor encode_jpeg(const vision::stable::Tensor& src_image); void set_quality(const int64_t quality); - const torch::Device original_device; - const torch::Device target_device; + const vision::stable::Device original_device; + const vision::stable::Device target_device; const c10::cuda::CUDAStream stream; const c10::cuda::CUDAStream current_stream; diff --git a/torchvision/csrc/io/image/image.cpp b/torchvision/csrc/io/image/image.cpp index b4a4ed54a67..0d0042ca953 100644 --- a/torchvision/csrc/io/image/image.cpp +++ b/torchvision/csrc/io/image/image.cpp @@ -1,29 +1,57 @@ #include "image.h" -#include +#include "../../StableABICompat.h" namespace vision { namespace image { -static auto registry = - torch::RegisterOperators() - .op("image::decode_gif", &decode_gif) - .op("image::decode_png(Tensor data, int mode, bool apply_exif_orientation=False) -> Tensor", - &decode_png) - .op("image::encode_png", &encode_png) - .op("image::decode_jpeg(Tensor data, int mode, bool apply_exif_orientation=False) -> Tensor", - &decode_jpeg) - .op("image::decode_webp(Tensor encoded_data, int mode) -> Tensor", - &decode_webp) - .op("image::encode_jpeg", &encode_jpeg) - .op("image::read_file", &read_file) - .op("image::write_file", &write_file) - .op("image::decode_image(Tensor data, int mode, bool apply_exif_orientation=False) -> Tensor", - &decode_image) - .op("image::decode_jpegs_cuda", &decode_jpegs_cuda) - .op("image::encode_jpegs_cuda", &encode_jpegs_cuda) - .op("image::_jpeg_version", &_jpeg_version) - .op("image::_is_compiled_against_turbo", &_is_compiled_against_turbo); +// Register operators using stable ABI macros +STABLE_TORCH_LIBRARY(image, m) { + m.def("decode_gif(Tensor data) -> Tensor"); + m.def( + "decode_png(Tensor data, int mode, bool apply_exif_orientation=False) -> Tensor"); + m.def("encode_png(Tensor data, int compression_level) -> Tensor"); + m.def( + "decode_jpeg(Tensor data, int mode, bool apply_exif_orientation=False) -> Tensor"); + m.def("decode_webp(Tensor encoded_data, int mode) -> Tensor"); + m.def("encode_jpeg(Tensor data, int quality) -> Tensor"); + m.def("read_file(str filename) -> Tensor"); + m.def("write_file(str filename, Tensor data) -> ()"); + m.def( + "decode_image(Tensor data, int mode, bool apply_exif_orientation=False) -> Tensor"); + // CUDA JPEG ops are disabled when building with stable ABI + // (TORCH_TARGET_VERSION) because the stable ABI doesn't expose raw CUDA + // streams needed by nvJPEG. m.def("decode_jpegs_cuda(Tensor[] encoded_jpegs, + // int mode, Device device) -> Tensor[]"); m.def("encode_jpegs_cuda(Tensor[] + // decoded_jpegs, int quality) -> Tensor[]"); + m.def("_jpeg_version() -> int"); + m.def("_is_compiled_against_turbo() -> bool"); +} + +STABLE_TORCH_LIBRARY_IMPL(image, CPU, m) { + m.impl("decode_gif", TORCH_BOX(&decode_gif)); + m.impl("decode_png", TORCH_BOX(&decode_png)); + m.impl("encode_png", TORCH_BOX(&encode_png)); + m.impl("decode_jpeg", TORCH_BOX(&decode_jpeg)); + m.impl("decode_webp", TORCH_BOX(&decode_webp)); + m.impl("encode_jpeg", TORCH_BOX(&encode_jpeg)); + m.impl("decode_image", TORCH_BOX(&decode_image)); +} + +// Ops without tensor inputs need BackendSelect dispatch +STABLE_TORCH_LIBRARY_IMPL(image, BackendSelect, m) { + m.impl("read_file", TORCH_BOX(&read_file)); + m.impl("write_file", TORCH_BOX(&write_file)); + m.impl("_jpeg_version", TORCH_BOX(&_jpeg_version)); + m.impl("_is_compiled_against_turbo", TORCH_BOX(&_is_compiled_against_turbo)); +} + +// CUDA JPEG is disabled when building with stable ABI (TORCH_TARGET_VERSION) +// because the stable ABI doesn't expose raw CUDA streams needed by nvJPEG. +// STABLE_TORCH_LIBRARY_IMPL(image, CUDA, m) { +// m.impl("decode_jpegs_cuda", TORCH_BOX(&decode_jpegs_cuda)); +// m.impl("encode_jpegs_cuda", TORCH_BOX(&encode_jpegs_cuda)); +// } } // namespace image } // namespace vision diff --git a/torchvision/csrc/io/image/image.h b/torchvision/csrc/io/image/image.h index 3f47fdec65c..35b66d951c6 100644 --- a/torchvision/csrc/io/image/image.h +++ b/torchvision/csrc/io/image/image.h @@ -8,4 +8,6 @@ #include "cpu/encode_jpeg.h" #include "cpu/encode_png.h" #include "cpu/read_write_file.h" -#include "cuda/encode_decode_jpegs_cuda.h" +// CUDA JPEG is disabled when building with stable ABI (TORCH_TARGET_VERSION) +// because the stable ABI doesn't expose raw CUDA streams needed by nvJPEG. +// #include "cuda/encode_decode_jpegs_cuda.h" From 195ca0ffbfb75aa0aa140fd8d9ef6a90c6b84941 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Mon, 26 Jan 2026 14:13:19 +0000 Subject: [PATCH 2/9] Fixed some CUDA stuff --- setup.py | 42 +++---- .../csrc/io/image/cuda/decode_jpegs_cuda.cpp | 113 ++++++------------ .../csrc/io/image/cuda/decode_jpegs_cuda.h | 4 +- .../csrc/io/image/cuda/encode_jpegs_cuda.cpp | 83 ++++++------- .../csrc/io/image/cuda/encode_jpegs_cuda.h | 7 +- torchvision/csrc/io/image/image.cpp | 22 ++-- torchvision/csrc/io/image/image.h | 4 +- 7 files changed, 106 insertions(+), 169 deletions(-) diff --git a/setup.py b/setup.py index 223d2419b7a..257cacc97a2 100644 --- a/setup.py +++ b/setup.py @@ -299,15 +299,13 @@ def make_image_extension(): image_dir = CSRS_DIR / "io/image" sources = list(image_dir.glob("*.cpp")) + list(image_dir.glob("cpu/*.cpp")) + list(image_dir.glob("cpu/giflib/*.c")) - # Note: CUDA sources are excluded when building with stable ABI (TORCH_TARGET_VERSION) - # because the stable ABI doesn't expose raw CUDA streams needed by nvJPEG. - # When stable ABI CUDA support is added to PyTorch, this can be re-enabled. - # if IS_ROCM: - # sources += list(image_dir.glob("hip/*.cpp")) - # # we need to exclude this in favor of the hipified source - # sources.remove(image_dir / "image.cpp") - # else: - # sources += list(image_dir.glob("cuda/*.cpp")) + if BUILD_CUDA_SOURCES: + if IS_ROCM: + sources += list(image_dir.glob("hip/*.cpp")) + # we need to exclude this in favor of the hipified source + sources.remove(image_dir / "image.cpp") + else: + sources += list(image_dir.glob("cuda/*.cpp")) Extension = CppExtension @@ -355,20 +353,18 @@ def make_image_extension(): else: warnings.warn("Building torchvision without WEBP support") - # NVJPEG is disabled when building with stable ABI (TORCH_TARGET_VERSION) - # because the stable ABI doesn't expose raw CUDA streams needed by nvJPEG. - # if USE_NVJPEG and (torch.cuda.is_available() or FORCE_CUDA): - # nvjpeg_found = CUDA_HOME is not None and (Path(CUDA_HOME) / "include/nvjpeg.h").exists() - # - # if nvjpeg_found: - # print("Building torchvision with NVJPEG image support") - # libraries.append("nvjpeg") - # define_macros += [("NVJPEG_FOUND", 1)] - # Extension = CUDAExtension - # else: - # warnings.warn("Building torchvision without NVJPEG support") - # elif USE_NVJPEG: - # warnings.warn("Building torchvision without NVJPEG support") + if USE_NVJPEG and (torch.cuda.is_available() or FORCE_CUDA): + nvjpeg_found = CUDA_HOME is not None and (Path(CUDA_HOME) / "include/nvjpeg.h").exists() + + if nvjpeg_found: + print("Building torchvision with NVJPEG image support") + libraries.append("nvjpeg") + define_macros += [("NVJPEG_FOUND", 1)] + Extension = CUDAExtension + else: + warnings.warn("Building torchvision without NVJPEG support") + elif USE_NVJPEG: + warnings.warn("Building torchvision without NVJPEG support") return Extension( name="torchvision.image", diff --git a/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp b/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp index 5ef3a5ac996..613231ff456 100644 --- a/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp +++ b/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp @@ -16,9 +16,6 @@ std::vector decode_jpegs_cuda( } // namespace vision #else -#include -#include -#include #include #include #include @@ -87,13 +84,20 @@ std::vector decode_jpegs_cuda( // Note: TORCH_WARN_ONCE is not available in stable ABI, so we skip the // warning about nvjpeg memory leaks in CUDA versions < 11.6 - at::cuda::CUDAGuard device_guard(at::Device(at::kCUDA, device.index())); + // Set the target CUDA device + int prev_device; + cudaGetDevice(&prev_device); + int target_device_idx = device.has_index() ? device.index() : prev_device; + cudaSetDevice(target_device_idx); - if (cudaJpegDecoder == nullptr || device != cudaJpegDecoder->target_device) { + // Create a new device with the resolved index for consistency + Device resolved_device(kCUDA, static_cast(target_device_idx)); + + if (cudaJpegDecoder == nullptr || resolved_device != cudaJpegDecoder->target_device) { if (cudaJpegDecoder != nullptr) { - cudaJpegDecoder.reset(new CUDAJpegDecoder(device)); + cudaJpegDecoder.reset(new CUDAJpegDecoder(resolved_device)); } else { - cudaJpegDecoder = std::make_unique(device); + cudaJpegDecoder = std::make_unique(resolved_device); std::atexit([]() { cudaJpegDecoder.reset(); }); } } @@ -120,16 +124,21 @@ std::vector decode_jpegs_cuda( } try { - at::cuda::CUDAEvent event; auto result = cudaJpegDecoder->decode_images(contig_images, output_format); - auto current_stream{ - device.has_index() ? at::cuda::getCurrentCUDAStream( - cudaJpegDecoder->original_device.index()) - : at::cuda::getCurrentCUDAStream()}; - event.record(cudaJpegDecoder->stream); - event.block(current_stream); + + // Synchronize the decoder stream with the current stream using events + cudaEvent_t event; + cudaEventCreate(&event); + cudaEventRecord(event, cudaJpegDecoder->stream); + // Use the default stream for synchronization + cudaStreamWaitEvent(nullptr, event, 0); + cudaEventDestroy(event); + + // Restore original device + cudaSetDevice(prev_device); return result; } catch (const std::exception& e) { + cudaSetDevice(prev_device); if (typeid(e) != typeid(std::runtime_error)) { VISION_CHECK(false, "Error while decoding JPEG images: ", e.what()); } else { @@ -139,12 +148,16 @@ std::vector decode_jpegs_cuda( } CUDAJpegDecoder::CUDAJpegDecoder(const Device& target_device) - : original_device{kCUDA, c10::cuda::current_device()}, + : original_device{kCUDA, []() { + int dev; + cudaGetDevice(&dev); + return static_cast(dev); + }()}, target_device{target_device}, - stream{ - target_device.has_index() - ? at::cuda::getStreamFromPool(false, target_device.index()) - : at::cuda::getStreamFromPool(false)} { + stream{nullptr} { + // Create a CUDA stream for this decoder + cudaStreamCreate(&stream); + nvjpegStatus_t status; hw_decode_available = true; @@ -241,65 +254,9 @@ CUDAJpegDecoder::~CUDAJpegDecoder() { Please send a PR if you have a solution for this problem. */ - // nvjpegStatus_t status; - - // status = nvjpegDecodeParamsDestroy(nvjpeg_decode_params); - // TORCH_CHECK( - // status == NVJPEG_STATUS_SUCCESS, - // "Failed to destroy nvjpeg decode params: ", - // status); - - // status = nvjpegJpegStreamDestroy(jpeg_streams[0]); - // TORCH_CHECK( - // status == NVJPEG_STATUS_SUCCESS, - // "Failed to destroy jpeg stream: ", - // status); - - // status = nvjpegJpegStreamDestroy(jpeg_streams[1]); - // TORCH_CHECK( - // status == NVJPEG_STATUS_SUCCESS, - // "Failed to destroy jpeg stream: ", - // status); - - // status = nvjpegBufferPinnedDestroy(pinned_buffers[0]); - // TORCH_CHECK( - // status == NVJPEG_STATUS_SUCCESS, - // "Failed to destroy pinned buffer[0]: ", - // status); - - // status = nvjpegBufferPinnedDestroy(pinned_buffers[1]); - // TORCH_CHECK( - // status == NVJPEG_STATUS_SUCCESS, - // "Failed to destroy pinned buffer[1]: ", - // status); - - // status = nvjpegBufferDeviceDestroy(device_buffer); - // TORCH_CHECK( - // status == NVJPEG_STATUS_SUCCESS, - // "Failed to destroy device buffer: ", - // status); - - // status = nvjpegJpegStateDestroy(nvjpeg_decoupled_state); - // TORCH_CHECK( - // status == NVJPEG_STATUS_SUCCESS, - // "Failed to destroy nvjpeg decoupled state: ", - // status); - - // status = nvjpegDecoderDestroy(nvjpeg_decoder); - // TORCH_CHECK( - // status == NVJPEG_STATUS_SUCCESS, - // "Failed to destroy nvjpeg decoder: ", - // status); - - // status = nvjpegJpegStateDestroy(nvjpeg_state); - // TORCH_CHECK( - // status == NVJPEG_STATUS_SUCCESS, - // "Failed to destroy nvjpeg state: ", - // status); - - // status = nvjpegDestroy(nvjpeg_handle); - // TORCH_CHECK( - // status == NVJPEG_STATUS_SUCCESS, "nvjpegDestroy failed: ", status); + // if (stream != nullptr) { + // cudaStreamDestroy(stream); + // } } std::tuple, std::vector, std::vector> diff --git a/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.h b/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.h index 7797e7ad345..3274e119279 100644 --- a/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.h +++ b/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.h @@ -4,7 +4,7 @@ #include "../common.h" #if NVJPEG_FOUND -#include +#include #include namespace vision { @@ -20,7 +20,7 @@ class CUDAJpegDecoder { const vision::stable::Device original_device; const vision::stable::Device target_device; - const c10::cuda::CUDAStream stream; + cudaStream_t stream; private: std::tuple< diff --git a/torchvision/csrc/io/image/cuda/encode_jpegs_cuda.cpp b/torchvision/csrc/io/image/cuda/encode_jpegs_cuda.cpp index a61b12be5bd..aa582701b15 100644 --- a/torchvision/csrc/io/image/cuda/encode_jpegs_cuda.cpp +++ b/torchvision/csrc/io/image/cuda/encode_jpegs_cuda.cpp @@ -15,10 +15,10 @@ std::vector encode_jpegs_cuda( } // namespace vision #else -#include -#include +#include #include #include +#include #include namespace vision { @@ -43,18 +43,26 @@ std::vector encode_jpegs_cuda( VISION_CHECK(decoded_images.size() > 0, "Empty input tensor list"); Device device = Device( decoded_images[0].device().type(), decoded_images[0].device().index()); - at::cuda::CUDAGuard device_guard(at::Device(at::kCUDA, device.index())); + + // Set the target CUDA device + int prev_device; + cudaGetDevice(&prev_device); + int target_device_idx = device.has_index() ? device.index() : prev_device; + cudaSetDevice(target_device_idx); + + // Create a device with the resolved index for consistency + Device resolved_device(kCUDA, static_cast(target_device_idx)); // lazy init of the encoder class // the encoder object holds on to a lot of state and is expensive to create, // so we reuse it across calls. NB: the cached structures are device specific // and cannot be reused across devices - if (cudaJpegEncoder == nullptr || device != cudaJpegEncoder->target_device) { + if (cudaJpegEncoder == nullptr || resolved_device != cudaJpegEncoder->target_device) { if (cudaJpegEncoder != nullptr) { delete cudaJpegEncoder.release(); } - cudaJpegEncoder = std::make_unique(device); + cudaJpegEncoder = std::make_unique(resolved_device); // Unfortunately, we cannot rely on the smart pointer releasing the encoder // object correctly upon program exit. This is because, when cudaJpegEncoder @@ -101,8 +109,6 @@ std::vector encode_jpegs_cuda( auto encoded_image = cudaJpegEncoder->encode_jpeg(image); encoded_images.push_back(encoded_image); } - at::cuda::CUDAEvent event; - event.record(cudaJpegEncoder->stream); // We use a dedicated stream to do the encoding and even though the results // may be ready on that stream we cannot assume that they are also available @@ -111,21 +117,31 @@ std::vector encode_jpegs_cuda( // do not want to block the host at this particular point // (which is what cudaStreamSynchronize would do.) Events allow us to // synchronize the streams without blocking the host. - event.block(cudaJpegEncoder->current_stream); + cudaEvent_t event; + cudaEventCreate(&event); + cudaEventRecord(event, cudaJpegEncoder->stream); + cudaStreamWaitEvent(cudaJpegEncoder->current_stream, event, 0); + cudaEventDestroy(event); + + // Restore original device + cudaSetDevice(prev_device); return encoded_images; } CUDAJpegEncoder::CUDAJpegEncoder(const Device& target_device) - : original_device{kCUDA, c10::cuda::current_device()}, + : original_device{kCUDA, []() { + int dev; + cudaGetDevice(&dev); + return static_cast(dev); + }()}, target_device{target_device}, - stream{ - target_device.has_index() - ? at::cuda::getStreamFromPool(false, target_device.index()) - : at::cuda::getStreamFromPool(false)}, - current_stream{ - original_device.has_index() - ? at::cuda::getCurrentCUDAStream(original_device.index()) - : at::cuda::getCurrentCUDAStream()} { + stream{nullptr}, + current_stream{nullptr} { + // Create CUDA streams + cudaStreamCreate(&stream); + // Get the default stream (nullptr represents the default stream) + current_stream = nullptr; + nvjpegStatus_t status; status = nvjpegCreateSimple(&nvjpeg_handle); VISION_CHECK( @@ -157,36 +173,9 @@ CUDAJpegEncoder::~CUDAJpegEncoder() { Please send a PR if you have a solution for this problem. */ - // // We run cudaGetDeviceCount as a dummy to test if the CUDA runtime is - // still - // // initialized. If it is not, we can skip the rest of this function as it - // is - // // unsafe to execute. - // int deviceCount = 0; - // cudaError_t error = cudaGetDeviceCount(&deviceCount); - // if (error != cudaSuccess) - // return; // CUDA runtime has already shut down. There's nothing we can do - // // now. - - // nvjpegStatus_t status; - - // status = nvjpegEncoderParamsDestroy(nv_enc_params); - // VISION_CHECK( - // status == NVJPEG_STATUS_SUCCESS, - // "Failed to destroy nvjpeg encoder params: ", - // status); - - // status = nvjpegEncoderStateDestroy(nv_enc_state); - // VISION_CHECK( - // status == NVJPEG_STATUS_SUCCESS, - // "Failed to destroy nvjpeg encoder state: ", - // status); - - // cudaStreamSynchronize(stream); - - // status = nvjpegDestroy(nvjpeg_handle); - // VISION_CHECK( - // status == NVJPEG_STATUS_SUCCESS, "nvjpegDestroy failed: ", status); + // if (stream != nullptr) { + // cudaStreamDestroy(stream); + // } } Tensor CUDAJpegEncoder::encode_jpeg(const Tensor& src_image) { diff --git a/torchvision/csrc/io/image/cuda/encode_jpegs_cuda.h b/torchvision/csrc/io/image/cuda/encode_jpegs_cuda.h index 9964374e03a..020e1646340 100644 --- a/torchvision/csrc/io/image/cuda/encode_jpegs_cuda.h +++ b/torchvision/csrc/io/image/cuda/encode_jpegs_cuda.h @@ -3,8 +3,7 @@ #include "../../../StableABICompat.h" #if NVJPEG_FOUND -#include -#include +#include #include namespace vision { @@ -21,8 +20,8 @@ class CUDAJpegEncoder { const vision::stable::Device original_device; const vision::stable::Device target_device; - const c10::cuda::CUDAStream stream; - const c10::cuda::CUDAStream current_stream; + cudaStream_t stream; + cudaStream_t current_stream; protected: nvjpegEncoderState_t nv_enc_state; diff --git a/torchvision/csrc/io/image/image.cpp b/torchvision/csrc/io/image/image.cpp index 0d0042ca953..98df109f11d 100644 --- a/torchvision/csrc/io/image/image.cpp +++ b/torchvision/csrc/io/image/image.cpp @@ -19,11 +19,9 @@ STABLE_TORCH_LIBRARY(image, m) { m.def("write_file(str filename, Tensor data) -> ()"); m.def( "decode_image(Tensor data, int mode, bool apply_exif_orientation=False) -> Tensor"); - // CUDA JPEG ops are disabled when building with stable ABI - // (TORCH_TARGET_VERSION) because the stable ABI doesn't expose raw CUDA - // streams needed by nvJPEG. m.def("decode_jpegs_cuda(Tensor[] encoded_jpegs, - // int mode, Device device) -> Tensor[]"); m.def("encode_jpegs_cuda(Tensor[] - // decoded_jpegs, int quality) -> Tensor[]"); + m.def( + "decode_jpegs_cuda(Tensor[] encoded_jpegs, int mode, Device device) -> Tensor[]"); + m.def("encode_jpegs_cuda(Tensor[] decoded_jpegs, int quality) -> Tensor[]"); m.def("_jpeg_version() -> int"); m.def("_is_compiled_against_turbo() -> bool"); } @@ -38,20 +36,20 @@ STABLE_TORCH_LIBRARY_IMPL(image, CPU, m) { m.impl("decode_image", TORCH_BOX(&decode_image)); } -// Ops without tensor inputs need BackendSelect dispatch +// Ops without tensor inputs or with cross-device semantics need BackendSelect dispatch STABLE_TORCH_LIBRARY_IMPL(image, BackendSelect, m) { m.impl("read_file", TORCH_BOX(&read_file)); m.impl("write_file", TORCH_BOX(&write_file)); m.impl("_jpeg_version", TORCH_BOX(&_jpeg_version)); m.impl("_is_compiled_against_turbo", TORCH_BOX(&_is_compiled_against_turbo)); + // decode_jpegs_cuda takes CPU tensors as input but outputs CUDA tensors + m.impl("decode_jpegs_cuda", TORCH_BOX(&decode_jpegs_cuda)); } -// CUDA JPEG is disabled when building with stable ABI (TORCH_TARGET_VERSION) -// because the stable ABI doesn't expose raw CUDA streams needed by nvJPEG. -// STABLE_TORCH_LIBRARY_IMPL(image, CUDA, m) { -// m.impl("decode_jpegs_cuda", TORCH_BOX(&decode_jpegs_cuda)); -// m.impl("encode_jpegs_cuda", TORCH_BOX(&encode_jpegs_cuda)); -// } +STABLE_TORCH_LIBRARY_IMPL(image, CUDA, m) { + // encode_jpegs_cuda takes CUDA tensors as input + m.impl("encode_jpegs_cuda", TORCH_BOX(&encode_jpegs_cuda)); +} } // namespace image } // namespace vision diff --git a/torchvision/csrc/io/image/image.h b/torchvision/csrc/io/image/image.h index 35b66d951c6..3f47fdec65c 100644 --- a/torchvision/csrc/io/image/image.h +++ b/torchvision/csrc/io/image/image.h @@ -8,6 +8,4 @@ #include "cpu/encode_jpeg.h" #include "cpu/encode_png.h" #include "cpu/read_write_file.h" -// CUDA JPEG is disabled when building with stable ABI (TORCH_TARGET_VERSION) -// because the stable ABI doesn't expose raw CUDA streams needed by nvJPEG. -// #include "cuda/encode_decode_jpegs_cuda.h" +#include "cuda/encode_decode_jpegs_cuda.h" From f881a29672d89e6abfeff5a126575a912a33a7ca Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Mon, 26 Jan 2026 14:55:46 +0000 Subject: [PATCH 3/9] Fixed stuff --- torchvision/csrc/StableABICompat.h | 16 +++++++--------- .../csrc/io/image/cuda/decode_jpegs_cuda.cpp | 6 +----- torchvision/csrc/io/image/image.cpp | 7 +++++-- 3 files changed, 13 insertions(+), 16 deletions(-) diff --git a/torchvision/csrc/StableABICompat.h b/torchvision/csrc/StableABICompat.h index 31eb5b1cd74..74c9886ae7d 100644 --- a/torchvision/csrc/StableABICompat.h +++ b/torchvision/csrc/StableABICompat.h @@ -448,15 +448,13 @@ inline Tensor clone(const Tensor& tensor) { // Create empty tensor with ChannelsLast memory format inline Tensor emptyCPUChannelsLast(const std::vector& sizes, ScalarType dtype) { - std::array stack{ - torch::stable::detail::from(IntArrayRef(sizes.data(), sizes.size())), - torch::stable::detail::from(dtype), - torch::stable::detail::from(kStrided), - torch::stable::detail::from(Device(kCPU)), - torch::stable::detail::from(MemoryFormat::ChannelsLast)}; - TORCH_ERROR_CODE_CHECK( - torch_call_dispatcher("aten::empty", "memory_format", stack.data(), TORCH_ABI_VERSION)); - return torch::stable::detail::to(stack[0]); + return torch::stable::empty( + IntArrayRef(sizes.data(), sizes.size()), + dtype, + kStrided, + Device(kCPU), + std::nullopt, // pin_memory + MemoryFormat::ChannelsLast); } // =========================================================================== diff --git a/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp b/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp index 613231ff456..6159a2ac9af 100644 --- a/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp +++ b/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp @@ -139,11 +139,7 @@ std::vector decode_jpegs_cuda( return result; } catch (const std::exception& e) { cudaSetDevice(prev_device); - if (typeid(e) != typeid(std::runtime_error)) { - VISION_CHECK(false, "Error while decoding JPEG images: ", e.what()); - } else { - throw; - } + throw std::runtime_error(std::string("Error while decoding JPEG images: ") + e.what()); } } diff --git a/torchvision/csrc/io/image/image.cpp b/torchvision/csrc/io/image/image.cpp index 98df109f11d..ebfecabbe2d 100644 --- a/torchvision/csrc/io/image/image.cpp +++ b/torchvision/csrc/io/image/image.cpp @@ -29,14 +29,14 @@ STABLE_TORCH_LIBRARY(image, m) { STABLE_TORCH_LIBRARY_IMPL(image, CPU, m) { m.impl("decode_gif", TORCH_BOX(&decode_gif)); m.impl("decode_png", TORCH_BOX(&decode_png)); - m.impl("encode_png", TORCH_BOX(&encode_png)); m.impl("decode_jpeg", TORCH_BOX(&decode_jpeg)); m.impl("decode_webp", TORCH_BOX(&decode_webp)); - m.impl("encode_jpeg", TORCH_BOX(&encode_jpeg)); m.impl("decode_image", TORCH_BOX(&decode_image)); } // Ops without tensor inputs or with cross-device semantics need BackendSelect dispatch +// encode_jpeg/encode_png also use BackendSelect so they can give proper error messages +// when CUDA tensors are passed (instead of "no kernel for CUDA") STABLE_TORCH_LIBRARY_IMPL(image, BackendSelect, m) { m.impl("read_file", TORCH_BOX(&read_file)); m.impl("write_file", TORCH_BOX(&write_file)); @@ -44,6 +44,9 @@ STABLE_TORCH_LIBRARY_IMPL(image, BackendSelect, m) { m.impl("_is_compiled_against_turbo", TORCH_BOX(&_is_compiled_against_turbo)); // decode_jpegs_cuda takes CPU tensors as input but outputs CUDA tensors m.impl("decode_jpegs_cuda", TORCH_BOX(&decode_jpegs_cuda)); + // encode functions need BackendSelect to provide proper error messages for CUDA inputs + m.impl("encode_png", TORCH_BOX(&encode_png)); + m.impl("encode_jpeg", TORCH_BOX(&encode_jpeg)); } STABLE_TORCH_LIBRARY_IMPL(image, CUDA, m) { From d62994ae20f1681581ea94c0a0647e11b629894a Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Mon, 26 Jan 2026 14:58:39 +0000 Subject: [PATCH 4/9] Fix? --- torchvision/csrc/io/image/cpu/decode_png.cpp | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/torchvision/csrc/io/image/cpu/decode_png.cpp b/torchvision/csrc/io/image/cpu/decode_png.cpp index 4db0eb36b75..583a24fa679 100644 --- a/torchvision/csrc/io/image/cpu/decode_png.cpp +++ b/torchvision/csrc/io/image/cpu/decode_png.cpp @@ -200,13 +200,19 @@ Tensor decode_png( if (is_little_endian()) { png_set_swap(png_ptr); } - auto t_ptr = (uint8_t*)tensor.mutable_data_ptr(); + // Get raw pointer - for 16-bit images we get uint16_t* and cast to uint8_t* + auto t_ptr = is_16_bits + ? (uint8_t*)tensor.mutable_data_ptr() + : tensor.mutable_data_ptr(); for (int pass = 0; pass < number_of_passes; pass++) { for (png_uint_32 i = 0; i < height; ++i) { png_read_row(png_ptr, t_ptr, nullptr); t_ptr += num_pixels_per_row * (is_16_bits ? 2 : 1); } - t_ptr = (uint8_t*)tensor.mutable_data_ptr(); + // Reset pointer - for 16-bit images we get uint16_t* and cast to uint8_t* + t_ptr = is_16_bits + ? (uint8_t*)tensor.mutable_data_ptr() + : tensor.mutable_data_ptr(); } int exif_orientation = -1; From 0ec9815cbad668d7f992935ac632e394b5962190 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Mon, 26 Jan 2026 19:31:48 +0000 Subject: [PATCH 5/9] Fix CPU? --- setup.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/setup.py b/setup.py index 257cacc97a2..f7ae5f28ab3 100644 --- a/setup.py +++ b/setup.py @@ -299,13 +299,13 @@ def make_image_extension(): image_dir = CSRS_DIR / "io/image" sources = list(image_dir.glob("*.cpp")) + list(image_dir.glob("cpu/*.cpp")) + list(image_dir.glob("cpu/giflib/*.c")) - if BUILD_CUDA_SOURCES: - if IS_ROCM: - sources += list(image_dir.glob("hip/*.cpp")) - # we need to exclude this in favor of the hipified source - sources.remove(image_dir / "image.cpp") - else: - sources += list(image_dir.glob("cuda/*.cpp")) + # Always include CUDA sources - they have stubs when NVJPEG_FOUND is not defined + if IS_ROCM: + sources += list(image_dir.glob("hip/*.cpp")) + # we need to exclude this in favor of the hipified source + sources.remove(image_dir / "image.cpp") + else: + sources += list(image_dir.glob("cuda/*.cpp")) Extension = CppExtension From 047367a84f8598b8dc3973307fb3b3e94f605551 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 28 Jan 2026 17:30:14 +0000 Subject: [PATCH 6/9] Disable torchscript tests for write_png, write_file and write_jpeg. It doesn't work anymore. --- test/test_image.py | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/test/test_image.py b/test/test_image.py index b11dd67ca12..06e0b760bc2 100644 --- a/test/test_image.py +++ b/test/test_image.py @@ -272,16 +272,14 @@ def test_encode_png_errors(): "img_path", [pytest.param(png_path, id=_get_safe_image_name(png_path)) for png_path in get_images(IMAGE_DIR, ".png")], ) -@pytest.mark.parametrize("scripted", (True, False)) -def test_write_png(img_path, tmpdir, scripted): +def test_write_png(img_path, tmpdir): pil_image = Image.open(img_path) img_pil = torch.from_numpy(np.array(pil_image)) img_pil = img_pil.permute(2, 0, 1) filename, _ = os.path.splitext(os.path.basename(img_path)) torch_png = os.path.join(tmpdir, f"{filename}_torch.png") - write = torch.jit.script(write_png) if scripted else write_png - write(img_pil, torch_png, compression_level=6) + write_png(img_pil, torch_png, compression_level=6) saved_image = torch.from_numpy(np.array(Image.open(torch_png))) saved_image = saved_image.permute(2, 0, 1) @@ -325,13 +323,11 @@ def test_read_file_non_ascii(tmpdir): assert_equal(data, expected) -@pytest.mark.parametrize("scripted", (True, False)) -def test_write_file(tmpdir, scripted): +def test_write_file(tmpdir): fname, content = "test1.bin", b"TorchVision\211\n" fpath = os.path.join(tmpdir, fname) content_tensor = torch.tensor(list(content), dtype=torch.uint8) - write = torch.jit.script(write_file) if scripted else write_file - write(fpath, content_tensor) + write_file(fpath, content_tensor) with open(fpath, "rb") as f: saved_content = f.read() @@ -808,8 +804,7 @@ def test_batch_encode_jpegs_cuda_errors(): "img_path", [pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(ENCODE_JPEG, ".jpg")], ) -@pytest.mark.parametrize("scripted", (True, False)) -def test_write_jpeg(img_path, tmpdir, scripted): +def test_write_jpeg(img_path, tmpdir): tmpdir = Path(tmpdir) img = read_image(img_path) pil_img = F.to_pil_image(img) @@ -817,8 +812,7 @@ def test_write_jpeg(img_path, tmpdir, scripted): torch_jpeg = str(tmpdir / "torch.jpg") pil_jpeg = str(tmpdir / "pil.jpg") - write = torch.jit.script(write_jpeg) if scripted else write_jpeg - write(img, torch_jpeg, quality=75) + write_jpeg(img, torch_jpeg, quality=75) pil_img.save(pil_jpeg, quality=75) with open(torch_jpeg, "rb") as f: From 28e3401e06a3c9479761a2e670c8c6a96f97f18b Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 28 Jan 2026 17:34:54 +0000 Subject: [PATCH 7/9] Lint + remove graysacle warning for webp. We con't have TORCH_WARN_ONCE yet --- test/test_image.py | 1 + torchvision/csrc/StableABICompat.h | 140 ++++++++---------- torchvision/csrc/io/image/cpu/decode_png.cpp | 10 +- torchvision/csrc/io/image/cpu/decode_webp.cpp | 9 +- .../csrc/io/image/cuda/decode_jpegs_cuda.cpp | 6 +- .../csrc/io/image/cuda/encode_jpegs_cuda.cpp | 3 +- torchvision/csrc/io/image/image.cpp | 10 +- 7 files changed, 88 insertions(+), 91 deletions(-) diff --git a/test/test_image.py b/test/test_image.py index 06e0b760bc2..e6f34790626 100644 --- a/test/test_image.py +++ b/test/test_image.py @@ -929,6 +929,7 @@ def test_decode_webp(decode_fun, scripted): img += 123 # make sure image buffer wasn't freed by underlying decoding lib +@pytest.mark.skip(reason="TODO_STABLE_ABI: need TORCH_WARN_ONCE") @pytest.mark.parametrize("decode_fun", (decode_webp, decode_image)) def test_decode_webp_grayscale(decode_fun, capfd): encoded_bytes = read_file(next(get_images(FAKEDATA_DIR, ".webp"))) diff --git a/torchvision/csrc/StableABICompat.h b/torchvision/csrc/StableABICompat.h index 74c9886ae7d..26c5c194e6c 100644 --- a/torchvision/csrc/StableABICompat.h +++ b/torchvision/csrc/StableABICompat.h @@ -107,10 +107,7 @@ inline Tensor empty( Device device) { std::vector sizesVec(sizes); return torch::stable::empty( - IntArrayRef(sizesVec.data(), sizesVec.size()), - dtype, - kStrided, - device); + IntArrayRef(sizesVec.data(), sizesVec.size()), dtype, kStrided, device); } // Overload taking a vector @@ -119,16 +116,11 @@ inline Tensor empty( ScalarType dtype, Device device) { return torch::stable::empty( - IntArrayRef(sizes.data(), sizes.size()), - dtype, - kStrided, - device); + IntArrayRef(sizes.data(), sizes.size()), dtype, kStrided, device); } // Helper to create CPU tensors -inline Tensor emptyCPU( - std::initializer_list sizes, - ScalarType dtype) { +inline Tensor emptyCPU(std::initializer_list sizes, ScalarType dtype) { return empty(sizes, dtype, Device(kCPU)); } @@ -139,14 +131,11 @@ inline Tensor zeros( Device device) { std::vector sizesVec(sizes); auto tensor = torch::stable::empty( - IntArrayRef(sizesVec.data(), sizesVec.size()), - dtype, - kStrided, - device); + IntArrayRef(sizesVec.data(), sizesVec.size()), dtype, kStrided, device); // Use dispatcher to call aten::zero_ std::array stack{torch::stable::detail::from(tensor)}; - TORCH_ERROR_CODE_CHECK( - torch_call_dispatcher("aten::zero_", "", stack.data(), TORCH_ABI_VERSION)); + TORCH_ERROR_CODE_CHECK(torch_call_dispatcher( + "aten::zero_", "", stack.data(), TORCH_ABI_VERSION)); return torch::stable::detail::to(stack[0]); } @@ -165,13 +154,18 @@ inline Tensor to(const Tensor& tensor, const Device& device) { } // Stable version of tensor.narrow(dim, start, length) -inline Tensor narrow(Tensor tensor, int64_t dim, int64_t start, int64_t length) { +inline Tensor narrow( + Tensor tensor, + int64_t dim, + int64_t start, + int64_t length) { return torch::stable::narrow(tensor, dim, start, length); } // Note: contiguous() is provided by torch::stable::contiguous() directly // Do NOT define a vision::stable::contiguous wrapper as it conflicts with the -// default parameter in torch::stable::contiguous(tensor, memory_format = Contiguous) +// default parameter in torch::stable::contiguous(tensor, memory_format = +// Contiguous) // Stable version of tensor.select(dim, index) - from torch::stable::select inline Tensor select(const Tensor& tensor, int64_t dim, int64_t index) { @@ -195,7 +189,7 @@ inline std::pair sort( bool descending) { std::array stack{ torch::stable::detail::from(tensor), - torch::stable::detail::from(true), // stable sort + torch::stable::detail::from(true), // stable sort torch::stable::detail::from(dim), torch::stable::detail::from(descending)}; TORCH_ERROR_CODE_CHECK(torch_call_dispatcher( @@ -218,42 +212,37 @@ inline Tensor permute( std::vector dimsVec(dims); std::array stack{ torch::stable::detail::from(tensor), - torch::stable::detail::from( - IntArrayRef(dimsVec.data(), dimsVec.size()))}; - TORCH_ERROR_CODE_CHECK( - torch_call_dispatcher("aten::permute", "", stack.data(), TORCH_ABI_VERSION)); + torch::stable::detail::from(IntArrayRef(dimsVec.data(), dimsVec.size()))}; + TORCH_ERROR_CODE_CHECK(torch_call_dispatcher( + "aten::permute", "", stack.data(), TORCH_ABI_VERSION)); return torch::stable::detail::to(stack[0]); } // Stable version of at::cat() - concatenates tensors along a dimension inline Tensor cat(const std::vector& tensors, int64_t dim = 0) { std::array stack{ - torch::stable::detail::from(tensors), - torch::stable::detail::from(dim)}; + torch::stable::detail::from(tensors), torch::stable::detail::from(dim)}; TORCH_ERROR_CODE_CHECK( torch_call_dispatcher("aten::cat", "", stack.data(), TORCH_ABI_VERSION)); return torch::stable::detail::to(stack[0]); } // Stable version of at::clamp() -inline Tensor clamp( - const Tensor& tensor, - double min_val, - double max_val) { +inline Tensor clamp(const Tensor& tensor, double min_val, double max_val) { std::array stack{ torch::stable::detail::from(tensor), torch::stable::detail::from(min_val), torch::stable::detail::from(max_val)}; - TORCH_ERROR_CODE_CHECK( - torch_call_dispatcher("aten::clamp", "", stack.data(), TORCH_ABI_VERSION)); + TORCH_ERROR_CODE_CHECK(torch_call_dispatcher( + "aten::clamp", "", stack.data(), TORCH_ABI_VERSION)); return torch::stable::detail::to(stack[0]); } // Stable version of at::floor() inline Tensor floor(const Tensor& tensor) { std::array stack{torch::stable::detail::from(tensor)}; - TORCH_ERROR_CODE_CHECK( - torch_call_dispatcher("aten::floor", "", stack.data(), TORCH_ABI_VERSION)); + TORCH_ERROR_CODE_CHECK(torch_call_dispatcher( + "aten::floor", "", stack.data(), TORCH_ABI_VERSION)); return torch::stable::detail::to(stack[0]); } @@ -270,8 +259,8 @@ inline Tensor reshape(const Tensor& tensor, const std::vector& shape) { std::array stack{ torch::stable::detail::from(tensor), torch::stable::detail::from(IntArrayRef(shape.data(), shape.size()))}; - TORCH_ERROR_CODE_CHECK( - torch_call_dispatcher("aten::reshape", "", stack.data(), TORCH_ABI_VERSION)); + TORCH_ERROR_CODE_CHECK(torch_call_dispatcher( + "aten::reshape", "", stack.data(), TORCH_ABI_VERSION)); return torch::stable::detail::to(stack[0]); } @@ -291,8 +280,8 @@ inline Tensor flatten(const Tensor& tensor, int64_t start_dim) { torch::stable::detail::from(tensor), torch::stable::detail::from(start_dim), torch::stable::detail::from(static_cast(-1))}; - TORCH_ERROR_CODE_CHECK( - torch_call_dispatcher("aten::flatten", "using_ints", stack.data(), TORCH_ABI_VERSION)); + TORCH_ERROR_CODE_CHECK(torch_call_dispatcher( + "aten::flatten", "using_ints", stack.data(), TORCH_ABI_VERSION)); return torch::stable::detail::to(stack[0]); } @@ -301,8 +290,8 @@ inline Tensor flatten(const Tensor& tensor, int64_t start_dim) { // Stable version of tensor.zero_() inline Tensor& zero_(Tensor& tensor) { std::array stack{torch::stable::detail::from(tensor)}; - TORCH_ERROR_CODE_CHECK( - torch_call_dispatcher("aten::zero_", "", stack.data(), TORCH_ABI_VERSION)); + TORCH_ERROR_CODE_CHECK(torch_call_dispatcher( + "aten::zero_", "", stack.data(), TORCH_ABI_VERSION)); tensor = torch::stable::detail::to(stack[0]); return tensor; } @@ -313,10 +302,10 @@ inline Tensor& addmm_(Tensor& self, const Tensor& mat1, const Tensor& mat2) { torch::stable::detail::from(self), torch::stable::detail::from(mat1), torch::stable::detail::from(mat2), - torch::stable::detail::from(1.0), // beta + torch::stable::detail::from(1.0), // beta torch::stable::detail::from(1.0)}; // alpha - TORCH_ERROR_CODE_CHECK( - torch_call_dispatcher("aten::addmm_", "", stack.data(), TORCH_ABI_VERSION)); + TORCH_ERROR_CODE_CHECK(torch_call_dispatcher( + "aten::addmm_", "", stack.data(), TORCH_ABI_VERSION)); self = torch::stable::detail::to(stack[0]); return self; } @@ -327,13 +316,10 @@ inline Tensor zeros( ScalarType dtype, Device device) { auto tensor = torch::stable::empty( - IntArrayRef(sizes.data(), sizes.size()), - dtype, - kStrided, - device); + IntArrayRef(sizes.data(), sizes.size()), dtype, kStrided, device); std::array stack{torch::stable::detail::from(tensor)}; - TORCH_ERROR_CODE_CHECK( - torch_call_dispatcher("aten::zero_", "", stack.data(), TORCH_ABI_VERSION)); + TORCH_ERROR_CODE_CHECK(torch_call_dispatcher( + "aten::zero_", "", stack.data(), TORCH_ABI_VERSION)); return torch::stable::detail::to(stack[0]); } @@ -341,18 +327,17 @@ inline Tensor zeros( inline Tensor zeros_like(const Tensor& tensor) { // Use dispatcher to call aten::zeros_like std::array stack{torch::stable::detail::from(tensor)}; - TORCH_ERROR_CODE_CHECK( - torch_call_dispatcher("aten::zeros_like", "", stack.data(), TORCH_ABI_VERSION)); + TORCH_ERROR_CODE_CHECK(torch_call_dispatcher( + "aten::zeros_like", "", stack.data(), TORCH_ABI_VERSION)); return torch::stable::detail::to(stack[0]); } // Stable version of at::ones_like() * val inline Tensor ones_like_times(const Tensor& tensor, const Tensor& val) { std::array stack{ - torch::stable::detail::from(tensor), - torch::stable::detail::from(val)}; - TORCH_ERROR_CODE_CHECK( - torch_call_dispatcher("aten::mul", "Tensor", stack.data(), TORCH_ABI_VERSION)); + torch::stable::detail::from(tensor), torch::stable::detail::from(val)}; + TORCH_ERROR_CODE_CHECK(torch_call_dispatcher( + "aten::mul", "Tensor", stack.data(), TORCH_ABI_VERSION)); return torch::stable::detail::to(stack[0]); } @@ -362,8 +347,8 @@ inline Tensor sum(const Tensor& tensor, const std::vector& dims) { torch::stable::detail::from(tensor), torch::stable::detail::from(IntArrayRef(dims.data(), dims.size())), torch::stable::detail::from(false)}; // keepdim - TORCH_ERROR_CODE_CHECK( - torch_call_dispatcher("aten::sum", "dim_IntList", stack.data(), TORCH_ABI_VERSION)); + TORCH_ERROR_CODE_CHECK(torch_call_dispatcher( + "aten::sum", "dim_IntList", stack.data(), TORCH_ABI_VERSION)); return torch::stable::detail::to(stack[0]); } @@ -373,29 +358,31 @@ inline Tensor add(const Tensor& self, const Tensor& other) { torch::stable::detail::from(self), torch::stable::detail::from(other), torch::stable::detail::from(1.0)}; // alpha - TORCH_ERROR_CODE_CHECK( - torch_call_dispatcher("aten::add", "Tensor", stack.data(), TORCH_ABI_VERSION)); + TORCH_ERROR_CODE_CHECK(torch_call_dispatcher( + "aten::add", "Tensor", stack.data(), TORCH_ABI_VERSION)); return torch::stable::detail::to(stack[0]); } // Stable version of tensor.index_select(dim, index) -inline Tensor index_select(const Tensor& tensor, int64_t dim, const Tensor& index) { +inline Tensor index_select( + const Tensor& tensor, + int64_t dim, + const Tensor& index) { std::array stack{ torch::stable::detail::from(tensor), torch::stable::detail::from(dim), torch::stable::detail::from(index)}; - TORCH_ERROR_CODE_CHECK( - torch_call_dispatcher("aten::index_select", "", stack.data(), TORCH_ABI_VERSION)); + TORCH_ERROR_CODE_CHECK(torch_call_dispatcher( + "aten::index_select", "", stack.data(), TORCH_ABI_VERSION)); return torch::stable::detail::to(stack[0]); } // Stable version of tensor.masked_select(mask) inline Tensor masked_select(const Tensor& tensor, const Tensor& mask) { std::array stack{ - torch::stable::detail::from(tensor), - torch::stable::detail::from(mask)}; - TORCH_ERROR_CODE_CHECK( - torch_call_dispatcher("aten::masked_select", "", stack.data(), TORCH_ABI_VERSION)); + torch::stable::detail::from(tensor), torch::stable::detail::from(mask)}; + TORCH_ERROR_CODE_CHECK(torch_call_dispatcher( + "aten::masked_select", "", stack.data(), TORCH_ABI_VERSION)); return torch::stable::detail::to(stack[0]); } @@ -421,8 +408,8 @@ inline Tensor from_file( torch::stable::detail::from(shared), torch::stable::detail::from(size), torch::stable::detail::from(dtype)}; - TORCH_ERROR_CODE_CHECK( - torch_call_dispatcher("aten::from_file", "", stack.data(), TORCH_ABI_VERSION)); + TORCH_ERROR_CODE_CHECK(torch_call_dispatcher( + "aten::from_file", "", stack.data(), TORCH_ABI_VERSION)); return torch::stable::detail::to(stack[0]); } @@ -431,29 +418,30 @@ inline Tensor from_file( // Stable version of tensor.unsqueeze(dim) inline Tensor unsqueeze(const Tensor& tensor, int64_t dim) { std::array stack{ - torch::stable::detail::from(tensor), - torch::stable::detail::from(dim)}; - TORCH_ERROR_CODE_CHECK( - torch_call_dispatcher("aten::unsqueeze", "", stack.data(), TORCH_ABI_VERSION)); + torch::stable::detail::from(tensor), torch::stable::detail::from(dim)}; + TORCH_ERROR_CODE_CHECK(torch_call_dispatcher( + "aten::unsqueeze", "", stack.data(), TORCH_ABI_VERSION)); return torch::stable::detail::to(stack[0]); } // Stable version of tensor.clone() inline Tensor clone(const Tensor& tensor) { std::array stack{torch::stable::detail::from(tensor)}; - TORCH_ERROR_CODE_CHECK( - torch_call_dispatcher("aten::clone", "", stack.data(), TORCH_ABI_VERSION)); + TORCH_ERROR_CODE_CHECK(torch_call_dispatcher( + "aten::clone", "", stack.data(), TORCH_ABI_VERSION)); return torch::stable::detail::to(stack[0]); } // Create empty tensor with ChannelsLast memory format -inline Tensor emptyCPUChannelsLast(const std::vector& sizes, ScalarType dtype) { +inline Tensor emptyCPUChannelsLast( + const std::vector& sizes, + ScalarType dtype) { return torch::stable::empty( IntArrayRef(sizes.data(), sizes.size()), dtype, kStrided, Device(kCPU), - std::nullopt, // pin_memory + std::nullopt, // pin_memory MemoryFormat::ChannelsLast); } diff --git a/torchvision/csrc/io/image/cpu/decode_png.cpp b/torchvision/csrc/io/image/cpu/decode_png.cpp index 583a24fa679..d7afab8a347 100644 --- a/torchvision/csrc/io/image/cpu/decode_png.cpp +++ b/torchvision/csrc/io/image/cpu/decode_png.cpp @@ -201,18 +201,16 @@ Tensor decode_png( png_set_swap(png_ptr); } // Get raw pointer - for 16-bit images we get uint16_t* and cast to uint8_t* - auto t_ptr = is_16_bits - ? (uint8_t*)tensor.mutable_data_ptr() - : tensor.mutable_data_ptr(); + auto t_ptr = is_16_bits ? (uint8_t*)tensor.mutable_data_ptr() + : tensor.mutable_data_ptr(); for (int pass = 0; pass < number_of_passes; pass++) { for (png_uint_32 i = 0; i < height; ++i) { png_read_row(png_ptr, t_ptr, nullptr); t_ptr += num_pixels_per_row * (is_16_bits ? 2 : 1); } // Reset pointer - for 16-bit images we get uint16_t* and cast to uint8_t* - t_ptr = is_16_bits - ? (uint8_t*)tensor.mutable_data_ptr() - : tensor.mutable_data_ptr(); + t_ptr = is_16_bits ? (uint8_t*)tensor.mutable_data_ptr() + : tensor.mutable_data_ptr(); } int exif_orientation = -1; diff --git a/torchvision/csrc/io/image/cpu/decode_webp.cpp b/torchvision/csrc/io/image/cpu/decode_webp.cpp index 945fffeba70..17ebc9ab267 100644 --- a/torchvision/csrc/io/image/cpu/decode_webp.cpp +++ b/torchvision/csrc/io/image/cpu/decode_webp.cpp @@ -31,8 +31,13 @@ Tensor decode_webp(const Tensor& encoded_data, ImageReadMode mode) { VISION_CHECK( !features.has_animation, "Animated webp files are not supported."); - // Note: TORCH_WARN_ONCE is not available in stable ABI, so we just skip the - // warning + // TODO_STABLE_ABI: need TORCH_WARN_ONCE + // if (mode == IMAGE_READ_MODE_GRAY || mode == IMAGE_READ_MODE_GRAY_ALPHA) { + // TORCH_WARN_ONCE( + // "Webp does not support grayscale conversions. " + // "The returned tensor will be in the colorspace of the original + // image."); + // } auto return_rgb = should_this_return_rgb_or_rgba_let_me_know_in_the_comments_down_below_guys_see_you_in_the_next_video( diff --git a/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp b/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp index 6159a2ac9af..79cedc5f4c5 100644 --- a/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp +++ b/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp @@ -93,7 +93,8 @@ std::vector decode_jpegs_cuda( // Create a new device with the resolved index for consistency Device resolved_device(kCUDA, static_cast(target_device_idx)); - if (cudaJpegDecoder == nullptr || resolved_device != cudaJpegDecoder->target_device) { + if (cudaJpegDecoder == nullptr || + resolved_device != cudaJpegDecoder->target_device) { if (cudaJpegDecoder != nullptr) { cudaJpegDecoder.reset(new CUDAJpegDecoder(resolved_device)); } else { @@ -139,7 +140,8 @@ std::vector decode_jpegs_cuda( return result; } catch (const std::exception& e) { cudaSetDevice(prev_device); - throw std::runtime_error(std::string("Error while decoding JPEG images: ") + e.what()); + throw std::runtime_error( + std::string("Error while decoding JPEG images: ") + e.what()); } } diff --git a/torchvision/csrc/io/image/cuda/encode_jpegs_cuda.cpp b/torchvision/csrc/io/image/cuda/encode_jpegs_cuda.cpp index aa582701b15..17b6f46e17c 100644 --- a/torchvision/csrc/io/image/cuda/encode_jpegs_cuda.cpp +++ b/torchvision/csrc/io/image/cuda/encode_jpegs_cuda.cpp @@ -57,7 +57,8 @@ std::vector encode_jpegs_cuda( // the encoder object holds on to a lot of state and is expensive to create, // so we reuse it across calls. NB: the cached structures are device specific // and cannot be reused across devices - if (cudaJpegEncoder == nullptr || resolved_device != cudaJpegEncoder->target_device) { + if (cudaJpegEncoder == nullptr || + resolved_device != cudaJpegEncoder->target_device) { if (cudaJpegEncoder != nullptr) { delete cudaJpegEncoder.release(); } diff --git a/torchvision/csrc/io/image/image.cpp b/torchvision/csrc/io/image/image.cpp index ebfecabbe2d..a64800aca2a 100644 --- a/torchvision/csrc/io/image/image.cpp +++ b/torchvision/csrc/io/image/image.cpp @@ -34,9 +34,10 @@ STABLE_TORCH_LIBRARY_IMPL(image, CPU, m) { m.impl("decode_image", TORCH_BOX(&decode_image)); } -// Ops without tensor inputs or with cross-device semantics need BackendSelect dispatch -// encode_jpeg/encode_png also use BackendSelect so they can give proper error messages -// when CUDA tensors are passed (instead of "no kernel for CUDA") +// Ops without tensor inputs or with cross-device semantics need BackendSelect +// dispatch encode_jpeg/encode_png also use BackendSelect so they can give +// proper error messages when CUDA tensors are passed (instead of "no kernel for +// CUDA") STABLE_TORCH_LIBRARY_IMPL(image, BackendSelect, m) { m.impl("read_file", TORCH_BOX(&read_file)); m.impl("write_file", TORCH_BOX(&write_file)); @@ -44,7 +45,8 @@ STABLE_TORCH_LIBRARY_IMPL(image, BackendSelect, m) { m.impl("_is_compiled_against_turbo", TORCH_BOX(&_is_compiled_against_turbo)); // decode_jpegs_cuda takes CPU tensors as input but outputs CUDA tensors m.impl("decode_jpegs_cuda", TORCH_BOX(&decode_jpegs_cuda)); - // encode functions need BackendSelect to provide proper error messages for CUDA inputs + // encode functions need BackendSelect to provide proper error messages for + // CUDA inputs m.impl("encode_png", TORCH_BOX(&encode_png)); m.impl("encode_jpeg", TORCH_BOX(&encode_jpeg)); } From 74699cc90606656007bcf83f6f0614b0a24213b6 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 29 Jan 2026 12:52:22 +0000 Subject: [PATCH 8/9] Maybe fix rocm??? --- setup.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/setup.py b/setup.py index e7d7a1b4531..9b060f0dfb7 100644 --- a/setup.py +++ b/setup.py @@ -287,13 +287,20 @@ def make_image_extension(): define_macros += [("TORCH_TARGET_VERSION", "0x020b000000000000")] image_dir = CSRS_DIR / "io/image" - sources = list(image_dir.glob("*.cpp")) + list(image_dir.glob("cpu/*.cpp")) + list(image_dir.glob("cpu/giflib/*.c")) + # Exclude *_hip.cpp files - those are hipified versions that would cause multiple definition errors + sources = [s for s in image_dir.glob("*.cpp") if not s.name.endswith("_hip.cpp")] + sources += [s for s in image_dir.glob("cpu/*.cpp") if not s.name.endswith("_hip.cpp")] + sources += list(image_dir.glob("cpu/giflib/*.c")) # Always include CUDA sources - they have stubs when NVJPEG_FOUND is not defined if IS_ROCM: - sources += list(image_dir.glob("hip/*.cpp")) - # we need to exclude this in favor of the hipified source - sources.remove(image_dir / "image.cpp") + hip_sources = list(image_dir.glob("hip/*.cpp")) + if hip_sources: + sources += hip_sources + # Only remove image.cpp if we have a hipified replacement + if (image_dir / "image.cpp") in sources: + sources.remove(image_dir / "image.cpp") + # Note: if hip/ directory doesn't exist, we just use the regular sources else: sources += list(image_dir.glob("cuda/*.cpp")) From f0bfb6d13ef814bd627bd0ac914246724b261153 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 29 Jan 2026 13:20:58 +0000 Subject: [PATCH 9/9] rocm again? --- setup.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 9b060f0dfb7..22dfa1fdcca 100644 --- a/setup.py +++ b/setup.py @@ -300,7 +300,9 @@ def make_image_extension(): # Only remove image.cpp if we have a hipified replacement if (image_dir / "image.cpp") in sources: sources.remove(image_dir / "image.cpp") - # Note: if hip/ directory doesn't exist, we just use the regular sources + else: + # No hip/ directory - use cuda sources (they have stubs for non-NVJPEG builds) + sources += list(image_dir.glob("cuda/*.cpp")) else: sources += list(image_dir.glob("cuda/*.cpp"))