Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,14 +283,26 @@ def make_image_extension():

libraries = []
define_macros, extra_compile_args = get_macros_and_flags()
# PyTorch Stable ABI target version (2.11) - required for string handling in TORCH_BOX
define_macros += [("TORCH_TARGET_VERSION", "0x020b000000000000")]

image_dir = CSRS_DIR / "io/image"
sources = list(image_dir.glob("*.cpp")) + list(image_dir.glob("cpu/*.cpp")) + list(image_dir.glob("cpu/giflib/*.c"))
# Exclude *_hip.cpp files - those are hipified versions that would cause multiple definition errors
sources = [s for s in image_dir.glob("*.cpp") if not s.name.endswith("_hip.cpp")]
sources += [s for s in image_dir.glob("cpu/*.cpp") if not s.name.endswith("_hip.cpp")]
sources += list(image_dir.glob("cpu/giflib/*.c"))

# Always include CUDA sources - they have stubs when NVJPEG_FOUND is not defined
if IS_ROCM:
sources += list(image_dir.glob("hip/*.cpp"))
# we need to exclude this in favor of the hipified source
sources.remove(image_dir / "image.cpp")
hip_sources = list(image_dir.glob("hip/*.cpp"))
if hip_sources:
sources += hip_sources
# Only remove image.cpp if we have a hipified replacement
if (image_dir / "image.cpp") in sources:
sources.remove(image_dir / "image.cpp")
else:
# No hip/ directory - use cuda sources (they have stubs for non-NVJPEG builds)
sources += list(image_dir.glob("cuda/*.cpp"))
else:
sources += list(image_dir.glob("cuda/*.cpp"))

Expand Down
19 changes: 7 additions & 12 deletions test/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,16 +272,14 @@ def test_encode_png_errors():
"img_path",
[pytest.param(png_path, id=_get_safe_image_name(png_path)) for png_path in get_images(IMAGE_DIR, ".png")],
)
@pytest.mark.parametrize("scripted", (True, False))
def test_write_png(img_path, tmpdir, scripted):
def test_write_png(img_path, tmpdir):
pil_image = Image.open(img_path)
img_pil = torch.from_numpy(np.array(pil_image))
img_pil = img_pil.permute(2, 0, 1)

filename, _ = os.path.splitext(os.path.basename(img_path))
torch_png = os.path.join(tmpdir, f"{filename}_torch.png")
write = torch.jit.script(write_png) if scripted else write_png
write(img_pil, torch_png, compression_level=6)
write_png(img_pil, torch_png, compression_level=6)
saved_image = torch.from_numpy(np.array(Image.open(torch_png)))
saved_image = saved_image.permute(2, 0, 1)

Expand Down Expand Up @@ -325,13 +323,11 @@ def test_read_file_non_ascii(tmpdir):
assert_equal(data, expected)


@pytest.mark.parametrize("scripted", (True, False))
def test_write_file(tmpdir, scripted):
def test_write_file(tmpdir):
fname, content = "test1.bin", b"TorchVision\211\n"
fpath = os.path.join(tmpdir, fname)
content_tensor = torch.tensor(list(content), dtype=torch.uint8)
write = torch.jit.script(write_file) if scripted else write_file
write(fpath, content_tensor)
write_file(fpath, content_tensor)

with open(fpath, "rb") as f:
saved_content = f.read()
Expand Down Expand Up @@ -808,17 +804,15 @@ def test_batch_encode_jpegs_cuda_errors():
"img_path",
[pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(ENCODE_JPEG, ".jpg")],
)
@pytest.mark.parametrize("scripted", (True, False))
def test_write_jpeg(img_path, tmpdir, scripted):
def test_write_jpeg(img_path, tmpdir):
tmpdir = Path(tmpdir)
img = read_image(img_path)
pil_img = F.to_pil_image(img)

torch_jpeg = str(tmpdir / "torch.jpg")
pil_jpeg = str(tmpdir / "pil.jpg")

write = torch.jit.script(write_jpeg) if scripted else write_jpeg
write(img, torch_jpeg, quality=75)
write_jpeg(img, torch_jpeg, quality=75)
pil_img.save(pil_jpeg, quality=75)

with open(torch_jpeg, "rb") as f:
Expand Down Expand Up @@ -935,6 +929,7 @@ def test_decode_webp(decode_fun, scripted):
img += 123 # make sure image buffer wasn't freed by underlying decoding lib


@pytest.mark.skip(reason="TODO_STABLE_ABI: need TORCH_WARN_ONCE")
@pytest.mark.parametrize("decode_fun", (decode_webp, decode_image))
def test_decode_webp_grayscale(decode_fun, capfd):
encoded_bytes = read_file(next(get_images(FAKEDATA_DIR, ".webp")))
Expand Down
Loading
Loading