From 6776d0aa1d6555ca8fcce2c438a16f016fc60456 Mon Sep 17 00:00:00 2001 From: Josh Williams Date: Tue, 1 Jul 2025 10:21:10 +0100 Subject: [PATCH 1/4] Add transpose support convolution to distconv --- distconv/distconv.py | 93 ++++++++++++++++++++++++++++++++++++-------- 1 file changed, 76 insertions(+), 17 deletions(-) diff --git a/distconv/distconv.py b/distconv/distconv.py index 7316163..178195b 100644 --- a/distconv/distconv.py +++ b/distconv/distconv.py @@ -1,3 +1,4 @@ +from copy import copy from typing import Callable, Dict, List, Tuple import torch @@ -77,7 +78,7 @@ def check_is_distconv_supported( if kernel_size % 2 == 1: if (kernel_size // 2) != padding[shard_dim]: raise Exception( - 'DistConv: when kernel size is odd, padding must be equivalent to "same"' + f'DistConv: when kernel size is odd ({kernel_size}), padding must be equivalent to "same" but found ({padding})' ) else: if padding[shard_dim] != 0: @@ -240,17 +241,30 @@ def distconv_forward(func: Callable, args: Tuple, kwargs: Dict) -> "DCTensor": args = list(args) # Unpack the necessary arguments - tensor, weight, bias, stride, padding, dilation = args[:6] + tensor, weight, bias, stride, padding, dilation, transpose, output_padding = args[ + :8 + ] + padding_orig = copy(padding) # Extract the parallel strategy and shard dimension from the input tensor parallel_strategy = tensor._parallel_strategy shard_dim = parallel_strategy.shard_dim + shard_ind = parallel_strategy.shard_ind + world_size = parallel_strategy.world_size + kernel_size = weight.size(shard_dim) is_periodic = tensor._is_periodic if is_periodic: - assert padding[shard_dim - 2] == 0, ( - "Cannot zero-pad a tensor marked for periodic padding on the shard dimension" - ) - padding[shard_dim - 2] = tensor._periodic_shard_padding + if transpose: + assert padding[shard_dim - 2] == dilation[shard_dim - 2] * ( + kernel_size - 1 + ), f"padding is incorrect" + padding[shard_dim - 2] = tensor._periodic_shard_padding + padding_orig[shard_dim - 2] = tensor._periodic_shard_padding + else: + assert padding[shard_dim - 2] == 0, ( + "Cannot zero-pad a tensor marked for periodic padding on the shard dimension" + ) + padding[shard_dim - 2] = tensor._periodic_shard_padding # Unwrap the underlying tensor from the DCTensor torch_tensor = tensor._tensor @@ -261,7 +275,6 @@ def distconv_forward(func: Callable, args: Tuple, kwargs: Dict) -> "DCTensor": ) # Determine the halo size for halo exchange - kernel_size = weight.size(shard_dim) halo_size = kernel_size // 2 if (kernel_size % 2 == 1) else 0 # Perform forward halo exchange to prepare the tensor for convolution @@ -276,9 +289,23 @@ def distconv_forward(func: Callable, args: Tuple, kwargs: Dict) -> "DCTensor": ) # Update the arguments with the tensor including halos and adjusted padding - padding[shard_dim - 2] = 0 + if transpose: + padding[shard_dim - 2] = dilation[shard_dim - 2] * (kernel_size - 1) + for dim_i in range(tensor.ndim - 2): + if dim_i + 2 == shard_dim: + padding[shard_dim - 2] += (kernel_size - 1 - padding_orig[dim_i]) * ( + stride[dim_i] - 1 + ) + # modify output_padding for strided transpose convolution + if shard_ind < world_size - 1: + output_padding[dim_i] += ( + tensor.size(shard_dim) + 2 * padding_orig[dim_i] - kernel_size + ) % stride[dim_i] + else: + padding[shard_dim - 2] = 0 args[0] = tensor_with_halo args[4] = padding + args[7] = output_padding # Perform the convolution operation out_tensor = func(*args, **kwargs) @@ -305,19 +332,38 @@ def distconv_backward( args = list(args) # Unpack the necessary arguments - grad_out_tensor, input_tensor, weight, bias_size, stride, padding, dilation = args[ - :7 - ] + ( + grad_out_tensor, + input_tensor, + weight, + bias_size, + stride, + padding, + dilation, + transpose, + output_padding, + ) = args[:9] + padding_orig = copy(padding) # Extract the parallel strategy and shard dimension from the gradient output tensor parallel_strategy = grad_out_tensor._parallel_strategy shard_dim = parallel_strategy.shard_dim + shard_ind = parallel_strategy.shard_ind + world_size = parallel_strategy.world_size is_periodic = input_tensor._is_periodic + kernel_size = weight.size(shard_dim) if is_periodic: - assert padding[shard_dim - 2] == 0, ( - "Cannot zero-pad a tensor marked for periodic padding on the shard dimension" - ) - padding[shard_dim - 2] = input_tensor._periodic_shard_padding + if transpose: + assert padding[shard_dim - 2] == dilation[shard_dim - 2] * ( + kernel_size - 1 + ), f"shard-dim padding incorrect" + padding[shard_dim - 2] = input_tensor._periodic_shard_padding + padding_orig = [input_tensor._periodic_shard_padding] * len(padding) + else: + assert ( + padding[shard_dim - 2] == 0 + ), "Cannot zero-pad a tensor marked for periodic padding on the shard dimension" + padding[shard_dim - 2] = input_tensor._periodic_shard_padding # Unwrap the underlying tensors from the DCTensors grad_out_tensor = grad_out_tensor._tensor @@ -329,7 +375,6 @@ def distconv_backward( ) # Determine the halo size for halo exchange - kernel_size = weight.size(shard_dim) halo_size = kernel_size // 2 if (kernel_size % 2 == 1) else 0 # Get the input tensor including halos if available, otherwise perform forward halo exchange @@ -341,10 +386,24 @@ def distconv_backward( ) # Update the arguments with the gradient output tensor, input tensor including halos, and adjusted padding - padding[shard_dim - 2] = 0 + if transpose: + padding[shard_dim - 2] = dilation[shard_dim - 2] * (kernel_size - 1) + for dim_i in range(input_tensor.ndim - 2): + if dim_i + 2 == shard_dim: + padding[dim_i] += padding_orig[dim_i] * (stride[dim_i] - 1) + if shard_ind < world_size - 1: + crop_amount = ( + input_tensor.size(shard_dim) + + 2 * padding_orig[dim_i] + - kernel_size + ) % stride[dim_i] + output_padding[dim_i] = crop_amount + else: + padding[shard_dim - 2] = 0 args[0] = grad_out_tensor args[1] = input_tensor_with_halo args[5] = padding + args[8] = output_padding # Perform the backward convolution operation grad_in_tensor, grad_weight, grad_bias = func(*args, **kwargs) From 9507e5d7369d39a359e58f97ea5289f487d24723 Mon Sep 17 00:00:00 2001 From: Josh Williams Date: Tue, 1 Jul 2025 10:21:29 +0100 Subject: [PATCH 2/4] Add transpose convolution unit test for zero and periodic padding --- tests/test_convtranspose.py | 189 ++++++++++++++++++++++++++++++++++++ 1 file changed, 189 insertions(+) create mode 100644 tests/test_convtranspose.py diff --git a/tests/test_convtranspose.py b/tests/test_convtranspose.py new file mode 100644 index 0000000..29f059d --- /dev/null +++ b/tests/test_convtranspose.py @@ -0,0 +1,189 @@ +import pytest +import torch +import torch.distributed as dist +import torch.nn as nn +from utils import cleanup_parallel_strategy, fp32_allclose + +from distconv import DCTensor, DistConvDDP, ParallelStrategy + + +def all_gather_vlen(tensor: torch.Tensor, group=None, dim=0) -> list[torch.Tensor]: + """Gather tensors with the same number of dimensions but different lengths. + + Credit: https://stackoverflow.com/a/78934638 + """ + world_size = dist.get_world_size(group=group) + # Gather lengths first + shape = torch.as_tensor(tensor.shape, device=tensor.device) + shapes = [torch.empty_like(shape) for _ in range(world_size)] + dist.all_gather(shapes, shape, group=group) + # Gather data + inputs = [tensor] * world_size + outputs = [ + torch.empty(*_shape, dtype=tensor.dtype, device=tensor.device) + for _shape in shapes + ] + dist.all_to_all(outputs, inputs, group=group) + return torch.cat(outputs, dim=dim) + + +@pytest.fixture(scope="module") +def parallel_strategy(device: torch.device): + ps = ParallelStrategy(num_shards=4, device_type=device.type) + yield ps + cleanup_parallel_strategy(ps) + + +def find_padding(kernel_size): + if kernel_size % 2 != 0: + return kernel_size // 2 + else: + return 0 + + +def generate_configs(): + configs = [] + for ndims in [1, 2, 3]: + for shard_dim in range(ndims): + for kernel_size in [1, 3, 5]: + for stride in [1, 2, 4]: + configs.append((ndims, shard_dim, kernel_size, stride)) + + return "ndims,shard_dim,kernel_size,stride", configs + + +@pytest.mark.parametrize(*generate_configs()) +def test_transposeconv_zerospadding( + parallel_strategy: ParallelStrategy, + ndims: int, + shard_dim: int, + kernel_size: int, + padding: int, + stride: int, + device: torch.device, +): + """ + Test distributed convolution with different number of dimensions, kernel sizes, and strides. + Checks the output and gradients of the distributed convolution against the non-distributed + convolution. + + Args: + parallel_strategy (ParallelStrategy): Parallel strategy for the distributed convolution. + ndims (int): Number of dimensions for the convolution (1, 2, or 3). + shard_dim (int): Dimension along which the tensor is sharded. + kernel_size (int): Size of the convolution kernel. + stride (int): Stride of the convolution. + device (torch.device): Torch device to run test with. + """ + # Set the shard dimension for the parallel strategy + parallel_strategy.shard_dim = 2 + shard_dim + padding = find_padding(kernel_size) + + # Initialize the input tensor and convolution layer + shape = [1, 4] + [64] * ndims + x = torch.randn(*shape, device=device, requires_grad=True) + conv_class = getattr(nn, f"ConvTranspose{ndims}d") + conv = conv_class(4, 8, kernel_size=kernel_size, padding=padding, stride=stride).to( + device + ) + + # Perform forward and backward pass for reference (non-distributed) convolution + conv.zero_grad() + ref_y = conv(x) + ref_y.sum().backward() + ref_x_grad = x.grad + ref_conv_grad = conv.weight.grad + + # Perform forward and backward pass for distributed convolution + conv.zero_grad() + dist_conv = DistConvDDP(conv, parallel_strategy=parallel_strategy) + dcx = DCTensor.distribute(x, parallel_strategy) + dcy = dist_conv(dcx) + dcy_merge = all_gather_vlen(dcy, dim=(parallel_strategy.shard_dim)) + dc_loss = dcy.sum() + dist.all_reduce(dc_loss) + dc_loss.backward() + x_grad = dcx.grad.to_replicate() + dc_conv_grad = conv.weight.grad + + assert fp32_allclose(ref_y, dcy_merge) + assert fp32_allclose(ref_x_grad, x_grad) + assert fp32_allclose(ref_conv_grad, dc_conv_grad) + + +@pytest.mark.parametrize(*generate_configs()) +def test_transposeconv_circularpadding( + parallel_strategy: ParallelStrategy, + ndims: int, + shard_dim: int, + kernel_size: int, + stride: int, + device: torch.device, +): + """ + Test distributed convolution with different number of dimensions, kernel sizes, and strides. + Checks the output and gradients of the distributed convolution against the non-distributed + convolution. + + Args: + parallel_strategy (ParallelStrategy): Parallel strategy for the distributed convolution. + ndims (int): Number of dimensions for the convolution (1, 2, or 3). + shard_dim (int): Dimension along which the tensor is sharded. + kernel_size (int): Size of the convolution kernel. + stride (int): Stride of the convolution. + device (torch.device): Torch device to run test with. + """ + # Set the shard dimension for the parallel strategy + parallel_strategy.shard_dim = 2 + shard_dim + padding = find_padding(kernel_size) + + # Initialize the input tensor and convolution layer + shape = [1, 4] + [64] * ndims + x = torch.randn(*shape, device=device, requires_grad=True) + + conv_kwargs = dict(kernel_size=kernel_size, stride=stride) + + # set periodic padding case for reference + new_padding = [padding, padding] * ndims + x_periodic = torch.nn.functional.pad(input=x, pad=new_padding, mode="circular") + ref_padding = kernel_size - 1 + + conv_class = getattr(nn, f"ConvTranspose{ndims}d") + conv = ( + conv_class(4, 8, padding=ref_padding, **conv_kwargs) + .to(device) + .requires_grad_(False) + ) + conv.requires_grad_(True) + + # Perform forward and backward pass for reference (non-distributed) convolution + conv.zero_grad() + ref_y = conv(x_periodic) + for i in range(0, ndims): + crop_amount = (kernel_size - 1 - padding) * (stride - 1) + ref_y = ref_y.narrow(i + 2, crop_amount, ref_y.shape[i + 2] - 2 * crop_amount) + ref_y.sum().backward() + ref_x_grad = x.grad + ref_conv_grad = conv.weight.grad + + # Perform forward and backward pass for distributed convolution + conv.zero_grad() + dist_conv = DistConvDDP(conv, parallel_strategy=parallel_strategy) + dcx = DCTensor.distribute(x, parallel_strategy) + dcx_periodic = torch.nn.functional.pad(input=dcx, pad=new_padding, mode="circular") + dcy = dist_conv(dcx_periodic) + for i in range(0, ndims): + if i != shard_dim: + crop_amount = (kernel_size - 1 - padding) * (stride - 1) + dcy = dcy.narrow(i + 2, crop_amount, dcy.shape[i + 2] - 2 * crop_amount) + dcy_merge = all_gather_vlen(dcy.contiguous(), dim=(parallel_strategy.shard_dim)) + dc_loss = dcy.sum() + dist.all_reduce(dc_loss) + dc_loss.backward() + x_grad = dcx.grad.to_replicate() + dc_conv_grad = conv.weight.grad + + # Validate the results + assert fp32_allclose(ref_y, dcy_merge) + assert fp32_allclose(ref_x_grad, x_grad) + assert fp32_allclose(ref_conv_grad, dc_conv_grad) From 70f1cec4b4cd4bd3e22fb029c8ac132c0810955a Mon Sep 17 00:00:00 2001 From: Pier Fiedorowicz Date: Wed, 9 Jul 2025 16:37:20 -0700 Subject: [PATCH 3/4] Improve and simplify constraint checking and update tests --- distconv/distconv.py | 117 +++++++++++++++++------------------- tests/test_convtranspose.py | 66 ++++++++++---------- tests/test_periodic.py | 3 +- 3 files changed, 92 insertions(+), 94 deletions(-) diff --git a/distconv/distconv.py b/distconv/distconv.py index 178195b..1df47af 100644 --- a/distconv/distconv.py +++ b/distconv/distconv.py @@ -50,6 +50,8 @@ def check_is_distconv_supported( stride: List[int], padding: List[int], dilation: List[int], + transpose: bool, + output_padding: List[int], ) -> None: """ Check if the distributed convolution is supported with the given parameters. @@ -61,31 +63,42 @@ def check_is_distconv_supported( stride (List[int]): The stride of the convolution. padding (List[int]): The padding added to the input tensor. dilation (List[int]): The dilation applied to the kernel. + transpose (bool): Is transposed convolution. + dilation (List[int]): The output padding for transposed convolution. Raises: - Exception: If dilation is not 1. - Exception: If input size is not divisible by stride. - Exception: If kernel size is odd and padding is not equivalent to "same". - Exception: If kernel size is even and padding is not zero. - Exception: If kernel size is even and stride is not divisible by kernel size. + Exception: If local input size is not equal to stride times output size. + Exception: If local output size is not equal to stride times input size for transposed convolution. """ shard_dim = tensor_shard_dim - 2 kernel_size = weight.size(tensor_shard_dim) if dilation[shard_dim] != 1: raise Exception("DistConv: dilation must be 1") - if tensor.size(tensor_shard_dim) % stride[shard_dim] != 0: - raise Exception("DistConv: input size must be divisible by stride") - if kernel_size % 2 == 1: - if (kernel_size // 2) != padding[shard_dim]: + + input_size = tensor.size(tensor_shard_dim) + + if not transpose: + output_size = (input_size + 2 * padding[shard_dim] - kernel_size) // stride[ + shard_dim + ] + 1 + + if output_size * stride[shard_dim] != input_size: raise Exception( - f'DistConv: when kernel size is odd ({kernel_size}), padding must be equivalent to "same" but found ({padding})' + "DistConv: The input size along the shard dimension must equal the stride times the output size for the local tensors.\n" + + "This indicates incompatible kernel size, stride, and/or padding for the given input shape and parallel strategy." ) else: - if padding[shard_dim] != 0: - raise Exception("DistConv: when kernel size is even, padding must be zero") - if stride[shard_dim] % kernel_size != 0: + output_size = ( + (input_size - 1) * stride[shard_dim] + - 2 * padding[shard_dim] + + kernel_size + + output_padding[shard_dim] + ) + + if output_size != input_size * stride[shard_dim]: raise Exception( - "DistConv: when kernel size is even, stride must be divisble by kernel size" + "DistConv: The output size along the shard dimension must equal the stride times the input size for the local tensors.\n" + + "This indicates incompatible kernel size, stride, padding, and/or output padding for the given input shape and parallel strategy." ) @@ -244,22 +257,16 @@ def distconv_forward(func: Callable, args: Tuple, kwargs: Dict) -> "DCTensor": tensor, weight, bias, stride, padding, dilation, transpose, output_padding = args[ :8 ] - padding_orig = copy(padding) # Extract the parallel strategy and shard dimension from the input tensor parallel_strategy = tensor._parallel_strategy shard_dim = parallel_strategy.shard_dim - shard_ind = parallel_strategy.shard_ind - world_size = parallel_strategy.world_size - kernel_size = weight.size(shard_dim) is_periodic = tensor._is_periodic if is_periodic: if transpose: - assert padding[shard_dim - 2] == dilation[shard_dim - 2] * ( - kernel_size - 1 - ), f"padding is incorrect" - padding[shard_dim - 2] = tensor._periodic_shard_padding - padding_orig[shard_dim - 2] = tensor._periodic_shard_padding + padding[shard_dim - 2] -= ( + stride[shard_dim - 2] * tensor._periodic_shard_padding + ) else: assert padding[shard_dim - 2] == 0, ( "Cannot zero-pad a tensor marked for periodic padding on the shard dimension" @@ -271,11 +278,18 @@ def distconv_forward(func: Callable, args: Tuple, kwargs: Dict) -> "DCTensor": # Check if the distributed convolution is supported with the given parameters check_is_distconv_supported( - shard_dim, torch_tensor, weight, stride, padding, dilation + shard_dim, + torch_tensor, + weight, + stride, + padding, + dilation, + transpose, + output_padding, ) # Determine the halo size for halo exchange - halo_size = kernel_size // 2 if (kernel_size % 2 == 1) else 0 + halo_size = padding[shard_dim - 2] # Perform forward halo exchange to prepare the tensor for convolution tensor_with_halo = forward_halo_exchange( @@ -290,17 +304,7 @@ def distconv_forward(func: Callable, args: Tuple, kwargs: Dict) -> "DCTensor": # Update the arguments with the tensor including halos and adjusted padding if transpose: - padding[shard_dim - 2] = dilation[shard_dim - 2] * (kernel_size - 1) - for dim_i in range(tensor.ndim - 2): - if dim_i + 2 == shard_dim: - padding[shard_dim - 2] += (kernel_size - 1 - padding_orig[dim_i]) * ( - stride[dim_i] - 1 - ) - # modify output_padding for strided transpose convolution - if shard_ind < world_size - 1: - output_padding[dim_i] += ( - tensor.size(shard_dim) + 2 * padding_orig[dim_i] - kernel_size - ) % stride[dim_i] + padding[shard_dim - 2] += stride[shard_dim - 2] * halo_size else: padding[shard_dim - 2] = 0 args[0] = tensor_with_halo @@ -343,26 +347,20 @@ def distconv_backward( transpose, output_padding, ) = args[:9] - padding_orig = copy(padding) # Extract the parallel strategy and shard dimension from the gradient output tensor parallel_strategy = grad_out_tensor._parallel_strategy shard_dim = parallel_strategy.shard_dim - shard_ind = parallel_strategy.shard_ind - world_size = parallel_strategy.world_size is_periodic = input_tensor._is_periodic - kernel_size = weight.size(shard_dim) if is_periodic: if transpose: - assert padding[shard_dim - 2] == dilation[shard_dim - 2] * ( - kernel_size - 1 - ), f"shard-dim padding incorrect" - padding[shard_dim - 2] = input_tensor._periodic_shard_padding - padding_orig = [input_tensor._periodic_shard_padding] * len(padding) + padding[shard_dim - 2] -= ( + stride[shard_dim - 2] * input_tensor._periodic_shard_padding + ) else: - assert ( - padding[shard_dim - 2] == 0 - ), "Cannot zero-pad a tensor marked for periodic padding on the shard dimension" + assert padding[shard_dim - 2] == 0, ( + "Cannot zero-pad a tensor marked for periodic padding on the shard dimension" + ) padding[shard_dim - 2] = input_tensor._periodic_shard_padding # Unwrap the underlying tensors from the DCTensors @@ -371,11 +369,18 @@ def distconv_backward( # Check if the distributed convolution is supported with the given parameters check_is_distconv_supported( - shard_dim, input_torch_tensor, weight, stride, padding, dilation + shard_dim, + input_torch_tensor, + weight, + stride, + padding, + dilation, + transpose, + output_padding, ) # Determine the halo size for halo exchange - halo_size = kernel_size // 2 if (kernel_size % 2 == 1) else 0 + halo_size = padding[shard_dim - 2] # Get the input tensor including halos if available, otherwise perform forward halo exchange if input_tensor._tensor_with_halo is not None: @@ -387,17 +392,7 @@ def distconv_backward( # Update the arguments with the gradient output tensor, input tensor including halos, and adjusted padding if transpose: - padding[shard_dim - 2] = dilation[shard_dim - 2] * (kernel_size - 1) - for dim_i in range(input_tensor.ndim - 2): - if dim_i + 2 == shard_dim: - padding[dim_i] += padding_orig[dim_i] * (stride[dim_i] - 1) - if shard_ind < world_size - 1: - crop_amount = ( - input_tensor.size(shard_dim) - + 2 * padding_orig[dim_i] - - kernel_size - ) % stride[dim_i] - output_padding[dim_i] = crop_amount + padding[shard_dim - 2] += stride[shard_dim - 2] * halo_size else: padding[shard_dim - 2] = 0 args[0] = grad_out_tensor diff --git a/tests/test_convtranspose.py b/tests/test_convtranspose.py index 29f059d..5c6db17 100644 --- a/tests/test_convtranspose.py +++ b/tests/test_convtranspose.py @@ -34,11 +34,13 @@ def parallel_strategy(device: torch.device): cleanup_parallel_strategy(ps) -def find_padding(kernel_size): - if kernel_size % 2 != 0: - return kernel_size // 2 - else: - return 0 +def find_padding(kernel_size, stride=1, explicit_padding=False): + ep = kernel_size // 2 if explicit_padding else 0 + pad = (kernel_size + 2 * ep * stride - 1) // 2 + out_pad = stride - 1 + if explicit_padding: + return pad, out_pad, ep + return pad, out_pad def generate_configs(): @@ -58,7 +60,6 @@ def test_transposeconv_zerospadding( ndims: int, shard_dim: int, kernel_size: int, - padding: int, stride: int, device: torch.device, ): @@ -77,20 +78,25 @@ def test_transposeconv_zerospadding( """ # Set the shard dimension for the parallel strategy parallel_strategy.shard_dim = 2 + shard_dim - padding = find_padding(kernel_size) + padding, output_padding = find_padding(kernel_size, stride) # Initialize the input tensor and convolution layer shape = [1, 4] + [64] * ndims x = torch.randn(*shape, device=device, requires_grad=True) conv_class = getattr(nn, f"ConvTranspose{ndims}d") - conv = conv_class(4, 8, kernel_size=kernel_size, padding=padding, stride=stride).to( - device - ) + conv = conv_class( + 4, + 8, + kernel_size=kernel_size, + padding=padding, + stride=stride, + output_padding=output_padding, + ).to(device) # Perform forward and backward pass for reference (non-distributed) convolution conv.zero_grad() ref_y = conv(x) - ref_y.sum().backward() + ref_y.square().mean().backward() ref_x_grad = x.grad ref_conv_grad = conv.weight.grad @@ -99,9 +105,8 @@ def test_transposeconv_zerospadding( dist_conv = DistConvDDP(conv, parallel_strategy=parallel_strategy) dcx = DCTensor.distribute(x, parallel_strategy) dcy = dist_conv(dcx) - dcy_merge = all_gather_vlen(dcy, dim=(parallel_strategy.shard_dim)) - dc_loss = dcy.sum() - dist.all_reduce(dc_loss) + dcy_merge = dcy.to_replicate() + dc_loss = dcy.to_ddp().square().mean() dc_loss.backward() x_grad = dcx.grad.to_replicate() dc_conv_grad = conv.weight.grad @@ -135,22 +140,25 @@ def test_transposeconv_circularpadding( """ # Set the shard dimension for the parallel strategy parallel_strategy.shard_dim = 2 + shard_dim - padding = find_padding(kernel_size) + padding, output_padding, explicit_padding = find_padding( + kernel_size, stride, explicit_padding=True + ) # Initialize the input tensor and convolution layer shape = [1, 4] + [64] * ndims x = torch.randn(*shape, device=device, requires_grad=True) - conv_kwargs = dict(kernel_size=kernel_size, stride=stride) + conv_kwargs = dict( + kernel_size=kernel_size, stride=stride, output_padding=output_padding + ) # set periodic padding case for reference - new_padding = [padding, padding] * ndims - x_periodic = torch.nn.functional.pad(input=x, pad=new_padding, mode="circular") - ref_padding = kernel_size - 1 + explicit_padding = [explicit_padding, explicit_padding] * ndims + x_periodic = torch.nn.functional.pad(input=x, pad=explicit_padding, mode="circular") conv_class = getattr(nn, f"ConvTranspose{ndims}d") conv = ( - conv_class(4, 8, padding=ref_padding, **conv_kwargs) + conv_class(4, 8, padding=padding, **conv_kwargs) .to(device) .requires_grad_(False) ) @@ -159,10 +167,7 @@ def test_transposeconv_circularpadding( # Perform forward and backward pass for reference (non-distributed) convolution conv.zero_grad() ref_y = conv(x_periodic) - for i in range(0, ndims): - crop_amount = (kernel_size - 1 - padding) * (stride - 1) - ref_y = ref_y.narrow(i + 2, crop_amount, ref_y.shape[i + 2] - 2 * crop_amount) - ref_y.sum().backward() + ref_y.square().mean().backward() ref_x_grad = x.grad ref_conv_grad = conv.weight.grad @@ -170,15 +175,12 @@ def test_transposeconv_circularpadding( conv.zero_grad() dist_conv = DistConvDDP(conv, parallel_strategy=parallel_strategy) dcx = DCTensor.distribute(x, parallel_strategy) - dcx_periodic = torch.nn.functional.pad(input=dcx, pad=new_padding, mode="circular") + dcx_periodic = torch.nn.functional.pad( + input=dcx, pad=explicit_padding, mode="circular" + ) dcy = dist_conv(dcx_periodic) - for i in range(0, ndims): - if i != shard_dim: - crop_amount = (kernel_size - 1 - padding) * (stride - 1) - dcy = dcy.narrow(i + 2, crop_amount, dcy.shape[i + 2] - 2 * crop_amount) - dcy_merge = all_gather_vlen(dcy.contiguous(), dim=(parallel_strategy.shard_dim)) - dc_loss = dcy.sum() - dist.all_reduce(dc_loss) + dcy_merge = dcy.to_replicate() + dc_loss = dcy.to_ddp().square().mean() dc_loss.backward() x_grad = dcx.grad.to_replicate() dc_conv_grad = conv.weight.grad diff --git a/tests/test_periodic.py b/tests/test_periodic.py index af6e309..13edc6d 100644 --- a/tests/test_periodic.py +++ b/tests/test_periodic.py @@ -1,3 +1,4 @@ +from math import ceil import pytest import torch import torch.distributed as dist @@ -50,7 +51,7 @@ def test_periodic( conv_kwargs = dict( kernel_size=kernel_size, - padding=kernel_size // 2, + padding=ceil((kernel_size - stride) / 2), bias=False, stride=stride, padding_mode="circular", From 7095c8656e6d39df89f69f959b849b5e8e6789d6 Mon Sep 17 00:00:00 2001 From: Pier Fiedorowicz Date: Wed, 9 Jul 2025 16:46:00 -0700 Subject: [PATCH 4/4] Ruff formatting --- distconv/distconv.py | 1 - tests/test_basic.py | 4 ++-- tests/test_convtranspose.py | 4 ++-- tests/test_ddp_with_distconv.py | 4 ++-- tests/test_periodic.py | 5 +++-- tests/test_strides.py | 4 ++-- tests/utils.py | 1 - 7 files changed, 11 insertions(+), 12 deletions(-) diff --git a/distconv/distconv.py b/distconv/distconv.py index 1df47af..c89e4da 100644 --- a/distconv/distconv.py +++ b/distconv/distconv.py @@ -1,4 +1,3 @@ -from copy import copy from typing import Callable, Dict, List, Tuple import torch diff --git a/tests/test_basic.py b/tests/test_basic.py index d0748f5..9322cc1 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -2,10 +2,10 @@ import torch import torch.distributed as dist import torch.nn as nn -from utils import cleanup_parallel_strategy, fp32_allclose - from distconv import DCTensor, DistConvDDP, ParallelStrategy +from utils import cleanup_parallel_strategy, fp32_allclose + @pytest.fixture(scope="module") def parallel_strategy(device: torch.device): diff --git a/tests/test_convtranspose.py b/tests/test_convtranspose.py index 5c6db17..f353955 100644 --- a/tests/test_convtranspose.py +++ b/tests/test_convtranspose.py @@ -2,10 +2,10 @@ import torch import torch.distributed as dist import torch.nn as nn -from utils import cleanup_parallel_strategy, fp32_allclose - from distconv import DCTensor, DistConvDDP, ParallelStrategy +from utils import cleanup_parallel_strategy, fp32_allclose + def all_gather_vlen(tensor: torch.Tensor, group=None, dim=0) -> list[torch.Tensor]: """Gather tensors with the same number of dimensions but different lengths. diff --git a/tests/test_ddp_with_distconv.py b/tests/test_ddp_with_distconv.py index f683d21..189c38e 100644 --- a/tests/test_ddp_with_distconv.py +++ b/tests/test_ddp_with_distconv.py @@ -2,11 +2,11 @@ import torch import torch.distributed as dist import torch.nn as nn +from distconv import DCTensor, DistConvDDP, ParallelStrategy from torch.distributed.tensor import Replicate, Shard, distribute_tensor from torch.nn.parallel import DistributedDataParallel as DDP -from utils import cleanup_parallel_strategy, fp32_allclose -from distconv import DCTensor, DistConvDDP, ParallelStrategy +from utils import cleanup_parallel_strategy, fp32_allclose @pytest.fixture(scope="module") diff --git a/tests/test_periodic.py b/tests/test_periodic.py index 13edc6d..9d5dd14 100644 --- a/tests/test_periodic.py +++ b/tests/test_periodic.py @@ -1,12 +1,13 @@ from math import ceil + import pytest import torch import torch.distributed as dist import torch.nn as nn -from utils import cleanup_parallel_strategy, fp32_allclose - from distconv import DCTensor, DistConvDDP, ParallelStrategy +from utils import cleanup_parallel_strategy, fp32_allclose + def generate_configs(): configs = [] diff --git a/tests/test_strides.py b/tests/test_strides.py index efa82a6..d9e421e 100644 --- a/tests/test_strides.py +++ b/tests/test_strides.py @@ -2,10 +2,10 @@ import torch import torch.distributed as dist import torch.nn as nn -from utils import cleanup_parallel_strategy, fp32_allclose - from distconv import DCTensor, DistConvDDP, ParallelStrategy +from utils import cleanup_parallel_strategy, fp32_allclose + @pytest.fixture(scope="module") def parallel_strategy(device: torch.device): diff --git a/tests/utils.py b/tests/utils.py index 891254b..4e6fb39 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,6 +1,5 @@ import torch import torch.distributed as dist - from distconv import ParallelStrategy