diff --git a/src/Native/LibTorchSharp/THSVision.cpp b/src/Native/LibTorchSharp/THSVision.cpp index 5cc6f832d..5fd3ecdcf 100644 --- a/src/Native/LibTorchSharp/THSVision.cpp +++ b/src/Native/LibTorchSharp/THSVision.cpp @@ -184,7 +184,7 @@ Tensor THSVision_ApplyGridTransform(Tensor i, Tensor g, const int8_t m, const fl if (m == 0) { mask = mask < 0.5; - img[mask] = fill_img[mask]; + img = torch::where(mask, fill_img, img); } else { img = img * mask + (-mask + 1.0) * fill_img; diff --git a/test/TorchSharpTest/TestTorchVision.cs b/test/TorchSharpTest/TestTorchVision.cs index c8f1bc341..69ad3cf72 100644 --- a/test/TorchSharpTest/TestTorchVision.cs +++ b/test/TorchSharpTest/TestTorchVision.cs @@ -1143,6 +1143,60 @@ public void TestRotateImage45DegreesClockwise() } } + [Fact] + public void TestAffineTransform3D() + { + // 3D input: [C, H, W] + var img = torch.rand(new long[] { 1, 48, 48 }); + var result = functional.affine( + img, + angle: 0f, + translate: new[] { 0, 0 }, + scale: 1f, + shear: new[] { 1f, 1f }, + fill: 0); + Assert.Equal(img.shape, result.shape); + } + + [Fact] + public void TestAffineTransform4D() + { + // 4D input: [N, C, H, W] — reproduces issue #1502 + var img = torch.rand(new long[] { 1, 1, 48, 48 }); + var result = functional.affine( + img, + angle: 0f, + translate: new[] { 0, 0 }, + scale: 1f, + shear: new[] { 1f, 1f }, + fill: 0); + Assert.Equal(img.shape, result.shape); + } + + [Fact] + public void TestAffineTransform4DBatch() + { + // 4D input with batch > 1 + var img = torch.rand(new long[] { 4, 3, 32, 32 }); + var result = functional.affine( + img, + angle: 15f, + translate: new[] { 5, 5 }, + scale: 0.9f, + shear: new[] { 10f, 5f }, + fill: 0); + Assert.Equal(img.shape, result.shape); + } + + [Fact] + public void TestRotateWithFill() + { + // Rotate with fill also uses ApplyGridTransform — verify it works + var img = torch.rand(new long[] { 1, 1, 48, 48 }); + var result = functional.rotate(img, 45f, InterpolationMode.Nearest, fill: new float[] { 0f }); + Assert.Equal(img.shape, result.shape); + } + [Fact] public void Solarize_InvertedPixel_True() {