Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion torchvision/csrc/io/image/cpu/decode_image.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@ torch::Tensor decode_image(
const uint8_t png_signature[4] = {137, 80, 78, 71}; // == "\211PNG"
STD_TORCH_CHECK(data.numel() >= 4, err_msg);
if (memcmp(png_signature, datap, 4) == 0) {
return decode_png(data, mode, apply_exif_orientation);
auto stable_data = vision::toStableTensor(data);
auto stable_result = decode_png(stable_data, mode, apply_exif_orientation);
return vision::fromStableTensor(stable_result);
}

const uint8_t gif_signature_1[6] = {
Expand Down
30 changes: 15 additions & 15 deletions torchvision/csrc/io/image/cpu/decode_png.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ namespace image {
using namespace exif_private;

#if !PNG_FOUND
torch::Tensor decode_png(
const torch::Tensor& data,
torch::stable::Tensor decode_png(
const torch::stable::Tensor& data,
ImageReadMode mode,
bool apply_exif_orientation) {
STD_TORCH_CHECK(
Expand All @@ -23,13 +23,11 @@ bool is_little_endian() {
return *(uint8_t*)&x;
}

torch::Tensor decode_png(
const torch::Tensor& data,
torch::stable::Tensor decode_png(
const torch::stable::Tensor& data,
ImageReadMode mode,
bool apply_exif_orientation) {
C10_LOG_API_USAGE_ONCE("torchvision.csrc.io.image.cpu.decode_png.decode_png");

validate_encoded_data(data);
validate_encoded_data_stable(data);

auto png_ptr =
png_create_read_struct(PNG_LIBPNG_VER_STRING, nullptr, nullptr, nullptr);
Expand All @@ -41,7 +39,7 @@ torch::Tensor decode_png(
STD_TORCH_CHECK(info_ptr, "libpng info structure allocation failed!")
}

auto accessor = data.accessor<unsigned char, 1>();
auto accessor = constAccessor<unsigned char, 1>(data);
auto datap = accessor.data();
auto datap_len = accessor.size(0);

Expand Down Expand Up @@ -197,19 +195,21 @@ torch::Tensor decode_png(

auto num_pixels_per_row = width * channels;
auto is_16_bits = bit_depth == 16;
auto tensor = torch::empty(
{int64_t(height), int64_t(width), channels},
is_16_bits ? at::kUInt16 : torch::kU8);
int64_t tensor_sizes[] = {int64_t(height), int64_t(width), channels};
auto tensor = torch::stable::empty(
{tensor_sizes, 3},
is_16_bits ? torch::headeronly::ScalarType::UInt16
: torch::headeronly::ScalarType::Byte);
if (is_little_endian()) {
png_set_swap(png_ptr);
}
auto t_ptr = (uint8_t*)tensor.data_ptr();
auto t_ptr = static_cast<uint8_t*>(tensor.mutable_data_ptr());
for (int pass = 0; pass < number_of_passes; pass++) {
for (png_uint_32 i = 0; i < height; ++i) {
png_read_row(png_ptr, t_ptr, nullptr);
t_ptr += num_pixels_per_row * (is_16_bits ? 2 : 1);
}
t_ptr = (uint8_t*)tensor.data_ptr();
t_ptr = static_cast<uint8_t*>(tensor.mutable_data_ptr());
}

int exif_orientation = -1;
Expand All @@ -219,9 +219,9 @@ torch::Tensor decode_png(

png_destroy_read_struct(&png_ptr, &info_ptr, nullptr);

auto output = tensor.permute({2, 0, 1});
auto output = stablePermute(tensor, {2, 0, 1});
if (apply_exif_orientation) {
return exif_orientation_transform(output, exif_orientation);
return exif_orientation_transform_stable(output, exif_orientation);
}
return output;
}
Expand Down
6 changes: 3 additions & 3 deletions torchvision/csrc/io/image/cpu/decode_png.h
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
#pragma once

#include <torch/types.h>
#include "../../../stable_abi_compat.h"
#include "../common.h"

namespace vision {
namespace image {

C10_EXPORT torch::Tensor decode_png(
const torch::Tensor& data,
torch::stable::Tensor decode_png(
const torch::stable::Tensor& data,
ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED,
bool apply_exif_orientation = false);

Expand Down
30 changes: 18 additions & 12 deletions torchvision/csrc/io/image/cpu/encode_png.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
#include "encode_jpeg.h"
#include "encode_png.h"

#include <torch/headeronly/util/Exception.h>

#include <torch/headeronly/util/Exception.h>

Expand All @@ -9,7 +11,9 @@ namespace image {

#if !PNG_FOUND

torch::Tensor encode_png(const torch::Tensor& data, int64_t compression_level) {
torch::stable::Tensor encode_png(
const torch::stable::Tensor& data,
int64_t compression_level) {
STD_TORCH_CHECK(
false, "encode_png: torchvision not compiled with libpng support");
}
Expand Down Expand Up @@ -66,8 +70,9 @@ void torch_png_write_data(

} // namespace

torch::Tensor encode_png(const torch::Tensor& data, int64_t compression_level) {
C10_LOG_API_USAGE_ONCE("torchvision.csrc.io.image.cpu.encode_png.encode_png");
torch::stable::Tensor encode_png(
const torch::stable::Tensor& data,
int64_t compression_level) {
// Define compression structures and error handling
png_structp png_write;
png_infop info_ptr;
Expand Down Expand Up @@ -104,12 +109,12 @@ torch::Tensor encode_png(const torch::Tensor& data, int64_t compression_level) {
"Compression level should be between 0 and 9");

// Check that the input tensor is on CPU
STD_TORCH_CHECK(
data.device() == torch::kCPU, "Input tensor should be on CPU");
STD_TORCH_CHECK(data.is_cpu(), "Input tensor should be on CPU");

// Check that the input tensor dtype is uint8
STD_TORCH_CHECK(
data.dtype() == torch::kU8, "Input tensor dtype should be uint8");
data.scalar_type() == torch::headeronly::ScalarType::Byte,
"Input tensor dtype should be uint8");

// Check that the input tensor is 3-dimensional
STD_TORCH_CHECK(
Expand All @@ -119,7 +124,7 @@ torch::Tensor encode_png(const torch::Tensor& data, int64_t compression_level) {
int channels = data.size(0);
int height = data.size(1);
int width = data.size(2);
auto input = data.permute({1, 2, 0}).contiguous();
auto input = torch::stable::contiguous(stablePermute(data, {1, 2, 0}));

STD_TORCH_CHECK(
channels == 1 || channels == 3,
Expand Down Expand Up @@ -155,7 +160,7 @@ torch::Tensor encode_png(const torch::Tensor& data, int64_t compression_level) {
png_write_info(png_write, info_ptr);

auto stride = width * channels;
auto ptr = input.data_ptr<uint8_t>();
auto ptr = input.const_data_ptr<uint8_t>();

// Encode PNG file
for (int y = 0; y < height; ++y) {
Expand All @@ -169,12 +174,13 @@ torch::Tensor encode_png(const torch::Tensor& data, int64_t compression_level) {
// Destroy structures
png_destroy_write_struct(&png_write, &info_ptr);

torch::TensorOptions options = torch::TensorOptions{torch::kU8};
auto outTensor = torch::empty({(long)buf_info.size}, options);
int64_t out_size = static_cast<int64_t>(buf_info.size);
auto outTensor =
torch::stable::empty({&out_size, 1}, torch::headeronly::ScalarType::Byte);

// Copy memory from png buffer, since torch cannot get ownership of it via
// `from_blob`
auto outPtr = outTensor.data_ptr<uint8_t>();
auto outPtr = static_cast<uint8_t*>(outTensor.mutable_data_ptr());
std::memcpy(outPtr, buf_info.buffer, sizeof(uint8_t) * outTensor.numel());
free(buf_info.buffer);

Expand Down
6 changes: 3 additions & 3 deletions torchvision/csrc/io/image/cpu/encode_png.h
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
#pragma once

#include <torch/types.h>
#include "../../../stable_abi_compat.h"

namespace vision {
namespace image {

C10_EXPORT torch::Tensor encode_png(
const torch::Tensor& data,
torch::stable::Tensor encode_png(
const torch::stable::Tensor& data,
int64_t compression_level);

} // namespace image
Expand Down
28 changes: 28 additions & 0 deletions torchvision/csrc/io/image/cpu/exif.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ direct,

#include <torch/headeronly/util/Exception.h>
#include <torch/types.h>
#include "../../../stable_abi_compat.h"

namespace vision {
namespace image {
Expand Down Expand Up @@ -253,6 +254,33 @@ inline torch::Tensor exif_orientation_transform(
return image;
}

// Stable ABI version of exif_orientation_transform
inline torch::stable::Tensor exif_orientation_transform_stable(
const torch::stable::Tensor& image,
int orientation) {
if (orientation == IMAGE_ORIENTATION_TL) {
return image;
} else if (orientation == IMAGE_ORIENTATION_TR) {
return vision::stableFlip(image, {-1});
} else if (orientation == IMAGE_ORIENTATION_BR) {
// needs 180 rotation equivalent to
// flip both horizontally and vertically
return vision::stableFlip(image, {-2, -1});
} else if (orientation == IMAGE_ORIENTATION_BL) {
return vision::stableFlip(image, {-2});
} else if (orientation == IMAGE_ORIENTATION_LT) {
return torch::stable::transpose(image, -1, -2);
} else if (orientation == IMAGE_ORIENTATION_RT) {
return vision::stableFlip(torch::stable::transpose(image, -1, -2), {-1});
} else if (orientation == IMAGE_ORIENTATION_RB) {
return vision::stableFlip(
torch::stable::transpose(image, -1, -2), {-2, -1});
} else if (orientation == IMAGE_ORIENTATION_LB) {
return vision::stableFlip(torch::stable::transpose(image, -1, -2), {-2});
}
return image;
}

} // namespace exif_private
} // namespace image
} // namespace vision
17 changes: 14 additions & 3 deletions torchvision/csrc/io/image/image.cpp
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
#include "image.h"

#include <ATen/core/op_registration/op_registration.h>
#include <torch/csrc/stable/library.h>

namespace vision {
namespace image {

// Legacy registration for non-PNG ops
static auto registry =
torch::RegisterOperators()
.op("image::decode_gif", &decode_gif)
.op("image::decode_png(Tensor data, int mode, bool apply_exif_orientation=False) -> Tensor",
&decode_png)
.op("image::encode_png", &encode_png)
.op("image::decode_jpeg(Tensor data, int mode, bool apply_exif_orientation=False) -> Tensor",
&decode_jpeg)
.op("image::decode_webp(Tensor encoded_data, int mode) -> Tensor",
Expand All @@ -25,5 +24,17 @@ static auto registry =
.op("image::_jpeg_version", &_jpeg_version)
.op("image::_is_compiled_against_turbo", &_is_compiled_against_turbo);

// Stable ABI registration for PNG ops
STABLE_TORCH_LIBRARY_FRAGMENT(image, m) {
m.def(
"decode_png(Tensor data, int mode, bool apply_exif_orientation=False) -> Tensor");
m.def("encode_png(Tensor data, int compression_level) -> Tensor");
}

STABLE_TORCH_LIBRARY_IMPL(image, CPU, m) {
m.impl("decode_png", TORCH_BOX(&decode_png));
m.impl("encode_png", TORCH_BOX(&encode_png));
}

} // namespace image
} // namespace vision
89 changes: 89 additions & 0 deletions torchvision/csrc/stable_abi_compat.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
#pragma once

#include <torch/csrc/stable/device.h>
#include <torch/csrc/stable/library.h>
#include <torch/csrc/stable/ops.h>
#include <torch/csrc/stable/tensor.h>
#include <torch/headeronly/core/DeviceType.h>
#include <torch/headeronly/core/ScalarType.h>
#include <torch/headeronly/core/TensorAccessor.h>
#include <torch/headeronly/util/Exception.h>

#include <array>
#include <string>
#include <vector>

// Conversion helpers between at::Tensor and torch::stable::Tensor.
// These are used at migration boundaries where some code is on the old API
// and some is on the stable ABI.
#include <ATen/Tensor.h>

namespace vision {

inline torch::stable::Tensor toStableTensor(at::Tensor t) {
return torch::stable::Tensor(
reinterpret_cast<AtenTensorHandle>(new at::Tensor(std::move(t))));
}

inline at::Tensor fromStableTensor(const torch::stable::Tensor& t) {
return *reinterpret_cast<at::Tensor*>(t.get());
}

// Dispatcher-based helpers for ops not yet in the stable ABI.
inline torch::stable::Tensor stablePermute(
const torch::stable::Tensor& self,
std::vector<int64_t> dims) {
const auto num_args = 2;
std::array<StableIValue, num_args> stack{
torch::stable::detail::from(self), torch::stable::detail::from(dims)};
TORCH_ERROR_CODE_CHECK(torch_call_dispatcher(
"aten::permute", "", stack.data(), TORCH_ABI_VERSION));
return torch::stable::detail::to<torch::stable::Tensor>(stack[0]);
}

inline torch::stable::Tensor stableFlip(
const torch::stable::Tensor& self,
std::vector<int64_t> dims) {
const auto num_args = 2;
std::array<StableIValue, num_args> stack{
torch::stable::detail::from(self), torch::stable::detail::from(dims)};
TORCH_ERROR_CODE_CHECK(
torch_call_dispatcher("aten::flip", "", stack.data(), TORCH_ABI_VERSION));
return torch::stable::detail::to<torch::stable::Tensor>(stack[0]);
}

// Accessor helpers for torch::stable::Tensor, modeled after torchcodec's
// StableABICompat.h. These construct a HeaderOnlyTensorAccessor from the
// stable tensor's raw pointer, sizes, and strides.
template <typename T, size_t N>
torch::headeronly::HeaderOnlyTensorAccessor<T, N> mutableAccessor(
torch::stable::Tensor& tensor) {
return torch::headeronly::HeaderOnlyTensorAccessor<T, N>(
tensor.mutable_data_ptr<T>(),
tensor.sizes().data(),
tensor.strides().data());
}

template <typename T, size_t N>
torch::headeronly::HeaderOnlyTensorAccessor<const T, N> constAccessor(
const torch::stable::Tensor& tensor) {
return torch::headeronly::HeaderOnlyTensorAccessor<const T, N>(
tensor.const_data_ptr<T>(),
tensor.sizes().data(),
tensor.strides().data());
}

// Stable ABI version of validate_encoded_data.
inline void validate_encoded_data_stable(
const torch::stable::Tensor& encoded_data) {
STD_TORCH_CHECK(
encoded_data.is_contiguous(), "Input tensor must be contiguous.");
STD_TORCH_CHECK(
encoded_data.scalar_type() == torch::headeronly::ScalarType::Byte,
"Input tensor must have uint8 data type.");
STD_TORCH_CHECK(
encoded_data.dim() == 1 && encoded_data.numel() > 0,
"Input tensor must be 1-dimensional and non-empty.");
}

} // namespace vision
Loading