From f4fc92eee795606e9d4a1f5919f9c0f26bfb9e44 Mon Sep 17 00:00:00 2001 From: xytpai Date: Thu, 8 Jan 2026 03:18:27 +0000 Subject: [PATCH 1/6] add rocjpeg support --- setup.py | 17 +- test/test_image.py | 16 +- .../csrc/io/image/cuda/decode_jpegs_cuda.cpp | 308 +++++++++++++++++- .../csrc/io/image/cuda/decode_jpegs_cuda.h | 32 ++ 4 files changed, 361 insertions(+), 12 deletions(-) diff --git a/setup.py b/setup.py index 6181007924e..5ad744e5061 100644 --- a/setup.py +++ b/setup.py @@ -22,6 +22,7 @@ USE_JPEG = os.getenv("TORCHVISION_USE_JPEG", "1") == "1" USE_WEBP = os.getenv("TORCHVISION_USE_WEBP", "1") == "1" USE_NVJPEG = os.getenv("TORCHVISION_USE_NVJPEG", "1") == "1" +USE_ROCJPEG = os.getenv("TORCHVISION_USE_ROCJPEG", "1") == "1" NVCC_FLAGS = os.getenv("NVCC_FLAGS", None) # Note: the GPU video decoding stuff used to be called "video codec", which # isn't an accurate or descriptive name considering there are at least 2 other @@ -52,6 +53,7 @@ print(f"{USE_JPEG = }") print(f"{USE_WEBP = }") print(f"{USE_NVJPEG = }") +print(f"{USE_ROCJPEG = }") print(f"{NVCC_FLAGS = }") print(f"{USE_CPU_VIDEO_DECODER = }") print(f"{USE_GPU_VIDEO_DECODER = }") @@ -350,18 +352,23 @@ def make_image_extension(): else: warnings.warn("Building torchvision without WEBP support") - if USE_NVJPEG and (torch.cuda.is_available() or FORCE_CUDA): + if (USE_NVJPEG or USE_ROCJPEG) and (torch.cuda.is_available() or FORCE_CUDA): nvjpeg_found = CUDA_HOME is not None and (Path(CUDA_HOME) / "include/nvjpeg.h").exists() - + rocjpeg_found = ROCM_HOME is not None and (Path(ROCM_HOME) / "include/rocjpeg/rocjpeg.h").exists() if nvjpeg_found: print("Building torchvision with NVJPEG image support") libraries.append("nvjpeg") define_macros += [("NVJPEG_FOUND", 1)] Extension = CUDAExtension + elif rocjpeg_found: + print("Building torchvision with ROCJPEG image support") + libraries.append("rocjpeg") + define_macros += [("ROCJPEG_FOUND", 1)] + Extension = CUDAExtension else: - warnings.warn("Building torchvision without NVJPEG support") - elif USE_NVJPEG: - warnings.warn("Building torchvision without NVJPEG support") + warnings.warn("Building torchvision without NVJPEG or ROCJPEG support") + elif (USE_NVJPEG or USE_ROCJPEG): + warnings.warn("Building torchvision without NVJPEG or ROCJPEG support") return Extension( name="torchvision.image", diff --git a/test/test_image.py b/test/test_image.py index b11dd67ca12..e30b5695241 100644 --- a/test/test_image.py +++ b/test/test_image.py @@ -406,8 +406,10 @@ def test_read_interlaced_png(): @needs_cuda -@pytest.mark.parametrize("mode", [ImageReadMode.UNCHANGED, ImageReadMode.GRAY, ImageReadMode.RGB]) -@pytest.mark.parametrize("scripted", (False, True)) +# @pytest.mark.parametrize("mode", [ImageReadMode.UNCHANGED, ImageReadMode.GRAY, ImageReadMode.RGB]) +@pytest.mark.parametrize("mode", [ImageReadMode.RGB]) +# @pytest.mark.parametrize("scripted", (False, True)) +@pytest.mark.parametrize("scripted", (False, )) def test_decode_jpegs_cuda(mode, scripted): encoded_images = [] for jpeg_path in get_images(IMAGE_ROOT, ".jpg"): @@ -415,15 +417,17 @@ def test_decode_jpegs_cuda(mode, scripted): continue encoded_image = read_file(jpeg_path) encoded_images.append(encoded_image) + encoded_images = encoded_images[:3] + # encoded_images = [encoded_images[0], encoded_images[2], encoded_images[1]] decoded_images_cpu = decode_jpeg(encoded_images, mode=mode) decode_fn = torch.jit.script(decode_jpeg) if scripted else decode_jpeg # test multithreaded decoding # in the current version we prevent this by using a lock but we still want to test it - num_workers = 10 + num_workers = 1 with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor: - futures = [executor.submit(decode_fn, encoded_images, mode, "cuda") for _ in range(num_workers)] + futures = [executor.submit(decode_fn, encoded_images, mode, "cuda:0") for _ in range(num_workers)] decoded_images_threaded = [future.result() for future in futures] assert len(decoded_images_threaded) == num_workers for decoded_images in decoded_images_threaded: @@ -431,7 +435,9 @@ def test_decode_jpegs_cuda(mode, scripted): for decoded_image_cuda, decoded_image_cpu in zip(decoded_images, decoded_images_cpu): assert decoded_image_cuda.shape == decoded_image_cpu.shape assert decoded_image_cuda.dtype == decoded_image_cpu.dtype == torch.uint8 - assert (decoded_image_cuda.cpu().float() - decoded_image_cpu.cpu().float()).abs().mean() < 2 + print(decoded_image_cuda.contiguous()) + print(decoded_image_cpu.contiguous().cpu()) + assert (decoded_image_cuda.contiguous().cpu().float() - decoded_image_cpu.cpu().float()).abs().mean() < 5 @needs_cuda diff --git a/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp b/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp index 85aa6c760c1..9afb18abf32 100644 --- a/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp +++ b/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp @@ -1,5 +1,5 @@ #include "decode_jpegs_cuda.h" -#if !NVJPEG_FOUND +#if !NVJPEG_FOUND && !ROCJPEG_FOUND namespace vision { namespace image { std::vector decode_jpegs_cuda( @@ -11,8 +11,9 @@ std::vector decode_jpegs_cuda( } } // namespace image } // namespace vision +#endif -#else +#if NVJPEG_FOUND #include #include #include @@ -600,3 +601,306 @@ std::vector CUDAJpegDecoder::decode_images( } // namespace vision #endif + +#if ROCJPEG_FOUND + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace vision { +namespace image { + +std::mutex decoderMutex; +std::unique_ptr rocJpegDecoder; + +std::vector decode_jpegs_cuda( + const std::vector& encoded_images, + vision::image::ImageReadMode mode, + torch::Device device) { + C10_LOG_API_USAGE_ONCE( + "torchvision.csrc.io.image.cuda.decode_jpegs_cuda.decode_jpegs_cuda"); + + std::lock_guard lock(decoderMutex); + std::vector contig_images; + contig_images.reserve(encoded_images.size()); + + TORCH_CHECK( + device.is_cuda(), "Expected the device parameter to be a cuda device"); + + for (auto& encoded_image : encoded_images) { + TORCH_CHECK( + encoded_image.dtype() == torch::kU8, "Expected a torch.uint8 tensor"); + + TORCH_CHECK( + !encoded_image.is_cuda(), + "The input tensor must be on CPU when decoding with nvjpeg") + + TORCH_CHECK( + encoded_image.dim() == 1 && encoded_image.numel() > 0, + "Expected a non empty 1-dimensional tensor"); + + // nvjpeg requires images to be contiguous + if (encoded_image.is_contiguous()) { + contig_images.push_back(encoded_image); + } else { + contig_images.push_back(encoded_image.contiguous()); + } + } + + at::cuda::CUDAGuard device_guard(device); + + if (rocJpegDecoder == nullptr || device != rocJpegDecoder->target_device) { + if (rocJpegDecoder != nullptr) { + rocJpegDecoder.reset(new RocJpegDecoder(device)); + } else { + rocJpegDecoder = std::make_unique(device); + std::atexit([]() { rocJpegDecoder.reset(); }); + } + } + + RocJpegOutputFormat output_format; + + switch (mode) { + case vision::image::IMAGE_READ_MODE_UNCHANGED: + output_format = ROCJPEG_OUTPUT_NATIVE; + break; + case vision::image::IMAGE_READ_MODE_GRAY: + output_format = ROCJPEG_OUTPUT_Y; + break; + case vision::image::IMAGE_READ_MODE_RGB: + output_format = ROCJPEG_OUTPUT_RGB_PLANAR; + break; + default: + TORCH_CHECK( + false, "The provided mode is not supported for JPEG decoding on GPU"); + } + + try { + at::cuda::CUDAEvent event; + auto result = rocJpegDecoder->decode_images(contig_images, output_format); + auto current_stream{ + device.has_index() ? at::cuda::getCurrentCUDAStream( + rocJpegDecoder->original_device.index()) + : at::cuda::getCurrentCUDAStream()}; + event.record(rocJpegDecoder->stream); + event.block(current_stream); + return result; + } catch (const std::exception& e) { + if (typeid(e) != typeid(std::runtime_error)) { + TORCH_CHECK(false, "Error while decoding JPEG images: ", e.what()); + } else { + throw; + } + } +} + +RocJpegDecoder::RocJpegDecoder(const torch::Device& target_device) + : original_device{torch::kCUDA, c10::cuda::current_device()}, + target_device{target_device}, + stream{ + target_device.has_index() + ? at::cuda::getStreamFromPool(false, target_device.index()) + : at::cuda::getStreamFromPool(false)} { + int device_id = target_device.index(); + CHECK_HIP(hipSetDevice(device_id)); + RocJpegStatus status; + RocJpegBackend rocjpeg_backend = ROCJPEG_BACKEND_HARDWARE; + + status = rocJpegCreate(rocjpeg_backend, device_id, &rocjpeg_handle); + TORCH_CHECK( + status == ROCJPEG_STATUS_SUCCESS, + "Failed to initialize rocjpeg with hardware backend"); + + status = rocJpegStreamCreate(&rocjpeg_stream_handles[0]); + TORCH_CHECK( + status == ROCJPEG_STATUS_SUCCESS, "Failed to initialize rocjpeg stream"); + + status = rocJpegStreamCreate(&rocjpeg_stream_handles[1]); + TORCH_CHECK( + status == ROCJPEG_STATUS_SUCCESS, "Failed to initialize rocjpeg stream"); +} + +RocJpegDecoder::~RocJpegDecoder() { + rocJpegDestroy(rocjpeg_handle); + rocJpegStreamDestroy(rocjpeg_stream_handles[0]); + rocJpegStreamDestroy(rocjpeg_stream_handles[1]); +} + +static inline int align(int value, int alignment) { + return (value + alignment - 1) & ~(alignment - 1); +} + +std::vector RocJpegDecoder::decode_images( + const std::vector& encoded_images, + const RocJpegOutputFormat& output_format) { + /* + This function decodes a batch of jpeg bitstreams. + + Args: + - encoded_images (std::vector): a vector of tensors + containing the jpeg bitstreams to be decoded + - output_format (RocJpegOutputFormat): ROCJPEG_OUTPUT_RGB, ROCJPEG_OUTPUT_Y + or ROCJPEG_OUTPUT_NATIVE + - device (torch::Device): The desired CUDA device for the returned Tensors + + Returns: + - output_tensors (std::vector): a vector of Tensors + containing the decoded images + */ + + int num_images = encoded_images.size(); + std::vector output_tensors{num_images}; + RocJpegStatus rocjpeg_status; + cudaError_t cudaStatus; + + // baseline JPEGs can be batch decoded with hardware support + std::vector channels(num_images); + + cudaStatus = cudaStreamSynchronize(stream); + TORCH_CHECK( + cudaStatus == cudaSuccess, + "Failed to synchronize CUDA stream: ", + cudaStatus); + + constexpr int batch_size = 2; + RocJpegUtils rocjpeg_utils; + std::string chroma_sub_sampling = ""; + uint8_t num_components; + RocJpegChromaSubsampling temp_subsampling; + std::vector temp_widths(ROCJPEG_MAX_COMPONENT, 0); + std::vector temp_heights(ROCJPEG_MAX_COMPONENT, 0); + RocJpegDecodeParams decode_params = {}; + decode_params.output_format = output_format; + std::vector decode_params_batch; + decode_params_batch.resize(batch_size, decode_params); + std::vector output_images; + output_images.resize(batch_size); + int current_batch_size = 0; + uint32_t channel_sizes[ROCJPEG_MAX_COMPONENT] = {}; + uint32_t num_channels = 0; + std::vector> prior_channel_sizes; + prior_channel_sizes.resize( + batch_size, std::vector(ROCJPEG_MAX_COMPONENT, 0)); + + for (int i = 0; i < num_images; i += batch_size) { + int batch_end = std::min(i + batch_size, num_images); + for (int j = i; j < batch_end; j++) { + int index = j - i; + rocjpeg_status = rocJpegStreamParse( + (unsigned char*)encoded_images[j].data_ptr(), + encoded_images[j].numel(), + rocjpeg_stream_handles[index]); + if (rocjpeg_status != ROCJPEG_STATUS_SUCCESS) { + TORCH_CHECK( + false, + "ERROR: Failed to parse the input jpeg stream with ", + rocJpegGetErrorName(rocjpeg_status)); + } + CHECK_ROCJPEG(rocJpegGetImageInfo( + rocjpeg_handle, + rocjpeg_stream_handles[index], + &num_components, + &temp_subsampling, + temp_widths.data(), + temp_heights.data())); + rocjpeg_utils.GetChromaSubsamplingStr( + temp_subsampling, chroma_sub_sampling); + if (temp_widths[0] < 64 || temp_heights[0] < 64) { + TORCH_CHECK( + false, "The image resolution is not supported by VCN Hardware"); + } + if (temp_subsampling == ROCJPEG_CSS_411 || + temp_subsampling == ROCJPEG_CSS_UNKNOWN) { + TORCH_CHECK( + false, "The chroma sub-sampling is not supported by VCN Hardware"); + } + if (rocjpeg_utils.GetChannelPitchAndSizes( + decode_params_batch[index], + temp_subsampling, + temp_widths.data(), + temp_heights.data(), + num_channels, + output_images[index], + channel_sizes)) { + TORCH_CHECK(false, "ERROR: Failed to get the channel pitch and sizes"); + } + + uint32_t roi_width = decode_params_batch[index].crop_rectangle.right - decode_params_batch[index].crop_rectangle.left; + uint32_t roi_height = decode_params_batch[index].crop_rectangle.bottom - decode_params_batch[index].crop_rectangle.top; + bool is_roi_valid = (roi_width > 0 && roi_height > 0 && roi_width <= temp_widths[0] && roi_height <= temp_heights[0]) ? true : false; + std::cout << "is_roi_valid: " << is_roi_valid << "\n"; + uint32_t width = is_roi_valid ? align(roi_width, 16) : align(temp_widths[0], 16); + uint32_t height = is_roi_valid ? align(roi_height, 16) : align(temp_heights[0], 16); + auto output_tensor = torch::zeros( + {int64_t(num_channels), + int64_t(height), + int64_t(width)}, + torch::dtype(torch::kU8).device(target_device)); + channels[j] = num_channels; + + // for (int n = 0; n < (int)num_channels; n++) { + // output_images[current_batch_size].channel[n] = + // output_tensor[n].data_ptr(); + // } + + // allocate memory for each channel and reuse them if the sizes remain + // unchanged for a new image. + for (int c = 0; c < (int)num_channels; c++) { + output_images[index].channel[c] = output_tensor[c].data_ptr(); + } + // for (int c = (int)num_channels; c < ROCJPEG_MAX_COMPONENT; c++) { + // output_images[index].channel[c] = NULL; + // output_images[index].pitch[c] = 0; + // } + // output_tensors[j] = output_tensor; // output_tensor.narrow(1, 0, temp_heights[0]).narrow(2, 0, temp_widths[0]); + current_batch_size++; + output_tensors[j] = output_tensor.narrow(1, 0, temp_heights[0]).narrow(2, 0, temp_widths[0]); + } + + // if (current_batch_size == 2) { + if (current_batch_size > 0) { + CHECK_ROCJPEG(rocJpegDecodeBatched( + rocjpeg_handle, + rocjpeg_stream_handles, + current_batch_size, + decode_params_batch.data(), + output_images.data())); + } + + current_batch_size = 0; + } + + cudaStatus = cudaStreamSynchronize(stream); + TORCH_CHECK( + cudaStatus == cudaSuccess, + "Failed to synchronize CUDA stream: ", + cudaStatus); + + // prune extraneous channels from single channel images + if (output_format == ROCJPEG_OUTPUT_NATIVE) { + for (std::vector::size_type i = 0; i < output_tensors.size(); + ++i) { + if (channels[i] == 1) { + output_tensors[i] = output_tensors[i][0].unsqueeze(0).clone(); + } + } + } + + cudaDeviceSynchronize(); + return output_tensors; +} + +} // namespace image +} // namespace vision + +#endif diff --git a/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.h b/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.h index 6f72d9e35b2..a052e5e354e 100644 --- a/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.h +++ b/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.h @@ -4,6 +4,7 @@ #include "../common.h" #if NVJPEG_FOUND + #include #include @@ -42,4 +43,35 @@ class CUDAJpegDecoder { }; } // namespace image } // namespace vision + +#endif + +#if ROCJPEG_FOUND + +#include +#include +#include "rocjpeg_samples_utils.h" + +namespace vision { +namespace image { +class RocJpegDecoder { + public: + RocJpegDecoder(const torch::Device& target_device); + ~RocJpegDecoder(); + + std::vector decode_images( + const std::vector& encoded_images, + const RocJpegOutputFormat& output_format); + + const torch::Device original_device; + const torch::Device target_device; + const c10::cuda::CUDAStream stream; + + private: + RocJpegStreamHandle rocjpeg_stream_handles[2]; + RocJpegHandle rocjpeg_handle; +}; +} // namespace image +} // namespace vision + #endif From a371c3e57b4c838f3a707dc08e91c90d6e969f5c Mon Sep 17 00:00:00 2001 From: xytpai Date: Thu, 8 Jan 2026 03:21:08 +0000 Subject: [PATCH 2/6] update rocjpeg utils --- .../io/image/cuda/rocjpeg_samples_utils.h | 567 ++++++++++++++++++ 1 file changed, 567 insertions(+) create mode 100644 torchvision/csrc/io/image/cuda/rocjpeg_samples_utils.h diff --git a/torchvision/csrc/io/image/cuda/rocjpeg_samples_utils.h b/torchvision/csrc/io/image/cuda/rocjpeg_samples_utils.h new file mode 100644 index 00000000000..3a9595dcc4f --- /dev/null +++ b/torchvision/csrc/io/image/cuda/rocjpeg_samples_utils.h @@ -0,0 +1,567 @@ +/* +Copyright (c) 2024 - 2025 Advanced Micro Devices, Inc. All rights reserved. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +*/ +#ifndef ROC_JPEG_SAMPLES_COMMON +#define ROC_JPEG_SAMPLES_COMMON +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#if __cplusplus >= 201703L && __has_include() + #include + namespace fs = std::filesystem; +#else + #include + namespace fs = std::experimental::filesystem; +#endif +#include +#include "rocjpeg/rocjpeg.h" + +#define CHECK_ROCJPEG(call) { \ + RocJpegStatus rocjpeg_status = (call); \ + if (rocjpeg_status != ROCJPEG_STATUS_SUCCESS) { \ + std::cerr << #call << " returned " << rocJpegGetErrorName(rocjpeg_status) << " at " << __FILE__ << ":" << __LINE__ << std::endl;\ + exit(1); \ + } \ +} + +#define CHECK_HIP(call) { \ + hipError_t hip_status = (call); \ + if (hip_status != hipSuccess) { \ + std::cout << "HIP failure: 'status: " << hipGetErrorName(hip_status) << "' at " << __FILE__ << ":" << __LINE__ << std::endl;\ + exit(1); \ + } \ +} + +/** + * @class RocJpegUtils + * @brief Utility class for rocJPEG samples. + * + * This class provides utility functions for rocJPEG samples, such as parsing command line arguments, + * getting file paths, initializing HIP device, getting chroma subsampling string, getting channel pitch and sizes, + * getting output file extension, and saving images. + */ +class RocJpegUtils { +public: + /** + * @brief Parses the command line arguments. + * + * This function parses the command line arguments and sets the corresponding variables. + * + * @param input_path The input path. + * @param output_file_path The output file path. + * @param save_images Flag indicating whether to save images. + * @param device_id The device ID. + * @param rocjpeg_backend The rocJPEG backend. + * @param decode_params The rocJPEG decode parameters. + * @param num_threads The number of threads. + * @param crop The crop rectangle. + * @param argc The number of command line arguments. + * @param argv The command line arguments. + */ + static void ParseCommandLine(std::string &input_path, std::string &output_file_path, bool &save_images, int &device_id, + RocJpegBackend &rocjpeg_backend, RocJpegDecodeParams &decode_params, int *num_threads, int *batch_size, int argc, char *argv[]) { + if(argc <= 1) { + ShowHelpAndExit("", num_threads != nullptr, batch_size != nullptr); + } + for (int i = 1; i < argc; i++) { + if (!strcmp(argv[i], "-h")) { + ShowHelpAndExit("", num_threads != nullptr, batch_size != nullptr); + } + if (!strcmp(argv[i], "-i")) { + if (++i == argc) { + ShowHelpAndExit("-i", num_threads != nullptr, batch_size != nullptr); + } + input_path = argv[i]; + continue; + } + if (!strcmp(argv[i], "-o")) { + if (++i == argc) { + ShowHelpAndExit("-o", num_threads != nullptr, batch_size != nullptr); + } + output_file_path = argv[i]; + save_images = true; + continue; + } + if (!strcmp(argv[i], "-d")) { + if (++i == argc) { + ShowHelpAndExit("-d", num_threads != nullptr, batch_size != nullptr); + } + device_id = atoi(argv[i]); + continue; + } + if (!strcmp(argv[i], "-be")) { + if (++i == argc) { + ShowHelpAndExit("-be", num_threads != nullptr, batch_size != nullptr); + } + rocjpeg_backend = static_cast(atoi(argv[i])); + continue; + } + if (!strcmp(argv[i], "-fmt")) { + if (++i == argc) { + ShowHelpAndExit("-fmt", num_threads != nullptr, batch_size != nullptr); + } + std::string selected_output_format = argv[i]; + if (selected_output_format == "native") { + decode_params.output_format = ROCJPEG_OUTPUT_NATIVE; + } else if (selected_output_format == "yuv_planar") { + decode_params.output_format = ROCJPEG_OUTPUT_YUV_PLANAR; + } else if (selected_output_format == "y") { + decode_params.output_format = ROCJPEG_OUTPUT_Y; + } else if (selected_output_format == "rgb") { + decode_params.output_format = ROCJPEG_OUTPUT_RGB; + } else if (selected_output_format == "rgb_planar") { + decode_params.output_format = ROCJPEG_OUTPUT_RGB_PLANAR; + } else { + ShowHelpAndExit(argv[i], num_threads != nullptr); + } + continue; + } + if (!strcmp(argv[i], "-t")) { + if (++i == argc) { + ShowHelpAndExit("-t", num_threads != nullptr, batch_size != nullptr); + } + if (num_threads != nullptr) { + *num_threads = atoi(argv[i]); + if (*num_threads <= 0 || *num_threads > 32) { + ShowHelpAndExit(argv[i], num_threads != nullptr, batch_size != nullptr); + } + } + continue; + } + if (!strcmp(argv[i], "-b")) { + if (++i == argc) { + ShowHelpAndExit("-b", num_threads != nullptr, batch_size != nullptr); + } + if (batch_size != nullptr) + *batch_size = atoi(argv[i]); + continue; + } + if (!strcmp(argv[i], "-crop")) { + if (++i == argc || 4 != sscanf(argv[i], "%hd,%hd,%hd,%hd", &decode_params.crop_rectangle.left, &decode_params.crop_rectangle.top, &decode_params.crop_rectangle.right, &decode_params.crop_rectangle.bottom)) { + ShowHelpAndExit("-crop"); + } + if ((&decode_params.crop_rectangle.right - &decode_params.crop_rectangle.left) % 2 == 1 || (&decode_params.crop_rectangle.bottom - &decode_params.crop_rectangle.top) % 2 == 1) { + std::cout << "output crop rectangle must have width and height of even numbers" << std::endl; + exit(1); + } + continue; + } + ShowHelpAndExit(argv[i], num_threads != nullptr, batch_size != nullptr); + } + } + + /** + * Checks if a file is a JPEG file. + * + * @param filePath The path to the file to be checked. + * @return True if the file is a JPEG file, false otherwise. + */ + static bool IsJPEG(const std::string& filePath) { + std::ifstream file(filePath, std::ios::binary); + if (!file.is_open()) { + std::cerr << "Failed to open file: " << filePath << std::endl; + return false; + } + + unsigned char buffer[2]; + file.read(reinterpret_cast(buffer), 2); + file.close(); + + // The first two bytes of every JPEG stream are always 0xFFD8, which represents the Start of Image (SOI) marker. + return buffer[0] == 0xFF && buffer[1] == 0xD8; + } + + /** + * @brief Gets the file paths. + * + * This function gets the file paths based on the input path and sets the corresponding variables. + * + * @param input_path The input path. + * @param file_paths The vector to store the file paths. + * @param is_dir Flag indicating whether the input path is a directory. + * @param is_file Flag indicating whether the input path is a file. + * @return True if successful, false otherwise. + */ + static bool GetFilePaths(std::string &input_path, std::vector &file_paths, bool &is_dir, bool &is_file) { + std::cout << "Reading images from disk, please wait!" << std::endl; + if (!fs::exists(input_path)) { + std::cerr << "ERROR: the input path does not exist!" << std::endl; + return false; + } + is_dir = fs::is_directory(input_path); + is_file = fs::is_regular_file(input_path); + if (is_dir) { + for (const auto &entry : fs::recursive_directory_iterator(input_path)) { + if (fs::is_regular_file(entry) && IsJPEG(entry.path().string())) { + file_paths.push_back(entry.path().string()); + } + } + } else if (is_file && IsJPEG(input_path)) { + file_paths.push_back(input_path); + } else { + std::cerr << "ERROR: the input path does not contain JPEG files!" << std::endl; + return false; + } + return true; + } + + /** + * @brief Initializes the HIP device. + * + * This function initializes the HIP device with the specified device ID. + * + * @param device_id The device ID. + * @return True if successful, false otherwise. + */ + static bool InitHipDevice(int device_id) { + int num_devices; + hipDeviceProp_t hip_dev_prop; + CHECK_HIP(hipGetDeviceCount(&num_devices)); + if (num_devices < 1) { + std::cerr << "ERROR: didn't find any GPU!" << std::endl; + return false; + } + if (device_id >= num_devices) { + std::cerr << "ERROR: the requested device_id is not found!" << std::endl; + return false; + } + CHECK_HIP(hipSetDevice(device_id)); + CHECK_HIP(hipGetDeviceProperties(&hip_dev_prop, device_id)); + + std::cout << "Using GPU device " << device_id << ": " << hip_dev_prop.name << "[" << hip_dev_prop.gcnArchName << "] on PCI bus " << + std::setfill('0') << std::setw(2) << std::right << std::hex << hip_dev_prop.pciBusID << ":" << std::setfill('0') << std::setw(2) << + std::right << std::hex << hip_dev_prop.pciDomainID << "." << hip_dev_prop.pciDeviceID << std::dec << std::endl; + + return true; + } + + /** + * @brief Gets the chroma subsampling string. + * + * This function gets the chroma subsampling string based on the specified subsampling value. + * + * @param subsampling The chroma subsampling value. + * @param chroma_sub_sampling The string to store the chroma subsampling. + */ + void GetChromaSubsamplingStr(RocJpegChromaSubsampling subsampling, std::string &chroma_sub_sampling) { + switch (subsampling) { + case ROCJPEG_CSS_444: + chroma_sub_sampling = "YUV 4:4:4"; + break; + case ROCJPEG_CSS_440: + chroma_sub_sampling = "YUV 4:4:0"; + break; + case ROCJPEG_CSS_422: + chroma_sub_sampling = "YUV 4:2:2"; + break; + case ROCJPEG_CSS_420: + chroma_sub_sampling = "YUV 4:2:0"; + break; + case ROCJPEG_CSS_411: + chroma_sub_sampling = "YUV 4:1:1"; + break; + case ROCJPEG_CSS_400: + chroma_sub_sampling = "YUV 4:0:0"; + break; + case ROCJPEG_CSS_UNKNOWN: + chroma_sub_sampling = "UNKNOWN"; + break; + default: + chroma_sub_sampling = ""; + break; + } + } + + /** + * @brief Gets the channel pitch and sizes. + * + * This function gets the channel pitch and sizes based on the specified output format, chroma subsampling, + * output image, and channel sizes. + * + * @param decode_params The decode parameters that specify the output format and crop rectangle. + * @param subsampling The chroma subsampling. + * @param widths The array to store the channel widths. + * @param heights The array to store the channel heights. + * @param num_channels The number of channels. + * @param output_image The output image. + * @param channel_sizes The array to store the channel sizes. + * @return The channel pitch. + */ + int GetChannelPitchAndSizes(RocJpegDecodeParams decode_params, RocJpegChromaSubsampling subsampling, uint32_t *widths, uint32_t *heights, + uint32_t &num_channels, RocJpegImage &output_image, uint32_t *channel_sizes) { + + bool is_roi_valid = false; + uint32_t roi_width; + uint32_t roi_height; + roi_width = decode_params.crop_rectangle.right - decode_params.crop_rectangle.left; + roi_height = decode_params.crop_rectangle.bottom - decode_params.crop_rectangle.top; + if (roi_width > 0 && roi_height > 0 && roi_width <= widths[0] && roi_height <= heights[0]) { + is_roi_valid = true; + } + switch (decode_params.output_format) { + case ROCJPEG_OUTPUT_NATIVE: + switch (subsampling) { + case ROCJPEG_CSS_444: + num_channels = 3; + output_image.pitch[2] = output_image.pitch[1] = output_image.pitch[0] = is_roi_valid ? align(roi_width, mem_alignment) : align(widths[0], mem_alignment); + channel_sizes[2] = channel_sizes[1] = channel_sizes[0] = output_image.pitch[0] * (is_roi_valid ? align(roi_height, mem_alignment) : align(heights[0], mem_alignment)); + break; + case ROCJPEG_CSS_440: + num_channels = 3; + output_image.pitch[2] = output_image.pitch[1] = output_image.pitch[0] = is_roi_valid ? align(roi_width, mem_alignment) : align(widths[0], mem_alignment); + channel_sizes[0] = output_image.pitch[0] * (is_roi_valid ? align(roi_height, mem_alignment) : align(heights[0], mem_alignment)); + channel_sizes[2] = channel_sizes[1] = output_image.pitch[0] * (is_roi_valid ? align(roi_height >> 1, mem_alignment) : align(heights[0] >> 1, mem_alignment)); + break; + case ROCJPEG_CSS_422: + num_channels = 1; + output_image.pitch[0] = (is_roi_valid ? align(roi_width, mem_alignment) : align(widths[0], mem_alignment)) * 2; + channel_sizes[0] = output_image.pitch[0] * (is_roi_valid ? align(roi_height, mem_alignment) : align(heights[0], mem_alignment)); + break; + case ROCJPEG_CSS_420: + num_channels = 2; + output_image.pitch[1] = output_image.pitch[0] = is_roi_valid ? align(roi_width, mem_alignment) : align(widths[0], mem_alignment); + channel_sizes[0] = output_image.pitch[0] * (is_roi_valid ? align(roi_height, mem_alignment) : align(heights[0], mem_alignment)); + channel_sizes[1] = output_image.pitch[1] * (is_roi_valid ? align(roi_height >> 1, mem_alignment) : align(heights[0] >> 1, mem_alignment)); + break; + case ROCJPEG_CSS_400: + num_channels = 1; + output_image.pitch[0] = is_roi_valid ? align(roi_width, mem_alignment) : align(widths[0], mem_alignment); + channel_sizes[0] = output_image.pitch[0] * (is_roi_valid ? align(roi_height, mem_alignment) : align(heights[0], mem_alignment)); + break; + default: + std::cout << "Unknown chroma subsampling!" << std::endl; + return EXIT_FAILURE; + } + break; + case ROCJPEG_OUTPUT_YUV_PLANAR: + if (subsampling == ROCJPEG_CSS_400) { + num_channels = 1; + output_image.pitch[0] = is_roi_valid ? align(roi_width, mem_alignment) : align(widths[0], mem_alignment); + channel_sizes[0] = output_image.pitch[0] * (is_roi_valid ? align(roi_height, mem_alignment) : align(heights[0], mem_alignment)); + } else { + num_channels = 3; + output_image.pitch[0] = is_roi_valid ? align(roi_width, mem_alignment) : align(widths[0], mem_alignment); + output_image.pitch[1] = is_roi_valid ? align(roi_width, mem_alignment) : align(widths[1], mem_alignment); + output_image.pitch[2] = is_roi_valid ? align(roi_width, mem_alignment) : align(widths[2], mem_alignment); + channel_sizes[0] = output_image.pitch[0] * (is_roi_valid ? align(roi_height, mem_alignment) : align(heights[0], mem_alignment)); + channel_sizes[1] = output_image.pitch[1] * (is_roi_valid ? align(roi_height, mem_alignment) : align(heights[1], mem_alignment)); + channel_sizes[2] = output_image.pitch[2] * (is_roi_valid ? align(roi_height, mem_alignment) : align(heights[2], mem_alignment)); + } + break; + case ROCJPEG_OUTPUT_Y: + num_channels = 1; + output_image.pitch[0] = is_roi_valid ? align(roi_width, mem_alignment) : align(widths[0], mem_alignment); + channel_sizes[0] = output_image.pitch[0] * (is_roi_valid ? align(roi_height, mem_alignment) : align(heights[0], mem_alignment)); + break; + case ROCJPEG_OUTPUT_RGB: + num_channels = 1; + output_image.pitch[0] = (is_roi_valid ? align(roi_width, mem_alignment) : align(widths[0], mem_alignment)) * 3; + channel_sizes[0] = output_image.pitch[0] * (is_roi_valid ? align(roi_height, mem_alignment) : align(heights[0], mem_alignment)); + break; + case ROCJPEG_OUTPUT_RGB_PLANAR: + num_channels = 3; + output_image.pitch[2] = output_image.pitch[1] = output_image.pitch[0] = is_roi_valid ? align(roi_width, mem_alignment) : align(widths[0], mem_alignment); + channel_sizes[2] = channel_sizes[1] = channel_sizes[0] = output_image.pitch[0] * (is_roi_valid ? align(roi_height, mem_alignment) : align(heights[0], mem_alignment)); + break; + default: + std::cout << "Unknown output format!" << std::endl; + return EXIT_FAILURE; + } + return EXIT_SUCCESS; + } + + /** + * @brief Gets the output file extension. + * + * This function gets the output file extension based on the specified output format, base file name, + * image width, image height, and file name for saving. + * + * @param output_format The output format. + * @param base_file_name The base file name. + * @param image_width The image width. + * @param image_height The image height. + * @param file_name_for_saving The string to store the file name for saving. + */ + void GetOutputFileExt(RocJpegOutputFormat output_format, std::string &base_file_name, uint32_t image_width, uint32_t image_height, RocJpegChromaSubsampling subsampling, std::string &file_name_for_saving) { + std::string file_extension; + std::string::size_type const p(base_file_name.find_last_of('.')); + std::string file_name_no_ext = base_file_name.substr(0, p); + std::string format_description = ""; + switch (output_format) { + case ROCJPEG_OUTPUT_NATIVE: + file_extension = "yuv"; + switch (subsampling) { + case ROCJPEG_CSS_444: + format_description = "444"; + break; + case ROCJPEG_CSS_440: + format_description = "440"; + break; + case ROCJPEG_CSS_422: + format_description = "422_yuyv"; + break; + case ROCJPEG_CSS_420: + format_description = "nv12"; + break; + case ROCJPEG_CSS_400: + format_description = "400"; + break; + default: + std::cout << "Unknown chroma subsampling!" << std::endl; + return; + } + break; + case ROCJPEG_OUTPUT_YUV_PLANAR: + file_extension = "yuv"; + format_description = "planar"; + break; + case ROCJPEG_OUTPUT_Y: + file_extension = "yuv"; + format_description = "400"; + break; + case ROCJPEG_OUTPUT_RGB: + file_extension = "rgb"; + format_description = "packed"; + break; + case ROCJPEG_OUTPUT_RGB_PLANAR: + file_extension = "rgb"; + format_description = "planar"; + break; + default: + file_extension = ""; + break; + } + file_name_for_saving += "//" + file_name_no_ext + "_" + std::to_string(image_width) + "x" + + std::to_string(image_height) + "_" + format_description + "." + file_extension; + } + +private: + static const int mem_alignment = 16; + /** + * @brief Shows the help message and exits. + * + * This function shows the help message and exits the program. + * + * @param option The option to display in the help message (optional). + * @param show_threads Flag indicating whether to show the number of threads in the help message. + */ + static void ShowHelpAndExit(const char *option = nullptr, bool show_threads = false, bool show_batch_size = false) { + std::cout << "Options:\n" + "-i [input path] - input path to a single JPEG image or a directory containing JPEG images - [required]\n" + "-be [backend] - select rocJPEG backend (0 for hardware-accelerated JPEG decoding using VCN,\n" + " 1 for hybrid JPEG decoding using CPU and GPU HIP kernels (currently not supported)) [optional - default: 0]\n" + "-fmt [output format] - select rocJPEG output format for decoding, one of the [native, yuv_planar, y, rgb, rgb_planar] - [optional - default: native]\n" + "-o [output path] - path to an output file or a path to an existing directory - write decoded images to a file or an existing directory based on selected output format - [optional]\n" + "-crop [crop rectangle] - crop rectangle for output in a comma-separated format: left,top,right,bottom - [optional]\n" + "-d [device id] - specify the GPU device id for the desired device (use 0 for the first device, 1 for the second device, and so on) [optional - default: 0]\n"; + if (show_threads) { + std::cout << "-t [threads] - number of threads (<= 32) for parallel JPEG decoding - [optional - default: 1]\n"; + } + if (show_batch_size) { + std::cout << "-b [batch_size] - decode images from input by batches of a specified size - [optional - default: 1]\n"; + } + exit(0); + } + /** + * @brief Aligns a value to a specified alignment. + * + * This function takes a value and aligns it to the specified alignment. It returns the aligned value. + * + * @param value The value to be aligned. + * @param alignment The alignment value. + * @return The aligned value. + */ + static inline int align(int value, int alignment) { + return (value + alignment - 1) & ~(alignment - 1); + } +}; + +class ThreadPool { + public: + ThreadPool(int nthreads) : shutdown_(false) { + // Create the specified number of threads + threads_.reserve(nthreads); + for (int i = 0; i < nthreads; ++i) + threads_.emplace_back(std::bind(&ThreadPool::ThreadEntry, this, i)); + } + + ~ThreadPool() {} + + void JoinThreads() { + { + // Unblock any threads and tell them to stop + std::unique_lock lock(mutex_); + shutdown_ = true; + cond_var_.notify_all(); + } + + // Wait for all threads to stop + for (auto& thread : threads_) + thread.join(); + } + + void ExecuteJob(std::function func) { + // Place a job on the queue and unblock a thread + std::unique_lock lock(mutex_); + decode_jobs_queue_.emplace(std::move(func)); + cond_var_.notify_one(); + } + + protected: + void ThreadEntry(int i) { + std::function execute_decode_job; + + while (true) { + { + std::unique_lock lock(mutex_); + cond_var_.wait(lock, [&] {return shutdown_ || !decode_jobs_queue_.empty();}); + if (decode_jobs_queue_.empty()) { + // No jobs to do; shutting down + return; + } + + execute_decode_job = std::move(decode_jobs_queue_.front()); + decode_jobs_queue_.pop(); + } + + // Execute the decode job without holding any locks + execute_decode_job(); + } + } + + std::mutex mutex_; + std::condition_variable cond_var_; + bool shutdown_; + std::queue> decode_jobs_queue_; + std::vector threads_; +}; + +#endif //ROC_JPEG_SAMPLES_COMMON From e4c4fd0f1b9eeb5218042437ea15d7cb537abff8 Mon Sep 17 00:00:00 2001 From: Yutao Xu Date: Thu, 8 Jan 2026 11:29:14 +0800 Subject: [PATCH 3/6] rm cout --- torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp b/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp index 9afb18abf32..d0c2f2cdacf 100644 --- a/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp +++ b/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp @@ -838,7 +838,6 @@ std::vector RocJpegDecoder::decode_images( uint32_t roi_width = decode_params_batch[index].crop_rectangle.right - decode_params_batch[index].crop_rectangle.left; uint32_t roi_height = decode_params_batch[index].crop_rectangle.bottom - decode_params_batch[index].crop_rectangle.top; bool is_roi_valid = (roi_width > 0 && roi_height > 0 && roi_width <= temp_widths[0] && roi_height <= temp_heights[0]) ? true : false; - std::cout << "is_roi_valid: " << is_roi_valid << "\n"; uint32_t width = is_roi_valid ? align(roi_width, 16) : align(temp_widths[0], 16); uint32_t height = is_roi_valid ? align(roi_height, 16) : align(temp_heights[0], 16); auto output_tensor = torch::zeros( From 3d9041c4012f2e265a8da057fc7655b6680c2aed Mon Sep 17 00:00:00 2001 From: xytpai Date: Fri, 16 Jan 2026 09:53:44 +0000 Subject: [PATCH 4/6] refine code --- setup.py | 2 +- test/test_image.py | 16 +- .../csrc/io/image/cuda/decode_jpegs_cuda.cpp | 233 +++++++++++++++--- .../csrc/io/image/cuda/decode_jpegs_cuda.h | 22 +- 4 files changed, 232 insertions(+), 41 deletions(-) diff --git a/setup.py b/setup.py index 5ad744e5061..4b9559eb630 100644 --- a/setup.py +++ b/setup.py @@ -22,7 +22,7 @@ USE_JPEG = os.getenv("TORCHVISION_USE_JPEG", "1") == "1" USE_WEBP = os.getenv("TORCHVISION_USE_WEBP", "1") == "1" USE_NVJPEG = os.getenv("TORCHVISION_USE_NVJPEG", "1") == "1" -USE_ROCJPEG = os.getenv("TORCHVISION_USE_ROCJPEG", "1") == "1" +USE_ROCJPEG = os.getenv("TORCHVISION_USE_ROCJPEG", "0") == "1" NVCC_FLAGS = os.getenv("NVCC_FLAGS", None) # Note: the GPU video decoding stuff used to be called "video codec", which # isn't an accurate or descriptive name considering there are at least 2 other diff --git a/test/test_image.py b/test/test_image.py index e30b5695241..b11dd67ca12 100644 --- a/test/test_image.py +++ b/test/test_image.py @@ -406,10 +406,8 @@ def test_read_interlaced_png(): @needs_cuda -# @pytest.mark.parametrize("mode", [ImageReadMode.UNCHANGED, ImageReadMode.GRAY, ImageReadMode.RGB]) -@pytest.mark.parametrize("mode", [ImageReadMode.RGB]) -# @pytest.mark.parametrize("scripted", (False, True)) -@pytest.mark.parametrize("scripted", (False, )) +@pytest.mark.parametrize("mode", [ImageReadMode.UNCHANGED, ImageReadMode.GRAY, ImageReadMode.RGB]) +@pytest.mark.parametrize("scripted", (False, True)) def test_decode_jpegs_cuda(mode, scripted): encoded_images = [] for jpeg_path in get_images(IMAGE_ROOT, ".jpg"): @@ -417,17 +415,15 @@ def test_decode_jpegs_cuda(mode, scripted): continue encoded_image = read_file(jpeg_path) encoded_images.append(encoded_image) - encoded_images = encoded_images[:3] - # encoded_images = [encoded_images[0], encoded_images[2], encoded_images[1]] decoded_images_cpu = decode_jpeg(encoded_images, mode=mode) decode_fn = torch.jit.script(decode_jpeg) if scripted else decode_jpeg # test multithreaded decoding # in the current version we prevent this by using a lock but we still want to test it - num_workers = 1 + num_workers = 10 with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor: - futures = [executor.submit(decode_fn, encoded_images, mode, "cuda:0") for _ in range(num_workers)] + futures = [executor.submit(decode_fn, encoded_images, mode, "cuda") for _ in range(num_workers)] decoded_images_threaded = [future.result() for future in futures] assert len(decoded_images_threaded) == num_workers for decoded_images in decoded_images_threaded: @@ -435,9 +431,7 @@ def test_decode_jpegs_cuda(mode, scripted): for decoded_image_cuda, decoded_image_cpu in zip(decoded_images, decoded_images_cpu): assert decoded_image_cuda.shape == decoded_image_cpu.shape assert decoded_image_cuda.dtype == decoded_image_cpu.dtype == torch.uint8 - print(decoded_image_cuda.contiguous()) - print(decoded_image_cpu.contiguous().cpu()) - assert (decoded_image_cuda.contiguous().cpu().float() - decoded_image_cpu.cpu().float()).abs().mean() < 5 + assert (decoded_image_cuda.cpu().float() - decoded_image_cpu.cpu().float()).abs().mean() < 2 @needs_cuda diff --git a/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp b/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp index d0c2f2cdacf..9b974b7dc03 100644 --- a/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp +++ b/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp @@ -671,18 +671,13 @@ std::vector decode_jpegs_cuda( RocJpegOutputFormat output_format; switch (mode) { - case vision::image::IMAGE_READ_MODE_UNCHANGED: - output_format = ROCJPEG_OUTPUT_NATIVE; - break; - case vision::image::IMAGE_READ_MODE_GRAY: - output_format = ROCJPEG_OUTPUT_Y; - break; case vision::image::IMAGE_READ_MODE_RGB: output_format = ROCJPEG_OUTPUT_RGB_PLANAR; break; default: TORCH_CHECK( - false, "The provided mode is not supported for JPEG decoding on GPU"); + false, + "The provided mode is not supported for ROCJPEG decoding on GPU"); } try { @@ -736,8 +731,184 @@ RocJpegDecoder::~RocJpegDecoder() { rocJpegStreamDestroy(rocjpeg_stream_handles[1]); } +static constexpr int mem_alignment = 16; + static inline int align(int value, int alignment) { - return (value + alignment - 1) & ~(alignment - 1); + return (value + alignment - 1) & ~(alignment - 1); +} + +void getChromaSubsamplingStr( + RocJpegChromaSubsampling subsampling, + std::string& chroma_sub_sampling) { + switch (subsampling) { + case ROCJPEG_CSS_444: + chroma_sub_sampling = "YUV 4:4:4"; + break; + case ROCJPEG_CSS_440: + chroma_sub_sampling = "YUV 4:4:0"; + break; + case ROCJPEG_CSS_422: + chroma_sub_sampling = "YUV 4:2:2"; + break; + case ROCJPEG_CSS_420: + chroma_sub_sampling = "YUV 4:2:0"; + break; + case ROCJPEG_CSS_411: + chroma_sub_sampling = "YUV 4:1:1"; + break; + case ROCJPEG_CSS_400: + chroma_sub_sampling = "YUV 4:0:0"; + break; + case ROCJPEG_CSS_UNKNOWN: + chroma_sub_sampling = "UNKNOWN"; + break; + default: + chroma_sub_sampling = ""; + break; + } +} + +int getChannelPitchAndSizes( + RocJpegDecodeParams decode_params, + RocJpegChromaSubsampling subsampling, + uint32_t* widths, + uint32_t* heights, + uint32_t& num_channels, + RocJpegImage& output_image, + uint32_t* channel_sizes) { + bool is_roi_valid = false; + uint32_t roi_width; + uint32_t roi_height; + roi_width = + decode_params.crop_rectangle.right - decode_params.crop_rectangle.left; + roi_height = + decode_params.crop_rectangle.bottom - decode_params.crop_rectangle.top; + if (roi_width > 0 && roi_height > 0 && roi_width <= widths[0] && + roi_height <= heights[0]) { + is_roi_valid = true; + } + switch (decode_params.output_format) { + case ROCJPEG_OUTPUT_NATIVE: + switch (subsampling) { + case ROCJPEG_CSS_444: + num_channels = 3; + output_image.pitch[2] = output_image.pitch[1] = + output_image.pitch[0] = + is_roi_valid ? align(roi_width, mem_alignment) + : align(widths[0], mem_alignment); + channel_sizes[2] = channel_sizes[1] = channel_sizes[0] = + output_image.pitch[0] * + (is_roi_valid ? align(roi_height, mem_alignment) + : align(heights[0], mem_alignment)); + break; + case ROCJPEG_CSS_440: + num_channels = 3; + output_image.pitch[2] = output_image.pitch[1] = + output_image.pitch[0] = + is_roi_valid ? align(roi_width, mem_alignment) + : align(widths[0], mem_alignment); + channel_sizes[0] = output_image.pitch[0] * + (is_roi_valid ? align(roi_height, mem_alignment) + : align(heights[0], mem_alignment)); + channel_sizes[2] = channel_sizes[1] = output_image.pitch[0] * + (is_roi_valid ? align(roi_height >> 1, mem_alignment) + : align(heights[0] >> 1, mem_alignment)); + break; + case ROCJPEG_CSS_422: + num_channels = 1; + output_image.pitch[0] = + (is_roi_valid ? align(roi_width, mem_alignment) + : align(widths[0], mem_alignment)) * + 2; + channel_sizes[0] = output_image.pitch[0] * + (is_roi_valid ? align(roi_height, mem_alignment) + : align(heights[0], mem_alignment)); + break; + case ROCJPEG_CSS_420: + num_channels = 2; + output_image.pitch[1] = output_image.pitch[0] = is_roi_valid + ? align(roi_width, mem_alignment) + : align(widths[0], mem_alignment); + channel_sizes[0] = output_image.pitch[0] * + (is_roi_valid ? align(roi_height, mem_alignment) + : align(heights[0], mem_alignment)); + channel_sizes[1] = output_image.pitch[1] * + (is_roi_valid ? align(roi_height >> 1, mem_alignment) + : align(heights[0] >> 1, mem_alignment)); + break; + case ROCJPEG_CSS_400: + num_channels = 1; + output_image.pitch[0] = is_roi_valid + ? align(roi_width, mem_alignment) + : align(widths[0], mem_alignment); + channel_sizes[0] = output_image.pitch[0] * + (is_roi_valid ? align(roi_height, mem_alignment) + : align(heights[0], mem_alignment)); + break; + default: + std::cout << "Unknown chroma subsampling!" << std::endl; + return EXIT_FAILURE; + } + break; + case ROCJPEG_OUTPUT_YUV_PLANAR: + if (subsampling == ROCJPEG_CSS_400) { + num_channels = 1; + output_image.pitch[0] = is_roi_valid ? align(roi_width, mem_alignment) + : align(widths[0], mem_alignment); + channel_sizes[0] = output_image.pitch[0] * + (is_roi_valid ? align(roi_height, mem_alignment) + : align(heights[0], mem_alignment)); + } else { + num_channels = 3; + output_image.pitch[0] = is_roi_valid ? align(roi_width, mem_alignment) + : align(widths[0], mem_alignment); + output_image.pitch[1] = is_roi_valid ? align(roi_width, mem_alignment) + : align(widths[1], mem_alignment); + output_image.pitch[2] = is_roi_valid ? align(roi_width, mem_alignment) + : align(widths[2], mem_alignment); + channel_sizes[0] = output_image.pitch[0] * + (is_roi_valid ? align(roi_height, mem_alignment) + : align(heights[0], mem_alignment)); + channel_sizes[1] = output_image.pitch[1] * + (is_roi_valid ? align(roi_height, mem_alignment) + : align(heights[1], mem_alignment)); + channel_sizes[2] = output_image.pitch[2] * + (is_roi_valid ? align(roi_height, mem_alignment) + : align(heights[2], mem_alignment)); + } + break; + case ROCJPEG_OUTPUT_Y: + num_channels = 1; + output_image.pitch[0] = is_roi_valid ? align(roi_width, mem_alignment) + : align(widths[0], mem_alignment); + channel_sizes[0] = output_image.pitch[0] * + (is_roi_valid ? align(roi_height, mem_alignment) + : align(heights[0], mem_alignment)); + break; + case ROCJPEG_OUTPUT_RGB: + num_channels = 1; + output_image.pitch[0] = (is_roi_valid ? align(roi_width, mem_alignment) + : align(widths[0], mem_alignment)) * + 3; + channel_sizes[0] = output_image.pitch[0] * + (is_roi_valid ? align(roi_height, mem_alignment) + : align(heights[0], mem_alignment)); + break; + case ROCJPEG_OUTPUT_RGB_PLANAR: + num_channels = 3; + output_image.pitch[2] = output_image.pitch[1] = output_image.pitch[0] = + is_roi_valid ? align(roi_width, mem_alignment) + : align(widths[0], mem_alignment); + channel_sizes[2] = channel_sizes[1] = channel_sizes[0] = + output_image.pitch[0] * + (is_roi_valid ? align(roi_height, mem_alignment) + : align(heights[0], mem_alignment)); + break; + default: + std::cout << "Unknown output format!" << std::endl; + return EXIT_FAILURE; + } + return EXIT_SUCCESS; } std::vector RocJpegDecoder::decode_images( @@ -757,7 +928,7 @@ std::vector RocJpegDecoder::decode_images( - output_tensors (std::vector): a vector of Tensors containing the decoded images */ - + int num_images = encoded_images.size(); std::vector output_tensors{num_images}; RocJpegStatus rocjpeg_status; @@ -773,7 +944,6 @@ std::vector RocJpegDecoder::decode_images( cudaStatus); constexpr int batch_size = 2; - RocJpegUtils rocjpeg_utils; std::string chroma_sub_sampling = ""; uint8_t num_components; RocJpegChromaSubsampling temp_subsampling; @@ -797,9 +967,9 @@ std::vector RocJpegDecoder::decode_images( for (int j = i; j < batch_end; j++) { int index = j - i; rocjpeg_status = rocJpegStreamParse( - (unsigned char*)encoded_images[j].data_ptr(), - encoded_images[j].numel(), - rocjpeg_stream_handles[index]); + (unsigned char*)encoded_images[j].data_ptr(), + encoded_images[j].numel(), + rocjpeg_stream_handles[index]); if (rocjpeg_status != ROCJPEG_STATUS_SUCCESS) { TORCH_CHECK( false, @@ -813,8 +983,7 @@ std::vector RocJpegDecoder::decode_images( &temp_subsampling, temp_widths.data(), temp_heights.data())); - rocjpeg_utils.GetChromaSubsamplingStr( - temp_subsampling, chroma_sub_sampling); + getChromaSubsamplingStr(temp_subsampling, chroma_sub_sampling); if (temp_widths[0] < 64 || temp_heights[0] < 64) { TORCH_CHECK( false, "The image resolution is not supported by VCN Hardware"); @@ -824,7 +993,7 @@ std::vector RocJpegDecoder::decode_images( TORCH_CHECK( false, "The chroma sub-sampling is not supported by VCN Hardware"); } - if (rocjpeg_utils.GetChannelPitchAndSizes( + if (getChannelPitchAndSizes( decode_params_batch[index], temp_subsampling, temp_widths.data(), @@ -835,15 +1004,21 @@ std::vector RocJpegDecoder::decode_images( TORCH_CHECK(false, "ERROR: Failed to get the channel pitch and sizes"); } - uint32_t roi_width = decode_params_batch[index].crop_rectangle.right - decode_params_batch[index].crop_rectangle.left; - uint32_t roi_height = decode_params_batch[index].crop_rectangle.bottom - decode_params_batch[index].crop_rectangle.top; - bool is_roi_valid = (roi_width > 0 && roi_height > 0 && roi_width <= temp_widths[0] && roi_height <= temp_heights[0]) ? true : false; - uint32_t width = is_roi_valid ? align(roi_width, 16) : align(temp_widths[0], 16); - uint32_t height = is_roi_valid ? align(roi_height, 16) : align(temp_heights[0], 16); + uint32_t roi_width = decode_params_batch[index].crop_rectangle.right - + decode_params_batch[index].crop_rectangle.left; + uint32_t roi_height = decode_params_batch[index].crop_rectangle.bottom - + decode_params_batch[index].crop_rectangle.top; + bool is_roi_valid = + (roi_width > 0 && roi_height > 0 && roi_width <= temp_widths[0] && + roi_height <= temp_heights[0]) + ? true + : false; + uint32_t width = is_roi_valid ? align(roi_width, mem_alignment) + : align(temp_widths[0], mem_alignment); + uint32_t height = is_roi_valid ? align(roi_height, mem_alignment) + : align(temp_heights[0], mem_alignment); auto output_tensor = torch::zeros( - {int64_t(num_channels), - int64_t(height), - int64_t(width)}, + {int64_t(num_channels), int64_t(height), int64_t(width)}, torch::dtype(torch::kU8).device(target_device)); channels[j] = num_channels; @@ -855,15 +1030,17 @@ std::vector RocJpegDecoder::decode_images( // allocate memory for each channel and reuse them if the sizes remain // unchanged for a new image. for (int c = 0; c < (int)num_channels; c++) { - output_images[index].channel[c] = output_tensor[c].data_ptr(); + output_images[index].channel[c] = output_tensor[c].data_ptr(); } // for (int c = (int)num_channels; c < ROCJPEG_MAX_COMPONENT; c++) { // output_images[index].channel[c] = NULL; // output_images[index].pitch[c] = 0; // } - // output_tensors[j] = output_tensor; // output_tensor.narrow(1, 0, temp_heights[0]).narrow(2, 0, temp_widths[0]); + // output_tensors[j] = output_tensor; // output_tensor.narrow(1, 0, + // temp_heights[0]).narrow(2, 0, temp_widths[0]); current_batch_size++; - output_tensors[j] = output_tensor.narrow(1, 0, temp_heights[0]).narrow(2, 0, temp_widths[0]); + output_tensors[j] = output_tensor.narrow(1, 0, temp_heights[0]) + .narrow(2, 0, temp_widths[0]); } // if (current_batch_size == 2) { @@ -874,7 +1051,7 @@ std::vector RocJpegDecoder::decode_images( current_batch_size, decode_params_batch.data(), output_images.data())); - } + } current_batch_size = 0; } diff --git a/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.h b/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.h index a052e5e354e..5c0fa56113b 100644 --- a/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.h +++ b/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.h @@ -50,7 +50,6 @@ class CUDAJpegDecoder { #include #include -#include "rocjpeg_samples_utils.h" namespace vision { namespace image { @@ -74,4 +73,25 @@ class RocJpegDecoder { } // namespace image } // namespace vision +#define CHECK_ROCJPEG(call) \ + { \ + RocJpegStatus rocjpeg_status = (call); \ + if (rocjpeg_status != ROCJPEG_STATUS_SUCCESS) { \ + std::cerr << #call << " returned " \ + << rocJpegGetErrorName(rocjpeg_status) << " at " << __FILE__ \ + << ":" << __LINE__ << std::endl; \ + exit(1); \ + } \ + } + +#define CHECK_HIP(call) \ + { \ + hipError_t hip_status = (call); \ + if (hip_status != hipSuccess) { \ + std::cout << "HIP failure: 'status: " << hipGetErrorName(hip_status) \ + << "' at " << __FILE__ << ":" << __LINE__ << std::endl; \ + exit(1); \ + } \ + } + #endif From 1d299860339b60ca1d5fd488e4149e429bcd33f6 Mon Sep 17 00:00:00 2001 From: xytpai Date: Fri, 16 Jan 2026 09:57:41 +0000 Subject: [PATCH 5/6] rm unused file --- .../io/image/cuda/rocjpeg_samples_utils.h | 567 ------------------ 1 file changed, 567 deletions(-) delete mode 100644 torchvision/csrc/io/image/cuda/rocjpeg_samples_utils.h diff --git a/torchvision/csrc/io/image/cuda/rocjpeg_samples_utils.h b/torchvision/csrc/io/image/cuda/rocjpeg_samples_utils.h deleted file mode 100644 index 3a9595dcc4f..00000000000 --- a/torchvision/csrc/io/image/cuda/rocjpeg_samples_utils.h +++ /dev/null @@ -1,567 +0,0 @@ -/* -Copyright (c) 2024 - 2025 Advanced Micro Devices, Inc. All rights reserved. - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in -all copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -THE SOFTWARE. -*/ -#ifndef ROC_JPEG_SAMPLES_COMMON -#define ROC_JPEG_SAMPLES_COMMON -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#if __cplusplus >= 201703L && __has_include() - #include - namespace fs = std::filesystem; -#else - #include - namespace fs = std::experimental::filesystem; -#endif -#include -#include "rocjpeg/rocjpeg.h" - -#define CHECK_ROCJPEG(call) { \ - RocJpegStatus rocjpeg_status = (call); \ - if (rocjpeg_status != ROCJPEG_STATUS_SUCCESS) { \ - std::cerr << #call << " returned " << rocJpegGetErrorName(rocjpeg_status) << " at " << __FILE__ << ":" << __LINE__ << std::endl;\ - exit(1); \ - } \ -} - -#define CHECK_HIP(call) { \ - hipError_t hip_status = (call); \ - if (hip_status != hipSuccess) { \ - std::cout << "HIP failure: 'status: " << hipGetErrorName(hip_status) << "' at " << __FILE__ << ":" << __LINE__ << std::endl;\ - exit(1); \ - } \ -} - -/** - * @class RocJpegUtils - * @brief Utility class for rocJPEG samples. - * - * This class provides utility functions for rocJPEG samples, such as parsing command line arguments, - * getting file paths, initializing HIP device, getting chroma subsampling string, getting channel pitch and sizes, - * getting output file extension, and saving images. - */ -class RocJpegUtils { -public: - /** - * @brief Parses the command line arguments. - * - * This function parses the command line arguments and sets the corresponding variables. - * - * @param input_path The input path. - * @param output_file_path The output file path. - * @param save_images Flag indicating whether to save images. - * @param device_id The device ID. - * @param rocjpeg_backend The rocJPEG backend. - * @param decode_params The rocJPEG decode parameters. - * @param num_threads The number of threads. - * @param crop The crop rectangle. - * @param argc The number of command line arguments. - * @param argv The command line arguments. - */ - static void ParseCommandLine(std::string &input_path, std::string &output_file_path, bool &save_images, int &device_id, - RocJpegBackend &rocjpeg_backend, RocJpegDecodeParams &decode_params, int *num_threads, int *batch_size, int argc, char *argv[]) { - if(argc <= 1) { - ShowHelpAndExit("", num_threads != nullptr, batch_size != nullptr); - } - for (int i = 1; i < argc; i++) { - if (!strcmp(argv[i], "-h")) { - ShowHelpAndExit("", num_threads != nullptr, batch_size != nullptr); - } - if (!strcmp(argv[i], "-i")) { - if (++i == argc) { - ShowHelpAndExit("-i", num_threads != nullptr, batch_size != nullptr); - } - input_path = argv[i]; - continue; - } - if (!strcmp(argv[i], "-o")) { - if (++i == argc) { - ShowHelpAndExit("-o", num_threads != nullptr, batch_size != nullptr); - } - output_file_path = argv[i]; - save_images = true; - continue; - } - if (!strcmp(argv[i], "-d")) { - if (++i == argc) { - ShowHelpAndExit("-d", num_threads != nullptr, batch_size != nullptr); - } - device_id = atoi(argv[i]); - continue; - } - if (!strcmp(argv[i], "-be")) { - if (++i == argc) { - ShowHelpAndExit("-be", num_threads != nullptr, batch_size != nullptr); - } - rocjpeg_backend = static_cast(atoi(argv[i])); - continue; - } - if (!strcmp(argv[i], "-fmt")) { - if (++i == argc) { - ShowHelpAndExit("-fmt", num_threads != nullptr, batch_size != nullptr); - } - std::string selected_output_format = argv[i]; - if (selected_output_format == "native") { - decode_params.output_format = ROCJPEG_OUTPUT_NATIVE; - } else if (selected_output_format == "yuv_planar") { - decode_params.output_format = ROCJPEG_OUTPUT_YUV_PLANAR; - } else if (selected_output_format == "y") { - decode_params.output_format = ROCJPEG_OUTPUT_Y; - } else if (selected_output_format == "rgb") { - decode_params.output_format = ROCJPEG_OUTPUT_RGB; - } else if (selected_output_format == "rgb_planar") { - decode_params.output_format = ROCJPEG_OUTPUT_RGB_PLANAR; - } else { - ShowHelpAndExit(argv[i], num_threads != nullptr); - } - continue; - } - if (!strcmp(argv[i], "-t")) { - if (++i == argc) { - ShowHelpAndExit("-t", num_threads != nullptr, batch_size != nullptr); - } - if (num_threads != nullptr) { - *num_threads = atoi(argv[i]); - if (*num_threads <= 0 || *num_threads > 32) { - ShowHelpAndExit(argv[i], num_threads != nullptr, batch_size != nullptr); - } - } - continue; - } - if (!strcmp(argv[i], "-b")) { - if (++i == argc) { - ShowHelpAndExit("-b", num_threads != nullptr, batch_size != nullptr); - } - if (batch_size != nullptr) - *batch_size = atoi(argv[i]); - continue; - } - if (!strcmp(argv[i], "-crop")) { - if (++i == argc || 4 != sscanf(argv[i], "%hd,%hd,%hd,%hd", &decode_params.crop_rectangle.left, &decode_params.crop_rectangle.top, &decode_params.crop_rectangle.right, &decode_params.crop_rectangle.bottom)) { - ShowHelpAndExit("-crop"); - } - if ((&decode_params.crop_rectangle.right - &decode_params.crop_rectangle.left) % 2 == 1 || (&decode_params.crop_rectangle.bottom - &decode_params.crop_rectangle.top) % 2 == 1) { - std::cout << "output crop rectangle must have width and height of even numbers" << std::endl; - exit(1); - } - continue; - } - ShowHelpAndExit(argv[i], num_threads != nullptr, batch_size != nullptr); - } - } - - /** - * Checks if a file is a JPEG file. - * - * @param filePath The path to the file to be checked. - * @return True if the file is a JPEG file, false otherwise. - */ - static bool IsJPEG(const std::string& filePath) { - std::ifstream file(filePath, std::ios::binary); - if (!file.is_open()) { - std::cerr << "Failed to open file: " << filePath << std::endl; - return false; - } - - unsigned char buffer[2]; - file.read(reinterpret_cast(buffer), 2); - file.close(); - - // The first two bytes of every JPEG stream are always 0xFFD8, which represents the Start of Image (SOI) marker. - return buffer[0] == 0xFF && buffer[1] == 0xD8; - } - - /** - * @brief Gets the file paths. - * - * This function gets the file paths based on the input path and sets the corresponding variables. - * - * @param input_path The input path. - * @param file_paths The vector to store the file paths. - * @param is_dir Flag indicating whether the input path is a directory. - * @param is_file Flag indicating whether the input path is a file. - * @return True if successful, false otherwise. - */ - static bool GetFilePaths(std::string &input_path, std::vector &file_paths, bool &is_dir, bool &is_file) { - std::cout << "Reading images from disk, please wait!" << std::endl; - if (!fs::exists(input_path)) { - std::cerr << "ERROR: the input path does not exist!" << std::endl; - return false; - } - is_dir = fs::is_directory(input_path); - is_file = fs::is_regular_file(input_path); - if (is_dir) { - for (const auto &entry : fs::recursive_directory_iterator(input_path)) { - if (fs::is_regular_file(entry) && IsJPEG(entry.path().string())) { - file_paths.push_back(entry.path().string()); - } - } - } else if (is_file && IsJPEG(input_path)) { - file_paths.push_back(input_path); - } else { - std::cerr << "ERROR: the input path does not contain JPEG files!" << std::endl; - return false; - } - return true; - } - - /** - * @brief Initializes the HIP device. - * - * This function initializes the HIP device with the specified device ID. - * - * @param device_id The device ID. - * @return True if successful, false otherwise. - */ - static bool InitHipDevice(int device_id) { - int num_devices; - hipDeviceProp_t hip_dev_prop; - CHECK_HIP(hipGetDeviceCount(&num_devices)); - if (num_devices < 1) { - std::cerr << "ERROR: didn't find any GPU!" << std::endl; - return false; - } - if (device_id >= num_devices) { - std::cerr << "ERROR: the requested device_id is not found!" << std::endl; - return false; - } - CHECK_HIP(hipSetDevice(device_id)); - CHECK_HIP(hipGetDeviceProperties(&hip_dev_prop, device_id)); - - std::cout << "Using GPU device " << device_id << ": " << hip_dev_prop.name << "[" << hip_dev_prop.gcnArchName << "] on PCI bus " << - std::setfill('0') << std::setw(2) << std::right << std::hex << hip_dev_prop.pciBusID << ":" << std::setfill('0') << std::setw(2) << - std::right << std::hex << hip_dev_prop.pciDomainID << "." << hip_dev_prop.pciDeviceID << std::dec << std::endl; - - return true; - } - - /** - * @brief Gets the chroma subsampling string. - * - * This function gets the chroma subsampling string based on the specified subsampling value. - * - * @param subsampling The chroma subsampling value. - * @param chroma_sub_sampling The string to store the chroma subsampling. - */ - void GetChromaSubsamplingStr(RocJpegChromaSubsampling subsampling, std::string &chroma_sub_sampling) { - switch (subsampling) { - case ROCJPEG_CSS_444: - chroma_sub_sampling = "YUV 4:4:4"; - break; - case ROCJPEG_CSS_440: - chroma_sub_sampling = "YUV 4:4:0"; - break; - case ROCJPEG_CSS_422: - chroma_sub_sampling = "YUV 4:2:2"; - break; - case ROCJPEG_CSS_420: - chroma_sub_sampling = "YUV 4:2:0"; - break; - case ROCJPEG_CSS_411: - chroma_sub_sampling = "YUV 4:1:1"; - break; - case ROCJPEG_CSS_400: - chroma_sub_sampling = "YUV 4:0:0"; - break; - case ROCJPEG_CSS_UNKNOWN: - chroma_sub_sampling = "UNKNOWN"; - break; - default: - chroma_sub_sampling = ""; - break; - } - } - - /** - * @brief Gets the channel pitch and sizes. - * - * This function gets the channel pitch and sizes based on the specified output format, chroma subsampling, - * output image, and channel sizes. - * - * @param decode_params The decode parameters that specify the output format and crop rectangle. - * @param subsampling The chroma subsampling. - * @param widths The array to store the channel widths. - * @param heights The array to store the channel heights. - * @param num_channels The number of channels. - * @param output_image The output image. - * @param channel_sizes The array to store the channel sizes. - * @return The channel pitch. - */ - int GetChannelPitchAndSizes(RocJpegDecodeParams decode_params, RocJpegChromaSubsampling subsampling, uint32_t *widths, uint32_t *heights, - uint32_t &num_channels, RocJpegImage &output_image, uint32_t *channel_sizes) { - - bool is_roi_valid = false; - uint32_t roi_width; - uint32_t roi_height; - roi_width = decode_params.crop_rectangle.right - decode_params.crop_rectangle.left; - roi_height = decode_params.crop_rectangle.bottom - decode_params.crop_rectangle.top; - if (roi_width > 0 && roi_height > 0 && roi_width <= widths[0] && roi_height <= heights[0]) { - is_roi_valid = true; - } - switch (decode_params.output_format) { - case ROCJPEG_OUTPUT_NATIVE: - switch (subsampling) { - case ROCJPEG_CSS_444: - num_channels = 3; - output_image.pitch[2] = output_image.pitch[1] = output_image.pitch[0] = is_roi_valid ? align(roi_width, mem_alignment) : align(widths[0], mem_alignment); - channel_sizes[2] = channel_sizes[1] = channel_sizes[0] = output_image.pitch[0] * (is_roi_valid ? align(roi_height, mem_alignment) : align(heights[0], mem_alignment)); - break; - case ROCJPEG_CSS_440: - num_channels = 3; - output_image.pitch[2] = output_image.pitch[1] = output_image.pitch[0] = is_roi_valid ? align(roi_width, mem_alignment) : align(widths[0], mem_alignment); - channel_sizes[0] = output_image.pitch[0] * (is_roi_valid ? align(roi_height, mem_alignment) : align(heights[0], mem_alignment)); - channel_sizes[2] = channel_sizes[1] = output_image.pitch[0] * (is_roi_valid ? align(roi_height >> 1, mem_alignment) : align(heights[0] >> 1, mem_alignment)); - break; - case ROCJPEG_CSS_422: - num_channels = 1; - output_image.pitch[0] = (is_roi_valid ? align(roi_width, mem_alignment) : align(widths[0], mem_alignment)) * 2; - channel_sizes[0] = output_image.pitch[0] * (is_roi_valid ? align(roi_height, mem_alignment) : align(heights[0], mem_alignment)); - break; - case ROCJPEG_CSS_420: - num_channels = 2; - output_image.pitch[1] = output_image.pitch[0] = is_roi_valid ? align(roi_width, mem_alignment) : align(widths[0], mem_alignment); - channel_sizes[0] = output_image.pitch[0] * (is_roi_valid ? align(roi_height, mem_alignment) : align(heights[0], mem_alignment)); - channel_sizes[1] = output_image.pitch[1] * (is_roi_valid ? align(roi_height >> 1, mem_alignment) : align(heights[0] >> 1, mem_alignment)); - break; - case ROCJPEG_CSS_400: - num_channels = 1; - output_image.pitch[0] = is_roi_valid ? align(roi_width, mem_alignment) : align(widths[0], mem_alignment); - channel_sizes[0] = output_image.pitch[0] * (is_roi_valid ? align(roi_height, mem_alignment) : align(heights[0], mem_alignment)); - break; - default: - std::cout << "Unknown chroma subsampling!" << std::endl; - return EXIT_FAILURE; - } - break; - case ROCJPEG_OUTPUT_YUV_PLANAR: - if (subsampling == ROCJPEG_CSS_400) { - num_channels = 1; - output_image.pitch[0] = is_roi_valid ? align(roi_width, mem_alignment) : align(widths[0], mem_alignment); - channel_sizes[0] = output_image.pitch[0] * (is_roi_valid ? align(roi_height, mem_alignment) : align(heights[0], mem_alignment)); - } else { - num_channels = 3; - output_image.pitch[0] = is_roi_valid ? align(roi_width, mem_alignment) : align(widths[0], mem_alignment); - output_image.pitch[1] = is_roi_valid ? align(roi_width, mem_alignment) : align(widths[1], mem_alignment); - output_image.pitch[2] = is_roi_valid ? align(roi_width, mem_alignment) : align(widths[2], mem_alignment); - channel_sizes[0] = output_image.pitch[0] * (is_roi_valid ? align(roi_height, mem_alignment) : align(heights[0], mem_alignment)); - channel_sizes[1] = output_image.pitch[1] * (is_roi_valid ? align(roi_height, mem_alignment) : align(heights[1], mem_alignment)); - channel_sizes[2] = output_image.pitch[2] * (is_roi_valid ? align(roi_height, mem_alignment) : align(heights[2], mem_alignment)); - } - break; - case ROCJPEG_OUTPUT_Y: - num_channels = 1; - output_image.pitch[0] = is_roi_valid ? align(roi_width, mem_alignment) : align(widths[0], mem_alignment); - channel_sizes[0] = output_image.pitch[0] * (is_roi_valid ? align(roi_height, mem_alignment) : align(heights[0], mem_alignment)); - break; - case ROCJPEG_OUTPUT_RGB: - num_channels = 1; - output_image.pitch[0] = (is_roi_valid ? align(roi_width, mem_alignment) : align(widths[0], mem_alignment)) * 3; - channel_sizes[0] = output_image.pitch[0] * (is_roi_valid ? align(roi_height, mem_alignment) : align(heights[0], mem_alignment)); - break; - case ROCJPEG_OUTPUT_RGB_PLANAR: - num_channels = 3; - output_image.pitch[2] = output_image.pitch[1] = output_image.pitch[0] = is_roi_valid ? align(roi_width, mem_alignment) : align(widths[0], mem_alignment); - channel_sizes[2] = channel_sizes[1] = channel_sizes[0] = output_image.pitch[0] * (is_roi_valid ? align(roi_height, mem_alignment) : align(heights[0], mem_alignment)); - break; - default: - std::cout << "Unknown output format!" << std::endl; - return EXIT_FAILURE; - } - return EXIT_SUCCESS; - } - - /** - * @brief Gets the output file extension. - * - * This function gets the output file extension based on the specified output format, base file name, - * image width, image height, and file name for saving. - * - * @param output_format The output format. - * @param base_file_name The base file name. - * @param image_width The image width. - * @param image_height The image height. - * @param file_name_for_saving The string to store the file name for saving. - */ - void GetOutputFileExt(RocJpegOutputFormat output_format, std::string &base_file_name, uint32_t image_width, uint32_t image_height, RocJpegChromaSubsampling subsampling, std::string &file_name_for_saving) { - std::string file_extension; - std::string::size_type const p(base_file_name.find_last_of('.')); - std::string file_name_no_ext = base_file_name.substr(0, p); - std::string format_description = ""; - switch (output_format) { - case ROCJPEG_OUTPUT_NATIVE: - file_extension = "yuv"; - switch (subsampling) { - case ROCJPEG_CSS_444: - format_description = "444"; - break; - case ROCJPEG_CSS_440: - format_description = "440"; - break; - case ROCJPEG_CSS_422: - format_description = "422_yuyv"; - break; - case ROCJPEG_CSS_420: - format_description = "nv12"; - break; - case ROCJPEG_CSS_400: - format_description = "400"; - break; - default: - std::cout << "Unknown chroma subsampling!" << std::endl; - return; - } - break; - case ROCJPEG_OUTPUT_YUV_PLANAR: - file_extension = "yuv"; - format_description = "planar"; - break; - case ROCJPEG_OUTPUT_Y: - file_extension = "yuv"; - format_description = "400"; - break; - case ROCJPEG_OUTPUT_RGB: - file_extension = "rgb"; - format_description = "packed"; - break; - case ROCJPEG_OUTPUT_RGB_PLANAR: - file_extension = "rgb"; - format_description = "planar"; - break; - default: - file_extension = ""; - break; - } - file_name_for_saving += "//" + file_name_no_ext + "_" + std::to_string(image_width) + "x" - + std::to_string(image_height) + "_" + format_description + "." + file_extension; - } - -private: - static const int mem_alignment = 16; - /** - * @brief Shows the help message and exits. - * - * This function shows the help message and exits the program. - * - * @param option The option to display in the help message (optional). - * @param show_threads Flag indicating whether to show the number of threads in the help message. - */ - static void ShowHelpAndExit(const char *option = nullptr, bool show_threads = false, bool show_batch_size = false) { - std::cout << "Options:\n" - "-i [input path] - input path to a single JPEG image or a directory containing JPEG images - [required]\n" - "-be [backend] - select rocJPEG backend (0 for hardware-accelerated JPEG decoding using VCN,\n" - " 1 for hybrid JPEG decoding using CPU and GPU HIP kernels (currently not supported)) [optional - default: 0]\n" - "-fmt [output format] - select rocJPEG output format for decoding, one of the [native, yuv_planar, y, rgb, rgb_planar] - [optional - default: native]\n" - "-o [output path] - path to an output file or a path to an existing directory - write decoded images to a file or an existing directory based on selected output format - [optional]\n" - "-crop [crop rectangle] - crop rectangle for output in a comma-separated format: left,top,right,bottom - [optional]\n" - "-d [device id] - specify the GPU device id for the desired device (use 0 for the first device, 1 for the second device, and so on) [optional - default: 0]\n"; - if (show_threads) { - std::cout << "-t [threads] - number of threads (<= 32) for parallel JPEG decoding - [optional - default: 1]\n"; - } - if (show_batch_size) { - std::cout << "-b [batch_size] - decode images from input by batches of a specified size - [optional - default: 1]\n"; - } - exit(0); - } - /** - * @brief Aligns a value to a specified alignment. - * - * This function takes a value and aligns it to the specified alignment. It returns the aligned value. - * - * @param value The value to be aligned. - * @param alignment The alignment value. - * @return The aligned value. - */ - static inline int align(int value, int alignment) { - return (value + alignment - 1) & ~(alignment - 1); - } -}; - -class ThreadPool { - public: - ThreadPool(int nthreads) : shutdown_(false) { - // Create the specified number of threads - threads_.reserve(nthreads); - for (int i = 0; i < nthreads; ++i) - threads_.emplace_back(std::bind(&ThreadPool::ThreadEntry, this, i)); - } - - ~ThreadPool() {} - - void JoinThreads() { - { - // Unblock any threads and tell them to stop - std::unique_lock lock(mutex_); - shutdown_ = true; - cond_var_.notify_all(); - } - - // Wait for all threads to stop - for (auto& thread : threads_) - thread.join(); - } - - void ExecuteJob(std::function func) { - // Place a job on the queue and unblock a thread - std::unique_lock lock(mutex_); - decode_jobs_queue_.emplace(std::move(func)); - cond_var_.notify_one(); - } - - protected: - void ThreadEntry(int i) { - std::function execute_decode_job; - - while (true) { - { - std::unique_lock lock(mutex_); - cond_var_.wait(lock, [&] {return shutdown_ || !decode_jobs_queue_.empty();}); - if (decode_jobs_queue_.empty()) { - // No jobs to do; shutting down - return; - } - - execute_decode_job = std::move(decode_jobs_queue_.front()); - decode_jobs_queue_.pop(); - } - - // Execute the decode job without holding any locks - execute_decode_job(); - } - } - - std::mutex mutex_; - std::condition_variable cond_var_; - bool shutdown_; - std::queue> decode_jobs_queue_; - std::vector threads_; -}; - -#endif //ROC_JPEG_SAMPLES_COMMON From 15d8f1113958795c80acb3ff05cd1958cc3cde32 Mon Sep 17 00:00:00 2001 From: xytpai Date: Fri, 16 Jan 2026 10:09:43 +0000 Subject: [PATCH 6/6] refine code 2 --- setup.py | 6 ++--- .../csrc/io/image/cuda/decode_jpegs_cuda.cpp | 27 ++----------------- 2 files changed, 5 insertions(+), 28 deletions(-) diff --git a/setup.py b/setup.py index 4b9559eb630..24a43c01778 100644 --- a/setup.py +++ b/setup.py @@ -22,7 +22,7 @@ USE_JPEG = os.getenv("TORCHVISION_USE_JPEG", "1") == "1" USE_WEBP = os.getenv("TORCHVISION_USE_WEBP", "1") == "1" USE_NVJPEG = os.getenv("TORCHVISION_USE_NVJPEG", "1") == "1" -USE_ROCJPEG = os.getenv("TORCHVISION_USE_ROCJPEG", "0") == "1" +USE_ROCJPEG = os.getenv("TORCHVISION_USE_ROCJPEG", "1") == "1" NVCC_FLAGS = os.getenv("NVCC_FLAGS", None) # Note: the GPU video decoding stuff used to be called "video codec", which # isn't an accurate or descriptive name considering there are at least 2 other @@ -355,12 +355,12 @@ def make_image_extension(): if (USE_NVJPEG or USE_ROCJPEG) and (torch.cuda.is_available() or FORCE_CUDA): nvjpeg_found = CUDA_HOME is not None and (Path(CUDA_HOME) / "include/nvjpeg.h").exists() rocjpeg_found = ROCM_HOME is not None and (Path(ROCM_HOME) / "include/rocjpeg/rocjpeg.h").exists() - if nvjpeg_found: + if nvjpeg_found and USE_NVJPEG: print("Building torchvision with NVJPEG image support") libraries.append("nvjpeg") define_macros += [("NVJPEG_FOUND", 1)] Extension = CUDAExtension - elif rocjpeg_found: + elif rocjpeg_found and USE_ROCJPEG: print("Building torchvision with ROCJPEG image support") libraries.append("rocjpeg") define_macros += [("ROCJPEG_FOUND", 1)] diff --git a/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp b/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp index 9b974b7dc03..e5b49e1067e 100644 --- a/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp +++ b/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp @@ -920,8 +920,7 @@ std::vector RocJpegDecoder::decode_images( Args: - encoded_images (std::vector): a vector of tensors containing the jpeg bitstreams to be decoded - - output_format (RocJpegOutputFormat): ROCJPEG_OUTPUT_RGB, ROCJPEG_OUTPUT_Y - or ROCJPEG_OUTPUT_NATIVE + - output_format (RocJpegOutputFormat): ROCJPEG_OUTPUT_RGB - device (torch::Device): The desired CUDA device for the returned Tensors Returns: @@ -1017,27 +1016,16 @@ std::vector RocJpegDecoder::decode_images( : align(temp_widths[0], mem_alignment); uint32_t height = is_roi_valid ? align(roi_height, mem_alignment) : align(temp_heights[0], mem_alignment); - auto output_tensor = torch::zeros( + auto output_tensor = torch::empty( {int64_t(num_channels), int64_t(height), int64_t(width)}, torch::dtype(torch::kU8).device(target_device)); channels[j] = num_channels; - // for (int n = 0; n < (int)num_channels; n++) { - // output_images[current_batch_size].channel[n] = - // output_tensor[n].data_ptr(); - // } - // allocate memory for each channel and reuse them if the sizes remain // unchanged for a new image. for (int c = 0; c < (int)num_channels; c++) { output_images[index].channel[c] = output_tensor[c].data_ptr(); } - // for (int c = (int)num_channels; c < ROCJPEG_MAX_COMPONENT; c++) { - // output_images[index].channel[c] = NULL; - // output_images[index].pitch[c] = 0; - // } - // output_tensors[j] = output_tensor; // output_tensor.narrow(1, 0, - // temp_heights[0]).narrow(2, 0, temp_widths[0]); current_batch_size++; output_tensors[j] = output_tensor.narrow(1, 0, temp_heights[0]) .narrow(2, 0, temp_widths[0]); @@ -1062,17 +1050,6 @@ std::vector RocJpegDecoder::decode_images( "Failed to synchronize CUDA stream: ", cudaStatus); - // prune extraneous channels from single channel images - if (output_format == ROCJPEG_OUTPUT_NATIVE) { - for (std::vector::size_type i = 0; i < output_tensors.size(); - ++i) { - if (channels[i] == 1) { - output_tensors[i] = output_tensors[i][0].unsqueeze(0).clone(); - } - } - } - - cudaDeviceSynchronize(); return output_tensors; }