Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 47 additions & 36 deletions torchvision/csrc/io/image/cpu/decode_jpeg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
#include "common_jpeg.h"
#include "exif.h"

#include <c10/util/Optional.h>

namespace vision {
namespace image {

Expand Down Expand Up @@ -78,8 +80,6 @@ static void torch_jpeg_set_source_mgr(
inline unsigned char clamped_cmyk_rgb_convert(
unsigned char k,
unsigned char cmy) {
// Inspired from Pillow:
// https://github.com/python-pillow/Pillow/blob/07623d1a7cc65206a5355fba2ae256550bfcaba6/src/libImaging/Convert.c#L568-L569
int v = k * cmy + 128;
v = ((v >> 8) + v) >> 8;
return std::clamp(k - v, 0, 255);
Expand All @@ -103,8 +103,6 @@ void convert_line_cmyk_to_rgb(
}

inline unsigned char rgb_to_gray(int r, int g, int b) {
// Inspired from Pillow:
// https://github.com/python-pillow/Pillow/blob/07623d1a7cc65206a5355fba2ae256550bfcaba6/src/libImaging/Convert.c#L226
return (r * 19595 + g * 38470 + b * 7471 + 0x8000) >> 16;
}

Expand All @@ -127,37 +125,43 @@ void convert_line_cmyk_to_gray(
}
}

} // namespace
// Helper return type: keep longjmp/libjpeg contained here.
struct JpegDecodeResult {
torch::Tensor hwc; // HWC uint8 tensor
int exif_orientation = -1; // valid if apply_exif_orientation==true
};

torch::Tensor decode_jpeg(
// Decode to HWC only. No permute/exif transform here (keep throwing ops outside longjmp zone).
static JpegDecodeResult decode_jpeg_hwc_impl(
const torch::Tensor& data,
ImageReadMode mode,
bool apply_exif_orientation) {
C10_LOG_API_USAGE_ONCE(
"torchvision.csrc.io.image.cpu.decode_jpeg.decode_jpeg");
JpegDecodeResult res;

validate_encoded_data(data);

struct jpeg_decompress_struct cinfo;
struct torch_jpeg_error_mgr jerr;

// NOTE: libjpeg uses setjmp/longjmp. longjmp does not unwind C++ stack frames.
// Any tensors allocated after setjmp must be explicitly reset in the error path.
c10::optional<torch::Tensor> tensor_opt;
c10::optional<torch::Tensor> cmyk_line_opt;

auto datap = data.data_ptr<uint8_t>();
// Setup decompression structure
cinfo.err = jpeg_std_error(&jerr.pub);
jerr.pub.error_exit = torch_jpeg_error_exit;
/* Establish the setjmp return context for my_error_exit to use. */

if (setjmp(jerr.setjmp_buffer)) {
/* If we get here, the JPEG code has signaled an error.
* We need to clean up the JPEG object.
*/
cmyk_line_opt.reset();
tensor_opt.reset();
jpeg_destroy_decompress(&cinfo);
STD_TORCH_CHECK(false, jerr.jpegLastErrorMsg);
}

jpeg_create_decompress(&cinfo);
torch_jpeg_set_source_mgr(&cinfo, datap, data.numel());

// read info from header.
jpeg_read_header(&cinfo, TRUE);

int channels = cinfo.num_components;
Expand Down Expand Up @@ -185,23 +189,15 @@ torch::Tensor decode_jpeg(
}
channels = 3;
break;
/*
* Libjpeg does not support converting from CMYK to grayscale etc. There
* is a way to do this but it involves converting it manually to RGB:
* https://github.com/tensorflow/tensorflow/blob/86871065265b04e0db8ca360c046421efb2bdeb4/tensorflow/core/lib/jpeg/jpeg_mem.cc#L284-L313
*/
default:
jpeg_destroy_decompress(&cinfo);
STD_TORCH_CHECK(
false, "The provided mode is not supported for JPEG files");
STD_TORCH_CHECK(false, "The provided mode is not supported for JPEG files");
}

jpeg_calc_output_dimensions(&cinfo);
}

int exif_orientation = -1;
if (apply_exif_orientation) {
exif_orientation = fetch_jpeg_exif_orientation(&cinfo);
res.exif_orientation = fetch_jpeg_exif_orientation(&cinfo);
}

jpeg_start_decompress(&cinfo);
Expand All @@ -210,21 +206,18 @@ torch::Tensor decode_jpeg(
int width = cinfo.output_width;

int stride = width * channels;
auto tensor =

tensor_opt =
torch::empty({int64_t(height), int64_t(width), channels}, torch::kU8);
auto ptr = tensor.data_ptr<uint8_t>();
torch::Tensor cmyk_line_tensor;
auto ptr = tensor_opt->data_ptr<uint8_t>();

if (cmyk_to_rgb_or_gray) {
cmyk_line_tensor = torch::empty({int64_t(width), 4}, torch::kU8);
cmyk_line_opt = torch::empty({int64_t(width), 4}, torch::kU8);
}

while (cinfo.output_scanline < cinfo.output_height) {
/* jpeg_read_scanlines expects an array of pointers to scanlines.
* Here the array is only one element long, but you could ask for
* more than one scanline at a time if that's more convenient.
*/
if (cmyk_to_rgb_or_gray) {
auto cmyk_line_ptr = cmyk_line_tensor.data_ptr<uint8_t>();
auto cmyk_line_ptr = cmyk_line_opt->data_ptr<uint8_t>();
jpeg_read_scanlines(&cinfo, &cmyk_line_ptr, 1);

if (channels == 3) {
Expand All @@ -240,13 +233,31 @@ torch::Tensor decode_jpeg(

jpeg_finish_decompress(&cinfo);
jpeg_destroy_decompress(&cinfo);
auto output = tensor.permute({2, 0, 1});

res.hwc = *tensor_opt;
return res;
}

} // namespace

torch::Tensor decode_jpeg(
const torch::Tensor& data,
ImageReadMode mode,
bool apply_exif_orientation) {
C10_LOG_API_USAGE_ONCE(
"torchvision.csrc.io.image.cpu.decode_jpeg.decode_jpeg");

// Longjmp/libjpeg zone is inside helper.
auto res = decode_jpeg_hwc_impl(data, mode, apply_exif_orientation);

// Throwing ops are outside the longjmp zone.
auto output = res.hwc.permute({2, 0, 1});
if (apply_exif_orientation) {
return exif_orientation_transform(output, exif_orientation);
return exif_orientation_transform(output, res.exif_orientation);
}
return output;
}

#endif // #if !JPEG_FOUND

int64_t _jpeg_version() {
Expand All @@ -266,4 +277,4 @@ bool _is_compiled_against_turbo() {
}

} // namespace image
} // namespace vision
} // namespace vision
Loading