diff --git a/torchvision/csrc/io/image/cpu/decode_jpeg.cpp b/torchvision/csrc/io/image/cpu/decode_jpeg.cpp index 8163ace3307..af35eac7532 100644 --- a/torchvision/csrc/io/image/cpu/decode_jpeg.cpp +++ b/torchvision/csrc/io/image/cpu/decode_jpeg.cpp @@ -3,6 +3,8 @@ #include "common_jpeg.h" #include "exif.h" +#include + namespace vision { namespace image { @@ -78,8 +80,6 @@ static void torch_jpeg_set_source_mgr( inline unsigned char clamped_cmyk_rgb_convert( unsigned char k, unsigned char cmy) { - // Inspired from Pillow: - // https://github.com/python-pillow/Pillow/blob/07623d1a7cc65206a5355fba2ae256550bfcaba6/src/libImaging/Convert.c#L568-L569 int v = k * cmy + 128; v = ((v >> 8) + v) >> 8; return std::clamp(k - v, 0, 255); @@ -103,8 +103,6 @@ void convert_line_cmyk_to_rgb( } inline unsigned char rgb_to_gray(int r, int g, int b) { - // Inspired from Pillow: - // https://github.com/python-pillow/Pillow/blob/07623d1a7cc65206a5355fba2ae256550bfcaba6/src/libImaging/Convert.c#L226 return (r * 19595 + g * 38470 + b * 7471 + 0x8000) >> 16; } @@ -127,29 +125,36 @@ void convert_line_cmyk_to_gray( } } -} // namespace +// Helper return type: keep longjmp/libjpeg contained here. +struct JpegDecodeResult { + torch::Tensor hwc; // HWC uint8 tensor + int exif_orientation = -1; // valid if apply_exif_orientation==true +}; -torch::Tensor decode_jpeg( +// Decode to HWC only. No permute/exif transform here (keep throwing ops outside longjmp zone). +static JpegDecodeResult decode_jpeg_hwc_impl( const torch::Tensor& data, ImageReadMode mode, bool apply_exif_orientation) { - C10_LOG_API_USAGE_ONCE( - "torchvision.csrc.io.image.cpu.decode_jpeg.decode_jpeg"); + JpegDecodeResult res; validate_encoded_data(data); struct jpeg_decompress_struct cinfo; struct torch_jpeg_error_mgr jerr; + // NOTE: libjpeg uses setjmp/longjmp. longjmp does not unwind C++ stack frames. + // Any tensors allocated after setjmp must be explicitly reset in the error path. + c10::optional tensor_opt; + c10::optional cmyk_line_opt; + auto datap = data.data_ptr(); - // Setup decompression structure cinfo.err = jpeg_std_error(&jerr.pub); jerr.pub.error_exit = torch_jpeg_error_exit; - /* Establish the setjmp return context for my_error_exit to use. */ + if (setjmp(jerr.setjmp_buffer)) { - /* If we get here, the JPEG code has signaled an error. - * We need to clean up the JPEG object. - */ + cmyk_line_opt.reset(); + tensor_opt.reset(); jpeg_destroy_decompress(&cinfo); STD_TORCH_CHECK(false, jerr.jpegLastErrorMsg); } @@ -157,7 +162,6 @@ torch::Tensor decode_jpeg( jpeg_create_decompress(&cinfo); torch_jpeg_set_source_mgr(&cinfo, datap, data.numel()); - // read info from header. jpeg_read_header(&cinfo, TRUE); int channels = cinfo.num_components; @@ -185,23 +189,15 @@ torch::Tensor decode_jpeg( } channels = 3; break; - /* - * Libjpeg does not support converting from CMYK to grayscale etc. There - * is a way to do this but it involves converting it manually to RGB: - * https://github.com/tensorflow/tensorflow/blob/86871065265b04e0db8ca360c046421efb2bdeb4/tensorflow/core/lib/jpeg/jpeg_mem.cc#L284-L313 - */ default: jpeg_destroy_decompress(&cinfo); - STD_TORCH_CHECK( - false, "The provided mode is not supported for JPEG files"); + STD_TORCH_CHECK(false, "The provided mode is not supported for JPEG files"); } - jpeg_calc_output_dimensions(&cinfo); } - int exif_orientation = -1; if (apply_exif_orientation) { - exif_orientation = fetch_jpeg_exif_orientation(&cinfo); + res.exif_orientation = fetch_jpeg_exif_orientation(&cinfo); } jpeg_start_decompress(&cinfo); @@ -210,21 +206,18 @@ torch::Tensor decode_jpeg( int width = cinfo.output_width; int stride = width * channels; - auto tensor = + + tensor_opt = torch::empty({int64_t(height), int64_t(width), channels}, torch::kU8); - auto ptr = tensor.data_ptr(); - torch::Tensor cmyk_line_tensor; + auto ptr = tensor_opt->data_ptr(); + if (cmyk_to_rgb_or_gray) { - cmyk_line_tensor = torch::empty({int64_t(width), 4}, torch::kU8); + cmyk_line_opt = torch::empty({int64_t(width), 4}, torch::kU8); } while (cinfo.output_scanline < cinfo.output_height) { - /* jpeg_read_scanlines expects an array of pointers to scanlines. - * Here the array is only one element long, but you could ask for - * more than one scanline at a time if that's more convenient. - */ if (cmyk_to_rgb_or_gray) { - auto cmyk_line_ptr = cmyk_line_tensor.data_ptr(); + auto cmyk_line_ptr = cmyk_line_opt->data_ptr(); jpeg_read_scanlines(&cinfo, &cmyk_line_ptr, 1); if (channels == 3) { @@ -240,13 +233,31 @@ torch::Tensor decode_jpeg( jpeg_finish_decompress(&cinfo); jpeg_destroy_decompress(&cinfo); - auto output = tensor.permute({2, 0, 1}); + res.hwc = *tensor_opt; + return res; +} + +} // namespace + +torch::Tensor decode_jpeg( + const torch::Tensor& data, + ImageReadMode mode, + bool apply_exif_orientation) { + C10_LOG_API_USAGE_ONCE( + "torchvision.csrc.io.image.cpu.decode_jpeg.decode_jpeg"); + + // Longjmp/libjpeg zone is inside helper. + auto res = decode_jpeg_hwc_impl(data, mode, apply_exif_orientation); + + // Throwing ops are outside the longjmp zone. + auto output = res.hwc.permute({2, 0, 1}); if (apply_exif_orientation) { - return exif_orientation_transform(output, exif_orientation); + return exif_orientation_transform(output, res.exif_orientation); } return output; } + #endif // #if !JPEG_FOUND int64_t _jpeg_version() { @@ -266,4 +277,4 @@ bool _is_compiled_against_turbo() { } } // namespace image -} // namespace vision +} // namespace vision \ No newline at end of file diff --git a/torchvision/csrc/io/image/cpu/decode_png.cpp b/torchvision/csrc/io/image/cpu/decode_png.cpp index 67c788455c4..d4bae2f38c8 100644 --- a/torchvision/csrc/io/image/cpu/decode_png.cpp +++ b/torchvision/csrc/io/image/cpu/decode_png.cpp @@ -3,6 +3,10 @@ #include "common_png.h" #include "exif.h" +#include + +#include + namespace vision { namespace image { @@ -18,59 +22,109 @@ torch::Tensor decode_png( } #else +namespace { + bool is_little_endian() { uint32_t x = 1; return *(uint8_t*)&x; } -torch::Tensor decode_png( +struct TorchPngErrorContext { + char msg[256]; +}; + +static void torch_png_error_fn(png_structp png_ptr, png_const_charp error_msg) { + auto* err = + static_cast(png_get_error_ptr(png_ptr)); + if (err != nullptr) { + std::snprintf( + err->msg, + sizeof(err->msg), + "%s", + error_msg != nullptr ? error_msg : "Internal PNG error."); + } + longjmp(png_jmpbuf(png_ptr), 1); +} + +static void torch_png_warning_fn(png_structp, png_const_charp) { + // Keep default behavior quiet here. +} + +struct PngDecodeResult { + torch::Tensor hwc; // HWC uint8/uint16 tensor + int exif_orientation = -1; // valid if apply_exif_orientation == true +}; + +static PngDecodeResult decode_png_hwc_impl( const torch::Tensor& data, ImageReadMode mode, bool apply_exif_orientation) { - C10_LOG_API_USAGE_ONCE("torchvision.csrc.io.image.cpu.decode_png.decode_png"); + PngDecodeResult res; validate_encoded_data(data); - auto png_ptr = - png_create_read_struct(PNG_LIBPNG_VER_STRING, nullptr, nullptr, nullptr); + TorchPngErrorContext err{}; + std::snprintf(err.msg, sizeof(err.msg), "%s", "Internal PNG error."); + + auto png_ptr = png_create_read_struct( + PNG_LIBPNG_VER_STRING, + &err, + torch_png_error_fn, + torch_png_warning_fn); STD_TORCH_CHECK(png_ptr, "libpng read structure allocation failed!") + auto info_ptr = png_create_info_struct(png_ptr); if (!info_ptr) { png_destroy_read_struct(&png_ptr, nullptr, nullptr); - // Seems redundant with the if statement. done here to avoid leaking memory. - STD_TORCH_CHECK(info_ptr, "libpng info structure allocation failed!") + STD_TORCH_CHECK(false, "libpng info structure allocation failed!") + } + + c10::optional tensor_opt; + + if (setjmp(png_jmpbuf(png_ptr)) != 0) { + tensor_opt.reset(); + png_destroy_read_struct(&png_ptr, &info_ptr, nullptr); + STD_TORCH_CHECK(false, err.msg); } auto accessor = data.accessor(); auto datap = accessor.data(); auto datap_len = accessor.size(0); - if (setjmp(png_jmpbuf(png_ptr)) != 0) { + if (datap_len < 8) { png_destroy_read_struct(&png_ptr, &info_ptr, nullptr); - STD_TORCH_CHECK(false, "Internal error."); + STD_TORCH_CHECK(false, "Content is too small for png!") } - STD_TORCH_CHECK(datap_len >= 8, "Content is too small for png!") + auto is_png = !png_sig_cmp(datap, 0, 8); - STD_TORCH_CHECK(is_png, "Content is not png!") + if (!is_png) { + png_destroy_read_struct(&png_ptr, &info_ptr, nullptr); + STD_TORCH_CHECK(false, "Content is not png!") + } struct Reader { png_const_bytep ptr; png_size_t count; } reader; + reader.ptr = png_const_bytep(datap) + 8; reader.count = datap_len - 8; auto read_callback = [](png_structp png_ptr, png_bytep output, png_size_t bytes) { - auto reader = static_cast(png_get_io_ptr(png_ptr)); - STD_TORCH_CHECK( - reader->count >= bytes, - "Out of bound read in decode_png. Probably, the input image is corrupted"); + auto* reader = static_cast(png_get_io_ptr(png_ptr)); + if (reader->count < bytes) { + png_error( + png_ptr, + "Out of bound read in decode_png. Probably, the input image is corrupted"); + return; + } std::copy(reader->ptr, reader->ptr + bytes, output); reader->ptr += bytes; reader->count -= bytes; }; + png_set_sig_bytes(png_ptr, 8); png_set_read_fn(png_ptr, &reader, read_callback); png_read_info(png_ptr, info_ptr); @@ -78,6 +132,7 @@ torch::Tensor decode_png( png_uint_32 width, height; int bit_depth, color_type; int interlace_type; + auto retval = png_get_IHDR( png_ptr, info_ptr, @@ -91,7 +146,7 @@ torch::Tensor decode_png( if (retval != 1) { png_destroy_read_struct(&png_ptr, &info_ptr, nullptr); - STD_TORCH_CHECK(retval == 1, "Could read image metadata from content.") + STD_TORCH_CHECK(false, "Could not read image metadata from content.") } if (bit_depth > 8 && bit_depth != 16) { @@ -139,6 +194,7 @@ torch::Tensor decode_png( channels = 1; } break; + case IMAGE_READ_MODE_GRAY_ALPHA: if (color_type != PNG_COLOR_TYPE_GRAY_ALPHA) { if (is_palette) { @@ -147,7 +203,8 @@ torch::Tensor decode_png( } if (!has_alpha) { - png_set_add_alpha(png_ptr, (1 << bit_depth) - 1, PNG_FILLER_AFTER); + png_set_add_alpha( + png_ptr, (1 << bit_depth) - 1, PNG_FILLER_AFTER); } if (has_color) { @@ -156,6 +213,7 @@ torch::Tensor decode_png( channels = 2; } break; + case IMAGE_READ_MODE_RGB: if (color_type != PNG_COLOR_TYPE_RGB) { if (is_palette) { @@ -171,6 +229,7 @@ torch::Tensor decode_png( channels = 3; } break; + case IMAGE_READ_MODE_RGB_ALPHA: if (color_type != PNG_COLOR_TYPE_RGB_ALPHA) { if (is_palette) { @@ -181,11 +240,13 @@ torch::Tensor decode_png( } if (!has_alpha) { - png_set_add_alpha(png_ptr, (1 << bit_depth) - 1, PNG_FILLER_AFTER); + png_set_add_alpha( + png_ptr, (1 << bit_depth) - 1, PNG_FILLER_AFTER); } channels = 4; } break; + default: png_destroy_read_struct(&png_ptr, &info_ptr, nullptr); STD_TORCH_CHECK( @@ -197,35 +258,52 @@ torch::Tensor decode_png( auto num_pixels_per_row = width * channels; auto is_16_bits = bit_depth == 16; - auto tensor = torch::empty( + + tensor_opt = torch::empty( {int64_t(height), int64_t(width), channels}, is_16_bits ? at::kUInt16 : torch::kU8); + if (is_little_endian()) { png_set_swap(png_ptr); } - auto t_ptr = (uint8_t*)tensor.data_ptr(); + + auto t_ptr = reinterpret_cast(tensor_opt->data_ptr()); + for (int pass = 0; pass < number_of_passes; pass++) { for (png_uint_32 i = 0; i < height; ++i) { png_read_row(png_ptr, t_ptr, nullptr); t_ptr += num_pixels_per_row * (is_16_bits ? 2 : 1); } - t_ptr = (uint8_t*)tensor.data_ptr(); + t_ptr = reinterpret_cast(tensor_opt->data_ptr()); } - int exif_orientation = -1; if (apply_exif_orientation) { - exif_orientation = fetch_png_exif_orientation(png_ptr, info_ptr); + res.exif_orientation = fetch_png_exif_orientation(png_ptr, info_ptr); } png_destroy_read_struct(&png_ptr, &info_ptr, nullptr); - auto output = tensor.permute({2, 0, 1}); + res.hwc = *tensor_opt; + return res; +} + +} // namespace + +torch::Tensor decode_png( + const torch::Tensor& data, + ImageReadMode mode, + bool apply_exif_orientation) { + C10_LOG_API_USAGE_ONCE("torchvision.csrc.io.image.cpu.decode_png.decode_png"); + + auto res = decode_png_hwc_impl(data, mode, apply_exif_orientation); + + auto output = res.hwc.permute({2, 0, 1}); if (apply_exif_orientation) { - return exif_orientation_transform(output, exif_orientation); + return exif_orientation_transform(output, res.exif_orientation); } return output; } #endif } // namespace image -} // namespace vision +} // namespace vision \ No newline at end of file