diff --git a/torchvision/csrc/io/image/cpu/decode_image.cpp b/torchvision/csrc/io/image/cpu/decode_image.cpp index a1674993bfc..3d0d8aed4ef 100644 --- a/torchvision/csrc/io/image/cpu/decode_image.cpp +++ b/torchvision/csrc/io/image/cpu/decode_image.cpp @@ -35,7 +35,9 @@ torch::Tensor decode_image( const uint8_t png_signature[4] = {137, 80, 78, 71}; // == "\211PNG" STD_TORCH_CHECK(data.numel() >= 4, err_msg); if (memcmp(png_signature, datap, 4) == 0) { - return decode_png(data, mode, apply_exif_orientation); + auto stable_data = vision::toStableTensor(data); + auto stable_result = decode_png(stable_data, mode, apply_exif_orientation); + return vision::fromStableTensor(stable_result); } const uint8_t gif_signature_1[6] = { diff --git a/torchvision/csrc/io/image/cpu/decode_png.cpp b/torchvision/csrc/io/image/cpu/decode_png.cpp index 67c788455c4..714cb4fd69e 100644 --- a/torchvision/csrc/io/image/cpu/decode_png.cpp +++ b/torchvision/csrc/io/image/cpu/decode_png.cpp @@ -9,8 +9,8 @@ namespace image { using namespace exif_private; #if !PNG_FOUND -torch::Tensor decode_png( - const torch::Tensor& data, +torch::stable::Tensor decode_png( + const torch::stable::Tensor& data, ImageReadMode mode, bool apply_exif_orientation) { STD_TORCH_CHECK( @@ -23,13 +23,11 @@ bool is_little_endian() { return *(uint8_t*)&x; } -torch::Tensor decode_png( - const torch::Tensor& data, +torch::stable::Tensor decode_png( + const torch::stable::Tensor& data, ImageReadMode mode, bool apply_exif_orientation) { - C10_LOG_API_USAGE_ONCE("torchvision.csrc.io.image.cpu.decode_png.decode_png"); - - validate_encoded_data(data); + validate_encoded_data_stable(data); auto png_ptr = png_create_read_struct(PNG_LIBPNG_VER_STRING, nullptr, nullptr, nullptr); @@ -41,7 +39,7 @@ torch::Tensor decode_png( STD_TORCH_CHECK(info_ptr, "libpng info structure allocation failed!") } - auto accessor = data.accessor(); + auto accessor = constAccessor(data); auto datap = accessor.data(); auto datap_len = accessor.size(0); @@ -197,19 +195,21 @@ torch::Tensor decode_png( auto num_pixels_per_row = width * channels; auto is_16_bits = bit_depth == 16; - auto tensor = torch::empty( - {int64_t(height), int64_t(width), channels}, - is_16_bits ? at::kUInt16 : torch::kU8); + int64_t tensor_sizes[] = {int64_t(height), int64_t(width), channels}; + auto tensor = torch::stable::empty( + {tensor_sizes, 3}, + is_16_bits ? torch::headeronly::ScalarType::UInt16 + : torch::headeronly::ScalarType::Byte); if (is_little_endian()) { png_set_swap(png_ptr); } - auto t_ptr = (uint8_t*)tensor.data_ptr(); + auto t_ptr = static_cast(tensor.mutable_data_ptr()); for (int pass = 0; pass < number_of_passes; pass++) { for (png_uint_32 i = 0; i < height; ++i) { png_read_row(png_ptr, t_ptr, nullptr); t_ptr += num_pixels_per_row * (is_16_bits ? 2 : 1); } - t_ptr = (uint8_t*)tensor.data_ptr(); + t_ptr = static_cast(tensor.mutable_data_ptr()); } int exif_orientation = -1; @@ -219,9 +219,9 @@ torch::Tensor decode_png( png_destroy_read_struct(&png_ptr, &info_ptr, nullptr); - auto output = tensor.permute({2, 0, 1}); + auto output = stablePermute(tensor, {2, 0, 1}); if (apply_exif_orientation) { - return exif_orientation_transform(output, exif_orientation); + return exif_orientation_transform_stable(output, exif_orientation); } return output; } diff --git a/torchvision/csrc/io/image/cpu/decode_png.h b/torchvision/csrc/io/image/cpu/decode_png.h index faaffa7ae49..45443a0582b 100644 --- a/torchvision/csrc/io/image/cpu/decode_png.h +++ b/torchvision/csrc/io/image/cpu/decode_png.h @@ -1,13 +1,13 @@ #pragma once -#include +#include "../../../stable_abi_compat.h" #include "../common.h" namespace vision { namespace image { -C10_EXPORT torch::Tensor decode_png( - const torch::Tensor& data, +torch::stable::Tensor decode_png( + const torch::stable::Tensor& data, ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED, bool apply_exif_orientation = false); diff --git a/torchvision/csrc/io/image/cpu/encode_png.cpp b/torchvision/csrc/io/image/cpu/encode_png.cpp index d015f44cb39..a276278b243 100644 --- a/torchvision/csrc/io/image/cpu/encode_png.cpp +++ b/torchvision/csrc/io/image/cpu/encode_png.cpp @@ -1,4 +1,6 @@ -#include "encode_jpeg.h" +#include "encode_png.h" + +#include #include @@ -9,7 +11,9 @@ namespace image { #if !PNG_FOUND -torch::Tensor encode_png(const torch::Tensor& data, int64_t compression_level) { +torch::stable::Tensor encode_png( + const torch::stable::Tensor& data, + int64_t compression_level) { STD_TORCH_CHECK( false, "encode_png: torchvision not compiled with libpng support"); } @@ -66,8 +70,9 @@ void torch_png_write_data( } // namespace -torch::Tensor encode_png(const torch::Tensor& data, int64_t compression_level) { - C10_LOG_API_USAGE_ONCE("torchvision.csrc.io.image.cpu.encode_png.encode_png"); +torch::stable::Tensor encode_png( + const torch::stable::Tensor& data, + int64_t compression_level) { // Define compression structures and error handling png_structp png_write; png_infop info_ptr; @@ -104,12 +109,12 @@ torch::Tensor encode_png(const torch::Tensor& data, int64_t compression_level) { "Compression level should be between 0 and 9"); // Check that the input tensor is on CPU - STD_TORCH_CHECK( - data.device() == torch::kCPU, "Input tensor should be on CPU"); + STD_TORCH_CHECK(data.is_cpu(), "Input tensor should be on CPU"); // Check that the input tensor dtype is uint8 STD_TORCH_CHECK( - data.dtype() == torch::kU8, "Input tensor dtype should be uint8"); + data.scalar_type() == torch::headeronly::ScalarType::Byte, + "Input tensor dtype should be uint8"); // Check that the input tensor is 3-dimensional STD_TORCH_CHECK( @@ -119,7 +124,7 @@ torch::Tensor encode_png(const torch::Tensor& data, int64_t compression_level) { int channels = data.size(0); int height = data.size(1); int width = data.size(2); - auto input = data.permute({1, 2, 0}).contiguous(); + auto input = torch::stable::contiguous(stablePermute(data, {1, 2, 0})); STD_TORCH_CHECK( channels == 1 || channels == 3, @@ -155,7 +160,7 @@ torch::Tensor encode_png(const torch::Tensor& data, int64_t compression_level) { png_write_info(png_write, info_ptr); auto stride = width * channels; - auto ptr = input.data_ptr(); + auto ptr = input.const_data_ptr(); // Encode PNG file for (int y = 0; y < height; ++y) { @@ -169,12 +174,13 @@ torch::Tensor encode_png(const torch::Tensor& data, int64_t compression_level) { // Destroy structures png_destroy_write_struct(&png_write, &info_ptr); - torch::TensorOptions options = torch::TensorOptions{torch::kU8}; - auto outTensor = torch::empty({(long)buf_info.size}, options); + int64_t out_size = static_cast(buf_info.size); + auto outTensor = + torch::stable::empty({&out_size, 1}, torch::headeronly::ScalarType::Byte); // Copy memory from png buffer, since torch cannot get ownership of it via // `from_blob` - auto outPtr = outTensor.data_ptr(); + auto outPtr = static_cast(outTensor.mutable_data_ptr()); std::memcpy(outPtr, buf_info.buffer, sizeof(uint8_t) * outTensor.numel()); free(buf_info.buffer); diff --git a/torchvision/csrc/io/image/cpu/encode_png.h b/torchvision/csrc/io/image/cpu/encode_png.h index 86a67c8706e..68955fead16 100644 --- a/torchvision/csrc/io/image/cpu/encode_png.h +++ b/torchvision/csrc/io/image/cpu/encode_png.h @@ -1,12 +1,12 @@ #pragma once -#include +#include "../../../stable_abi_compat.h" namespace vision { namespace image { -C10_EXPORT torch::Tensor encode_png( - const torch::Tensor& data, +torch::stable::Tensor encode_png( + const torch::stable::Tensor& data, int64_t compression_level); } // namespace image diff --git a/torchvision/csrc/io/image/cpu/exif.h b/torchvision/csrc/io/image/cpu/exif.h index e55a800b220..2730ef89cac 100644 --- a/torchvision/csrc/io/image/cpu/exif.h +++ b/torchvision/csrc/io/image/cpu/exif.h @@ -60,6 +60,7 @@ direct, #include #include +#include "../../../stable_abi_compat.h" namespace vision { namespace image { @@ -253,6 +254,33 @@ inline torch::Tensor exif_orientation_transform( return image; } +// Stable ABI version of exif_orientation_transform +inline torch::stable::Tensor exif_orientation_transform_stable( + const torch::stable::Tensor& image, + int orientation) { + if (orientation == IMAGE_ORIENTATION_TL) { + return image; + } else if (orientation == IMAGE_ORIENTATION_TR) { + return vision::stableFlip(image, {-1}); + } else if (orientation == IMAGE_ORIENTATION_BR) { + // needs 180 rotation equivalent to + // flip both horizontally and vertically + return vision::stableFlip(image, {-2, -1}); + } else if (orientation == IMAGE_ORIENTATION_BL) { + return vision::stableFlip(image, {-2}); + } else if (orientation == IMAGE_ORIENTATION_LT) { + return torch::stable::transpose(image, -1, -2); + } else if (orientation == IMAGE_ORIENTATION_RT) { + return vision::stableFlip(torch::stable::transpose(image, -1, -2), {-1}); + } else if (orientation == IMAGE_ORIENTATION_RB) { + return vision::stableFlip( + torch::stable::transpose(image, -1, -2), {-2, -1}); + } else if (orientation == IMAGE_ORIENTATION_LB) { + return vision::stableFlip(torch::stable::transpose(image, -1, -2), {-2}); + } + return image; +} + } // namespace exif_private } // namespace image } // namespace vision diff --git a/torchvision/csrc/io/image/image.cpp b/torchvision/csrc/io/image/image.cpp index b4a4ed54a67..45e77b36db0 100644 --- a/torchvision/csrc/io/image/image.cpp +++ b/torchvision/csrc/io/image/image.cpp @@ -1,16 +1,15 @@ #include "image.h" #include +#include namespace vision { namespace image { +// Legacy registration for non-PNG ops static auto registry = torch::RegisterOperators() .op("image::decode_gif", &decode_gif) - .op("image::decode_png(Tensor data, int mode, bool apply_exif_orientation=False) -> Tensor", - &decode_png) - .op("image::encode_png", &encode_png) .op("image::decode_jpeg(Tensor data, int mode, bool apply_exif_orientation=False) -> Tensor", &decode_jpeg) .op("image::decode_webp(Tensor encoded_data, int mode) -> Tensor", @@ -25,5 +24,17 @@ static auto registry = .op("image::_jpeg_version", &_jpeg_version) .op("image::_is_compiled_against_turbo", &_is_compiled_against_turbo); +// Stable ABI registration for PNG ops +STABLE_TORCH_LIBRARY_FRAGMENT(image, m) { + m.def( + "decode_png(Tensor data, int mode, bool apply_exif_orientation=False) -> Tensor"); + m.def("encode_png(Tensor data, int compression_level) -> Tensor"); +} + +STABLE_TORCH_LIBRARY_IMPL(image, CPU, m) { + m.impl("decode_png", TORCH_BOX(&decode_png)); + m.impl("encode_png", TORCH_BOX(&encode_png)); +} + } // namespace image } // namespace vision diff --git a/torchvision/csrc/stable_abi_compat.h b/torchvision/csrc/stable_abi_compat.h new file mode 100644 index 00000000000..e640b7e4414 --- /dev/null +++ b/torchvision/csrc/stable_abi_compat.h @@ -0,0 +1,89 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +// Conversion helpers between at::Tensor and torch::stable::Tensor. +// These are used at migration boundaries where some code is on the old API +// and some is on the stable ABI. +#include + +namespace vision { + +inline torch::stable::Tensor toStableTensor(at::Tensor t) { + return torch::stable::Tensor( + reinterpret_cast(new at::Tensor(std::move(t)))); +} + +inline at::Tensor fromStableTensor(const torch::stable::Tensor& t) { + return *reinterpret_cast(t.get()); +} + +// Dispatcher-based helpers for ops not yet in the stable ABI. +inline torch::stable::Tensor stablePermute( + const torch::stable::Tensor& self, + std::vector dims) { + const auto num_args = 2; + std::array stack{ + torch::stable::detail::from(self), torch::stable::detail::from(dims)}; + TORCH_ERROR_CODE_CHECK(torch_call_dispatcher( + "aten::permute", "", stack.data(), TORCH_ABI_VERSION)); + return torch::stable::detail::to(stack[0]); +} + +inline torch::stable::Tensor stableFlip( + const torch::stable::Tensor& self, + std::vector dims) { + const auto num_args = 2; + std::array stack{ + torch::stable::detail::from(self), torch::stable::detail::from(dims)}; + TORCH_ERROR_CODE_CHECK( + torch_call_dispatcher("aten::flip", "", stack.data(), TORCH_ABI_VERSION)); + return torch::stable::detail::to(stack[0]); +} + +// Accessor helpers for torch::stable::Tensor, modeled after torchcodec's +// StableABICompat.h. These construct a HeaderOnlyTensorAccessor from the +// stable tensor's raw pointer, sizes, and strides. +template +torch::headeronly::HeaderOnlyTensorAccessor mutableAccessor( + torch::stable::Tensor& tensor) { + return torch::headeronly::HeaderOnlyTensorAccessor( + tensor.mutable_data_ptr(), + tensor.sizes().data(), + tensor.strides().data()); +} + +template +torch::headeronly::HeaderOnlyTensorAccessor constAccessor( + const torch::stable::Tensor& tensor) { + return torch::headeronly::HeaderOnlyTensorAccessor( + tensor.const_data_ptr(), + tensor.sizes().data(), + tensor.strides().data()); +} + +// Stable ABI version of validate_encoded_data. +inline void validate_encoded_data_stable( + const torch::stable::Tensor& encoded_data) { + STD_TORCH_CHECK( + encoded_data.is_contiguous(), "Input tensor must be contiguous."); + STD_TORCH_CHECK( + encoded_data.scalar_type() == torch::headeronly::ScalarType::Byte, + "Input tensor must have uint8 data type."); + STD_TORCH_CHECK( + encoded_data.dim() == 1 && encoded_data.numel() > 0, + "Input tensor must be 1-dimensional and non-empty."); +} + +} // namespace vision