diff --git a/setup.py b/setup.py index 60d297c35f3..1211db0fdf4 100644 --- a/setup.py +++ b/setup.py @@ -22,6 +22,7 @@ USE_JPEG = os.getenv("TORCHVISION_USE_JPEG", "1") == "1" USE_WEBP = os.getenv("TORCHVISION_USE_WEBP", "1") == "1" USE_NVJPEG = os.getenv("TORCHVISION_USE_NVJPEG", "1") == "1" +USE_ROCJPEG = os.getenv("TORCHVISION_USE_ROCJPEG", "1") == "1" NVCC_FLAGS = os.getenv("NVCC_FLAGS", None) TORCHVISION_INCLUDE = os.environ.get("TORCHVISION_INCLUDE", "") @@ -44,6 +45,7 @@ print(f"{USE_JPEG = }") print(f"{USE_WEBP = }") print(f"{USE_NVJPEG = }") +print(f"{USE_ROCJPEG = }") print(f"{NVCC_FLAGS = }") print(f"{TORCHVISION_INCLUDE = }") print(f"{TORCHVISION_LIBRARY = }") @@ -340,18 +342,23 @@ def make_image_extension(): else: warnings.warn("Building torchvision without WEBP support") - if USE_NVJPEG and (torch.cuda.is_available() or FORCE_CUDA): + if (USE_NVJPEG or USE_ROCJPEG) and (torch.cuda.is_available() or FORCE_CUDA): nvjpeg_found = CUDA_HOME is not None and (Path(CUDA_HOME) / "include/nvjpeg.h").exists() - - if nvjpeg_found: + rocjpeg_found = ROCM_HOME is not None and (Path(ROCM_HOME) / "include/rocjpeg/rocjpeg.h").exists() + if nvjpeg_found and USE_NVJPEG: print("Building torchvision with NVJPEG image support") libraries.append("nvjpeg") define_macros += [("NVJPEG_FOUND", 1)] Extension = CUDAExtension + elif rocjpeg_found and USE_ROCJPEG: + print("Building torchvision with ROCJPEG image support") + libraries.append("rocjpeg") + define_macros += [("ROCJPEG_FOUND", 1)] + Extension = CUDAExtension else: - warnings.warn("Building torchvision without NVJPEG support") - elif USE_NVJPEG: - warnings.warn("Building torchvision without NVJPEG support") + warnings.warn("Building torchvision without NVJPEG or ROCJPEG support") + elif (USE_NVJPEG or USE_ROCJPEG): + warnings.warn("Building torchvision without NVJPEG or ROCJPEG support") return Extension( name="torchvision.image", diff --git a/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp b/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp index 85aa6c760c1..e5b49e1067e 100644 --- a/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp +++ b/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp @@ -1,5 +1,5 @@ #include "decode_jpegs_cuda.h" -#if !NVJPEG_FOUND +#if !NVJPEG_FOUND && !ROCJPEG_FOUND namespace vision { namespace image { std::vector decode_jpegs_cuda( @@ -11,8 +11,9 @@ std::vector decode_jpegs_cuda( } } // namespace image } // namespace vision +#endif -#else +#if NVJPEG_FOUND #include #include #include @@ -600,3 +601,459 @@ std::vector CUDAJpegDecoder::decode_images( } // namespace vision #endif + +#if ROCJPEG_FOUND + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace vision { +namespace image { + +std::mutex decoderMutex; +std::unique_ptr rocJpegDecoder; + +std::vector decode_jpegs_cuda( + const std::vector& encoded_images, + vision::image::ImageReadMode mode, + torch::Device device) { + C10_LOG_API_USAGE_ONCE( + "torchvision.csrc.io.image.cuda.decode_jpegs_cuda.decode_jpegs_cuda"); + + std::lock_guard lock(decoderMutex); + std::vector contig_images; + contig_images.reserve(encoded_images.size()); + + TORCH_CHECK( + device.is_cuda(), "Expected the device parameter to be a cuda device"); + + for (auto& encoded_image : encoded_images) { + TORCH_CHECK( + encoded_image.dtype() == torch::kU8, "Expected a torch.uint8 tensor"); + + TORCH_CHECK( + !encoded_image.is_cuda(), + "The input tensor must be on CPU when decoding with nvjpeg") + + TORCH_CHECK( + encoded_image.dim() == 1 && encoded_image.numel() > 0, + "Expected a non empty 1-dimensional tensor"); + + // nvjpeg requires images to be contiguous + if (encoded_image.is_contiguous()) { + contig_images.push_back(encoded_image); + } else { + contig_images.push_back(encoded_image.contiguous()); + } + } + + at::cuda::CUDAGuard device_guard(device); + + if (rocJpegDecoder == nullptr || device != rocJpegDecoder->target_device) { + if (rocJpegDecoder != nullptr) { + rocJpegDecoder.reset(new RocJpegDecoder(device)); + } else { + rocJpegDecoder = std::make_unique(device); + std::atexit([]() { rocJpegDecoder.reset(); }); + } + } + + RocJpegOutputFormat output_format; + + switch (mode) { + case vision::image::IMAGE_READ_MODE_RGB: + output_format = ROCJPEG_OUTPUT_RGB_PLANAR; + break; + default: + TORCH_CHECK( + false, + "The provided mode is not supported for ROCJPEG decoding on GPU"); + } + + try { + at::cuda::CUDAEvent event; + auto result = rocJpegDecoder->decode_images(contig_images, output_format); + auto current_stream{ + device.has_index() ? at::cuda::getCurrentCUDAStream( + rocJpegDecoder->original_device.index()) + : at::cuda::getCurrentCUDAStream()}; + event.record(rocJpegDecoder->stream); + event.block(current_stream); + return result; + } catch (const std::exception& e) { + if (typeid(e) != typeid(std::runtime_error)) { + TORCH_CHECK(false, "Error while decoding JPEG images: ", e.what()); + } else { + throw; + } + } +} + +RocJpegDecoder::RocJpegDecoder(const torch::Device& target_device) + : original_device{torch::kCUDA, c10::cuda::current_device()}, + target_device{target_device}, + stream{ + target_device.has_index() + ? at::cuda::getStreamFromPool(false, target_device.index()) + : at::cuda::getStreamFromPool(false)} { + int device_id = target_device.index(); + CHECK_HIP(hipSetDevice(device_id)); + RocJpegStatus status; + RocJpegBackend rocjpeg_backend = ROCJPEG_BACKEND_HARDWARE; + + status = rocJpegCreate(rocjpeg_backend, device_id, &rocjpeg_handle); + TORCH_CHECK( + status == ROCJPEG_STATUS_SUCCESS, + "Failed to initialize rocjpeg with hardware backend"); + + status = rocJpegStreamCreate(&rocjpeg_stream_handles[0]); + TORCH_CHECK( + status == ROCJPEG_STATUS_SUCCESS, "Failed to initialize rocjpeg stream"); + + status = rocJpegStreamCreate(&rocjpeg_stream_handles[1]); + TORCH_CHECK( + status == ROCJPEG_STATUS_SUCCESS, "Failed to initialize rocjpeg stream"); +} + +RocJpegDecoder::~RocJpegDecoder() { + rocJpegDestroy(rocjpeg_handle); + rocJpegStreamDestroy(rocjpeg_stream_handles[0]); + rocJpegStreamDestroy(rocjpeg_stream_handles[1]); +} + +static constexpr int mem_alignment = 16; + +static inline int align(int value, int alignment) { + return (value + alignment - 1) & ~(alignment - 1); +} + +void getChromaSubsamplingStr( + RocJpegChromaSubsampling subsampling, + std::string& chroma_sub_sampling) { + switch (subsampling) { + case ROCJPEG_CSS_444: + chroma_sub_sampling = "YUV 4:4:4"; + break; + case ROCJPEG_CSS_440: + chroma_sub_sampling = "YUV 4:4:0"; + break; + case ROCJPEG_CSS_422: + chroma_sub_sampling = "YUV 4:2:2"; + break; + case ROCJPEG_CSS_420: + chroma_sub_sampling = "YUV 4:2:0"; + break; + case ROCJPEG_CSS_411: + chroma_sub_sampling = "YUV 4:1:1"; + break; + case ROCJPEG_CSS_400: + chroma_sub_sampling = "YUV 4:0:0"; + break; + case ROCJPEG_CSS_UNKNOWN: + chroma_sub_sampling = "UNKNOWN"; + break; + default: + chroma_sub_sampling = ""; + break; + } +} + +int getChannelPitchAndSizes( + RocJpegDecodeParams decode_params, + RocJpegChromaSubsampling subsampling, + uint32_t* widths, + uint32_t* heights, + uint32_t& num_channels, + RocJpegImage& output_image, + uint32_t* channel_sizes) { + bool is_roi_valid = false; + uint32_t roi_width; + uint32_t roi_height; + roi_width = + decode_params.crop_rectangle.right - decode_params.crop_rectangle.left; + roi_height = + decode_params.crop_rectangle.bottom - decode_params.crop_rectangle.top; + if (roi_width > 0 && roi_height > 0 && roi_width <= widths[0] && + roi_height <= heights[0]) { + is_roi_valid = true; + } + switch (decode_params.output_format) { + case ROCJPEG_OUTPUT_NATIVE: + switch (subsampling) { + case ROCJPEG_CSS_444: + num_channels = 3; + output_image.pitch[2] = output_image.pitch[1] = + output_image.pitch[0] = + is_roi_valid ? align(roi_width, mem_alignment) + : align(widths[0], mem_alignment); + channel_sizes[2] = channel_sizes[1] = channel_sizes[0] = + output_image.pitch[0] * + (is_roi_valid ? align(roi_height, mem_alignment) + : align(heights[0], mem_alignment)); + break; + case ROCJPEG_CSS_440: + num_channels = 3; + output_image.pitch[2] = output_image.pitch[1] = + output_image.pitch[0] = + is_roi_valid ? align(roi_width, mem_alignment) + : align(widths[0], mem_alignment); + channel_sizes[0] = output_image.pitch[0] * + (is_roi_valid ? align(roi_height, mem_alignment) + : align(heights[0], mem_alignment)); + channel_sizes[2] = channel_sizes[1] = output_image.pitch[0] * + (is_roi_valid ? align(roi_height >> 1, mem_alignment) + : align(heights[0] >> 1, mem_alignment)); + break; + case ROCJPEG_CSS_422: + num_channels = 1; + output_image.pitch[0] = + (is_roi_valid ? align(roi_width, mem_alignment) + : align(widths[0], mem_alignment)) * + 2; + channel_sizes[0] = output_image.pitch[0] * + (is_roi_valid ? align(roi_height, mem_alignment) + : align(heights[0], mem_alignment)); + break; + case ROCJPEG_CSS_420: + num_channels = 2; + output_image.pitch[1] = output_image.pitch[0] = is_roi_valid + ? align(roi_width, mem_alignment) + : align(widths[0], mem_alignment); + channel_sizes[0] = output_image.pitch[0] * + (is_roi_valid ? align(roi_height, mem_alignment) + : align(heights[0], mem_alignment)); + channel_sizes[1] = output_image.pitch[1] * + (is_roi_valid ? align(roi_height >> 1, mem_alignment) + : align(heights[0] >> 1, mem_alignment)); + break; + case ROCJPEG_CSS_400: + num_channels = 1; + output_image.pitch[0] = is_roi_valid + ? align(roi_width, mem_alignment) + : align(widths[0], mem_alignment); + channel_sizes[0] = output_image.pitch[0] * + (is_roi_valid ? align(roi_height, mem_alignment) + : align(heights[0], mem_alignment)); + break; + default: + std::cout << "Unknown chroma subsampling!" << std::endl; + return EXIT_FAILURE; + } + break; + case ROCJPEG_OUTPUT_YUV_PLANAR: + if (subsampling == ROCJPEG_CSS_400) { + num_channels = 1; + output_image.pitch[0] = is_roi_valid ? align(roi_width, mem_alignment) + : align(widths[0], mem_alignment); + channel_sizes[0] = output_image.pitch[0] * + (is_roi_valid ? align(roi_height, mem_alignment) + : align(heights[0], mem_alignment)); + } else { + num_channels = 3; + output_image.pitch[0] = is_roi_valid ? align(roi_width, mem_alignment) + : align(widths[0], mem_alignment); + output_image.pitch[1] = is_roi_valid ? align(roi_width, mem_alignment) + : align(widths[1], mem_alignment); + output_image.pitch[2] = is_roi_valid ? align(roi_width, mem_alignment) + : align(widths[2], mem_alignment); + channel_sizes[0] = output_image.pitch[0] * + (is_roi_valid ? align(roi_height, mem_alignment) + : align(heights[0], mem_alignment)); + channel_sizes[1] = output_image.pitch[1] * + (is_roi_valid ? align(roi_height, mem_alignment) + : align(heights[1], mem_alignment)); + channel_sizes[2] = output_image.pitch[2] * + (is_roi_valid ? align(roi_height, mem_alignment) + : align(heights[2], mem_alignment)); + } + break; + case ROCJPEG_OUTPUT_Y: + num_channels = 1; + output_image.pitch[0] = is_roi_valid ? align(roi_width, mem_alignment) + : align(widths[0], mem_alignment); + channel_sizes[0] = output_image.pitch[0] * + (is_roi_valid ? align(roi_height, mem_alignment) + : align(heights[0], mem_alignment)); + break; + case ROCJPEG_OUTPUT_RGB: + num_channels = 1; + output_image.pitch[0] = (is_roi_valid ? align(roi_width, mem_alignment) + : align(widths[0], mem_alignment)) * + 3; + channel_sizes[0] = output_image.pitch[0] * + (is_roi_valid ? align(roi_height, mem_alignment) + : align(heights[0], mem_alignment)); + break; + case ROCJPEG_OUTPUT_RGB_PLANAR: + num_channels = 3; + output_image.pitch[2] = output_image.pitch[1] = output_image.pitch[0] = + is_roi_valid ? align(roi_width, mem_alignment) + : align(widths[0], mem_alignment); + channel_sizes[2] = channel_sizes[1] = channel_sizes[0] = + output_image.pitch[0] * + (is_roi_valid ? align(roi_height, mem_alignment) + : align(heights[0], mem_alignment)); + break; + default: + std::cout << "Unknown output format!" << std::endl; + return EXIT_FAILURE; + } + return EXIT_SUCCESS; +} + +std::vector RocJpegDecoder::decode_images( + const std::vector& encoded_images, + const RocJpegOutputFormat& output_format) { + /* + This function decodes a batch of jpeg bitstreams. + + Args: + - encoded_images (std::vector): a vector of tensors + containing the jpeg bitstreams to be decoded + - output_format (RocJpegOutputFormat): ROCJPEG_OUTPUT_RGB + - device (torch::Device): The desired CUDA device for the returned Tensors + + Returns: + - output_tensors (std::vector): a vector of Tensors + containing the decoded images + */ + + int num_images = encoded_images.size(); + std::vector output_tensors{num_images}; + RocJpegStatus rocjpeg_status; + cudaError_t cudaStatus; + + // baseline JPEGs can be batch decoded with hardware support + std::vector channels(num_images); + + cudaStatus = cudaStreamSynchronize(stream); + TORCH_CHECK( + cudaStatus == cudaSuccess, + "Failed to synchronize CUDA stream: ", + cudaStatus); + + constexpr int batch_size = 2; + std::string chroma_sub_sampling = ""; + uint8_t num_components; + RocJpegChromaSubsampling temp_subsampling; + std::vector temp_widths(ROCJPEG_MAX_COMPONENT, 0); + std::vector temp_heights(ROCJPEG_MAX_COMPONENT, 0); + RocJpegDecodeParams decode_params = {}; + decode_params.output_format = output_format; + std::vector decode_params_batch; + decode_params_batch.resize(batch_size, decode_params); + std::vector output_images; + output_images.resize(batch_size); + int current_batch_size = 0; + uint32_t channel_sizes[ROCJPEG_MAX_COMPONENT] = {}; + uint32_t num_channels = 0; + std::vector> prior_channel_sizes; + prior_channel_sizes.resize( + batch_size, std::vector(ROCJPEG_MAX_COMPONENT, 0)); + + for (int i = 0; i < num_images; i += batch_size) { + int batch_end = std::min(i + batch_size, num_images); + for (int j = i; j < batch_end; j++) { + int index = j - i; + rocjpeg_status = rocJpegStreamParse( + (unsigned char*)encoded_images[j].data_ptr(), + encoded_images[j].numel(), + rocjpeg_stream_handles[index]); + if (rocjpeg_status != ROCJPEG_STATUS_SUCCESS) { + TORCH_CHECK( + false, + "ERROR: Failed to parse the input jpeg stream with ", + rocJpegGetErrorName(rocjpeg_status)); + } + CHECK_ROCJPEG(rocJpegGetImageInfo( + rocjpeg_handle, + rocjpeg_stream_handles[index], + &num_components, + &temp_subsampling, + temp_widths.data(), + temp_heights.data())); + getChromaSubsamplingStr(temp_subsampling, chroma_sub_sampling); + if (temp_widths[0] < 64 || temp_heights[0] < 64) { + TORCH_CHECK( + false, "The image resolution is not supported by VCN Hardware"); + } + if (temp_subsampling == ROCJPEG_CSS_411 || + temp_subsampling == ROCJPEG_CSS_UNKNOWN) { + TORCH_CHECK( + false, "The chroma sub-sampling is not supported by VCN Hardware"); + } + if (getChannelPitchAndSizes( + decode_params_batch[index], + temp_subsampling, + temp_widths.data(), + temp_heights.data(), + num_channels, + output_images[index], + channel_sizes)) { + TORCH_CHECK(false, "ERROR: Failed to get the channel pitch and sizes"); + } + + uint32_t roi_width = decode_params_batch[index].crop_rectangle.right - + decode_params_batch[index].crop_rectangle.left; + uint32_t roi_height = decode_params_batch[index].crop_rectangle.bottom - + decode_params_batch[index].crop_rectangle.top; + bool is_roi_valid = + (roi_width > 0 && roi_height > 0 && roi_width <= temp_widths[0] && + roi_height <= temp_heights[0]) + ? true + : false; + uint32_t width = is_roi_valid ? align(roi_width, mem_alignment) + : align(temp_widths[0], mem_alignment); + uint32_t height = is_roi_valid ? align(roi_height, mem_alignment) + : align(temp_heights[0], mem_alignment); + auto output_tensor = torch::empty( + {int64_t(num_channels), int64_t(height), int64_t(width)}, + torch::dtype(torch::kU8).device(target_device)); + channels[j] = num_channels; + + // allocate memory for each channel and reuse them if the sizes remain + // unchanged for a new image. + for (int c = 0; c < (int)num_channels; c++) { + output_images[index].channel[c] = output_tensor[c].data_ptr(); + } + current_batch_size++; + output_tensors[j] = output_tensor.narrow(1, 0, temp_heights[0]) + .narrow(2, 0, temp_widths[0]); + } + + // if (current_batch_size == 2) { + if (current_batch_size > 0) { + CHECK_ROCJPEG(rocJpegDecodeBatched( + rocjpeg_handle, + rocjpeg_stream_handles, + current_batch_size, + decode_params_batch.data(), + output_images.data())); + } + + current_batch_size = 0; + } + + cudaStatus = cudaStreamSynchronize(stream); + TORCH_CHECK( + cudaStatus == cudaSuccess, + "Failed to synchronize CUDA stream: ", + cudaStatus); + + return output_tensors; +} + +} // namespace image +} // namespace vision + +#endif diff --git a/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.h b/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.h index 6f72d9e35b2..5c0fa56113b 100644 --- a/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.h +++ b/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.h @@ -4,6 +4,7 @@ #include "../common.h" #if NVJPEG_FOUND + #include #include @@ -42,4 +43,55 @@ class CUDAJpegDecoder { }; } // namespace image } // namespace vision + +#endif + +#if ROCJPEG_FOUND + +#include +#include + +namespace vision { +namespace image { +class RocJpegDecoder { + public: + RocJpegDecoder(const torch::Device& target_device); + ~RocJpegDecoder(); + + std::vector decode_images( + const std::vector& encoded_images, + const RocJpegOutputFormat& output_format); + + const torch::Device original_device; + const torch::Device target_device; + const c10::cuda::CUDAStream stream; + + private: + RocJpegStreamHandle rocjpeg_stream_handles[2]; + RocJpegHandle rocjpeg_handle; +}; +} // namespace image +} // namespace vision + +#define CHECK_ROCJPEG(call) \ + { \ + RocJpegStatus rocjpeg_status = (call); \ + if (rocjpeg_status != ROCJPEG_STATUS_SUCCESS) { \ + std::cerr << #call << " returned " \ + << rocJpegGetErrorName(rocjpeg_status) << " at " << __FILE__ \ + << ":" << __LINE__ << std::endl; \ + exit(1); \ + } \ + } + +#define CHECK_HIP(call) \ + { \ + hipError_t hip_status = (call); \ + if (hip_status != hipSuccess) { \ + std::cout << "HIP failure: 'status: " << hipGetErrorName(hip_status) \ + << "' at " << __FILE__ << ":" << __LINE__ << std::endl; \ + exit(1); \ + } \ + } + #endif