diff --git a/distconv/distconv.py b/distconv/distconv.py index 7316163..c89e4da 100644 --- a/distconv/distconv.py +++ b/distconv/distconv.py @@ -49,6 +49,8 @@ def check_is_distconv_supported( stride: List[int], padding: List[int], dilation: List[int], + transpose: bool, + output_padding: List[int], ) -> None: """ Check if the distributed convolution is supported with the given parameters. @@ -60,31 +62,42 @@ def check_is_distconv_supported( stride (List[int]): The stride of the convolution. padding (List[int]): The padding added to the input tensor. dilation (List[int]): The dilation applied to the kernel. + transpose (bool): Is transposed convolution. + dilation (List[int]): The output padding for transposed convolution. Raises: - Exception: If dilation is not 1. - Exception: If input size is not divisible by stride. - Exception: If kernel size is odd and padding is not equivalent to "same". - Exception: If kernel size is even and padding is not zero. - Exception: If kernel size is even and stride is not divisible by kernel size. + Exception: If local input size is not equal to stride times output size. + Exception: If local output size is not equal to stride times input size for transposed convolution. """ shard_dim = tensor_shard_dim - 2 kernel_size = weight.size(tensor_shard_dim) if dilation[shard_dim] != 1: raise Exception("DistConv: dilation must be 1") - if tensor.size(tensor_shard_dim) % stride[shard_dim] != 0: - raise Exception("DistConv: input size must be divisible by stride") - if kernel_size % 2 == 1: - if (kernel_size // 2) != padding[shard_dim]: + + input_size = tensor.size(tensor_shard_dim) + + if not transpose: + output_size = (input_size + 2 * padding[shard_dim] - kernel_size) // stride[ + shard_dim + ] + 1 + + if output_size * stride[shard_dim] != input_size: raise Exception( - 'DistConv: when kernel size is odd, padding must be equivalent to "same"' + "DistConv: The input size along the shard dimension must equal the stride times the output size for the local tensors.\n" + + "This indicates incompatible kernel size, stride, and/or padding for the given input shape and parallel strategy." ) else: - if padding[shard_dim] != 0: - raise Exception("DistConv: when kernel size is even, padding must be zero") - if stride[shard_dim] % kernel_size != 0: + output_size = ( + (input_size - 1) * stride[shard_dim] + - 2 * padding[shard_dim] + + kernel_size + + output_padding[shard_dim] + ) + + if output_size != input_size * stride[shard_dim]: raise Exception( - "DistConv: when kernel size is even, stride must be divisble by kernel size" + "DistConv: The output size along the shard dimension must equal the stride times the input size for the local tensors.\n" + + "This indicates incompatible kernel size, stride, padding, and/or output padding for the given input shape and parallel strategy." ) @@ -240,29 +253,42 @@ def distconv_forward(func: Callable, args: Tuple, kwargs: Dict) -> "DCTensor": args = list(args) # Unpack the necessary arguments - tensor, weight, bias, stride, padding, dilation = args[:6] + tensor, weight, bias, stride, padding, dilation, transpose, output_padding = args[ + :8 + ] # Extract the parallel strategy and shard dimension from the input tensor parallel_strategy = tensor._parallel_strategy shard_dim = parallel_strategy.shard_dim is_periodic = tensor._is_periodic if is_periodic: - assert padding[shard_dim - 2] == 0, ( - "Cannot zero-pad a tensor marked for periodic padding on the shard dimension" - ) - padding[shard_dim - 2] = tensor._periodic_shard_padding + if transpose: + padding[shard_dim - 2] -= ( + stride[shard_dim - 2] * tensor._periodic_shard_padding + ) + else: + assert padding[shard_dim - 2] == 0, ( + "Cannot zero-pad a tensor marked for periodic padding on the shard dimension" + ) + padding[shard_dim - 2] = tensor._periodic_shard_padding # Unwrap the underlying tensor from the DCTensor torch_tensor = tensor._tensor # Check if the distributed convolution is supported with the given parameters check_is_distconv_supported( - shard_dim, torch_tensor, weight, stride, padding, dilation + shard_dim, + torch_tensor, + weight, + stride, + padding, + dilation, + transpose, + output_padding, ) # Determine the halo size for halo exchange - kernel_size = weight.size(shard_dim) - halo_size = kernel_size // 2 if (kernel_size % 2 == 1) else 0 + halo_size = padding[shard_dim - 2] # Perform forward halo exchange to prepare the tensor for convolution tensor_with_halo = forward_halo_exchange( @@ -276,9 +302,13 @@ def distconv_forward(func: Callable, args: Tuple, kwargs: Dict) -> "DCTensor": ) # Update the arguments with the tensor including halos and adjusted padding - padding[shard_dim - 2] = 0 + if transpose: + padding[shard_dim - 2] += stride[shard_dim - 2] * halo_size + else: + padding[shard_dim - 2] = 0 args[0] = tensor_with_halo args[4] = padding + args[7] = output_padding # Perform the convolution operation out_tensor = func(*args, **kwargs) @@ -305,19 +335,32 @@ def distconv_backward( args = list(args) # Unpack the necessary arguments - grad_out_tensor, input_tensor, weight, bias_size, stride, padding, dilation = args[ - :7 - ] + ( + grad_out_tensor, + input_tensor, + weight, + bias_size, + stride, + padding, + dilation, + transpose, + output_padding, + ) = args[:9] # Extract the parallel strategy and shard dimension from the gradient output tensor parallel_strategy = grad_out_tensor._parallel_strategy shard_dim = parallel_strategy.shard_dim is_periodic = input_tensor._is_periodic if is_periodic: - assert padding[shard_dim - 2] == 0, ( - "Cannot zero-pad a tensor marked for periodic padding on the shard dimension" - ) - padding[shard_dim - 2] = input_tensor._periodic_shard_padding + if transpose: + padding[shard_dim - 2] -= ( + stride[shard_dim - 2] * input_tensor._periodic_shard_padding + ) + else: + assert padding[shard_dim - 2] == 0, ( + "Cannot zero-pad a tensor marked for periodic padding on the shard dimension" + ) + padding[shard_dim - 2] = input_tensor._periodic_shard_padding # Unwrap the underlying tensors from the DCTensors grad_out_tensor = grad_out_tensor._tensor @@ -325,12 +368,18 @@ def distconv_backward( # Check if the distributed convolution is supported with the given parameters check_is_distconv_supported( - shard_dim, input_torch_tensor, weight, stride, padding, dilation + shard_dim, + input_torch_tensor, + weight, + stride, + padding, + dilation, + transpose, + output_padding, ) # Determine the halo size for halo exchange - kernel_size = weight.size(shard_dim) - halo_size = kernel_size // 2 if (kernel_size % 2 == 1) else 0 + halo_size = padding[shard_dim - 2] # Get the input tensor including halos if available, otherwise perform forward halo exchange if input_tensor._tensor_with_halo is not None: @@ -341,10 +390,14 @@ def distconv_backward( ) # Update the arguments with the gradient output tensor, input tensor including halos, and adjusted padding - padding[shard_dim - 2] = 0 + if transpose: + padding[shard_dim - 2] += stride[shard_dim - 2] * halo_size + else: + padding[shard_dim - 2] = 0 args[0] = grad_out_tensor args[1] = input_tensor_with_halo args[5] = padding + args[8] = output_padding # Perform the backward convolution operation grad_in_tensor, grad_weight, grad_bias = func(*args, **kwargs) diff --git a/tests/test_basic.py b/tests/test_basic.py index d0748f5..9322cc1 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -2,10 +2,10 @@ import torch import torch.distributed as dist import torch.nn as nn -from utils import cleanup_parallel_strategy, fp32_allclose - from distconv import DCTensor, DistConvDDP, ParallelStrategy +from utils import cleanup_parallel_strategy, fp32_allclose + @pytest.fixture(scope="module") def parallel_strategy(device: torch.device): diff --git a/tests/test_convtranspose.py b/tests/test_convtranspose.py new file mode 100644 index 0000000..f353955 --- /dev/null +++ b/tests/test_convtranspose.py @@ -0,0 +1,191 @@ +import pytest +import torch +import torch.distributed as dist +import torch.nn as nn +from distconv import DCTensor, DistConvDDP, ParallelStrategy + +from utils import cleanup_parallel_strategy, fp32_allclose + + +def all_gather_vlen(tensor: torch.Tensor, group=None, dim=0) -> list[torch.Tensor]: + """Gather tensors with the same number of dimensions but different lengths. + + Credit: https://stackoverflow.com/a/78934638 + """ + world_size = dist.get_world_size(group=group) + # Gather lengths first + shape = torch.as_tensor(tensor.shape, device=tensor.device) + shapes = [torch.empty_like(shape) for _ in range(world_size)] + dist.all_gather(shapes, shape, group=group) + # Gather data + inputs = [tensor] * world_size + outputs = [ + torch.empty(*_shape, dtype=tensor.dtype, device=tensor.device) + for _shape in shapes + ] + dist.all_to_all(outputs, inputs, group=group) + return torch.cat(outputs, dim=dim) + + +@pytest.fixture(scope="module") +def parallel_strategy(device: torch.device): + ps = ParallelStrategy(num_shards=4, device_type=device.type) + yield ps + cleanup_parallel_strategy(ps) + + +def find_padding(kernel_size, stride=1, explicit_padding=False): + ep = kernel_size // 2 if explicit_padding else 0 + pad = (kernel_size + 2 * ep * stride - 1) // 2 + out_pad = stride - 1 + if explicit_padding: + return pad, out_pad, ep + return pad, out_pad + + +def generate_configs(): + configs = [] + for ndims in [1, 2, 3]: + for shard_dim in range(ndims): + for kernel_size in [1, 3, 5]: + for stride in [1, 2, 4]: + configs.append((ndims, shard_dim, kernel_size, stride)) + + return "ndims,shard_dim,kernel_size,stride", configs + + +@pytest.mark.parametrize(*generate_configs()) +def test_transposeconv_zerospadding( + parallel_strategy: ParallelStrategy, + ndims: int, + shard_dim: int, + kernel_size: int, + stride: int, + device: torch.device, +): + """ + Test distributed convolution with different number of dimensions, kernel sizes, and strides. + Checks the output and gradients of the distributed convolution against the non-distributed + convolution. + + Args: + parallel_strategy (ParallelStrategy): Parallel strategy for the distributed convolution. + ndims (int): Number of dimensions for the convolution (1, 2, or 3). + shard_dim (int): Dimension along which the tensor is sharded. + kernel_size (int): Size of the convolution kernel. + stride (int): Stride of the convolution. + device (torch.device): Torch device to run test with. + """ + # Set the shard dimension for the parallel strategy + parallel_strategy.shard_dim = 2 + shard_dim + padding, output_padding = find_padding(kernel_size, stride) + + # Initialize the input tensor and convolution layer + shape = [1, 4] + [64] * ndims + x = torch.randn(*shape, device=device, requires_grad=True) + conv_class = getattr(nn, f"ConvTranspose{ndims}d") + conv = conv_class( + 4, + 8, + kernel_size=kernel_size, + padding=padding, + stride=stride, + output_padding=output_padding, + ).to(device) + + # Perform forward and backward pass for reference (non-distributed) convolution + conv.zero_grad() + ref_y = conv(x) + ref_y.square().mean().backward() + ref_x_grad = x.grad + ref_conv_grad = conv.weight.grad + + # Perform forward and backward pass for distributed convolution + conv.zero_grad() + dist_conv = DistConvDDP(conv, parallel_strategy=parallel_strategy) + dcx = DCTensor.distribute(x, parallel_strategy) + dcy = dist_conv(dcx) + dcy_merge = dcy.to_replicate() + dc_loss = dcy.to_ddp().square().mean() + dc_loss.backward() + x_grad = dcx.grad.to_replicate() + dc_conv_grad = conv.weight.grad + + assert fp32_allclose(ref_y, dcy_merge) + assert fp32_allclose(ref_x_grad, x_grad) + assert fp32_allclose(ref_conv_grad, dc_conv_grad) + + +@pytest.mark.parametrize(*generate_configs()) +def test_transposeconv_circularpadding( + parallel_strategy: ParallelStrategy, + ndims: int, + shard_dim: int, + kernel_size: int, + stride: int, + device: torch.device, +): + """ + Test distributed convolution with different number of dimensions, kernel sizes, and strides. + Checks the output and gradients of the distributed convolution against the non-distributed + convolution. + + Args: + parallel_strategy (ParallelStrategy): Parallel strategy for the distributed convolution. + ndims (int): Number of dimensions for the convolution (1, 2, or 3). + shard_dim (int): Dimension along which the tensor is sharded. + kernel_size (int): Size of the convolution kernel. + stride (int): Stride of the convolution. + device (torch.device): Torch device to run test with. + """ + # Set the shard dimension for the parallel strategy + parallel_strategy.shard_dim = 2 + shard_dim + padding, output_padding, explicit_padding = find_padding( + kernel_size, stride, explicit_padding=True + ) + + # Initialize the input tensor and convolution layer + shape = [1, 4] + [64] * ndims + x = torch.randn(*shape, device=device, requires_grad=True) + + conv_kwargs = dict( + kernel_size=kernel_size, stride=stride, output_padding=output_padding + ) + + # set periodic padding case for reference + explicit_padding = [explicit_padding, explicit_padding] * ndims + x_periodic = torch.nn.functional.pad(input=x, pad=explicit_padding, mode="circular") + + conv_class = getattr(nn, f"ConvTranspose{ndims}d") + conv = ( + conv_class(4, 8, padding=padding, **conv_kwargs) + .to(device) + .requires_grad_(False) + ) + conv.requires_grad_(True) + + # Perform forward and backward pass for reference (non-distributed) convolution + conv.zero_grad() + ref_y = conv(x_periodic) + ref_y.square().mean().backward() + ref_x_grad = x.grad + ref_conv_grad = conv.weight.grad + + # Perform forward and backward pass for distributed convolution + conv.zero_grad() + dist_conv = DistConvDDP(conv, parallel_strategy=parallel_strategy) + dcx = DCTensor.distribute(x, parallel_strategy) + dcx_periodic = torch.nn.functional.pad( + input=dcx, pad=explicit_padding, mode="circular" + ) + dcy = dist_conv(dcx_periodic) + dcy_merge = dcy.to_replicate() + dc_loss = dcy.to_ddp().square().mean() + dc_loss.backward() + x_grad = dcx.grad.to_replicate() + dc_conv_grad = conv.weight.grad + + # Validate the results + assert fp32_allclose(ref_y, dcy_merge) + assert fp32_allclose(ref_x_grad, x_grad) + assert fp32_allclose(ref_conv_grad, dc_conv_grad) diff --git a/tests/test_ddp_with_distconv.py b/tests/test_ddp_with_distconv.py index f683d21..189c38e 100644 --- a/tests/test_ddp_with_distconv.py +++ b/tests/test_ddp_with_distconv.py @@ -2,11 +2,11 @@ import torch import torch.distributed as dist import torch.nn as nn +from distconv import DCTensor, DistConvDDP, ParallelStrategy from torch.distributed.tensor import Replicate, Shard, distribute_tensor from torch.nn.parallel import DistributedDataParallel as DDP -from utils import cleanup_parallel_strategy, fp32_allclose -from distconv import DCTensor, DistConvDDP, ParallelStrategy +from utils import cleanup_parallel_strategy, fp32_allclose @pytest.fixture(scope="module") diff --git a/tests/test_periodic.py b/tests/test_periodic.py index af6e309..9d5dd14 100644 --- a/tests/test_periodic.py +++ b/tests/test_periodic.py @@ -1,11 +1,13 @@ +from math import ceil + import pytest import torch import torch.distributed as dist import torch.nn as nn -from utils import cleanup_parallel_strategy, fp32_allclose - from distconv import DCTensor, DistConvDDP, ParallelStrategy +from utils import cleanup_parallel_strategy, fp32_allclose + def generate_configs(): configs = [] @@ -50,7 +52,7 @@ def test_periodic( conv_kwargs = dict( kernel_size=kernel_size, - padding=kernel_size // 2, + padding=ceil((kernel_size - stride) / 2), bias=False, stride=stride, padding_mode="circular", diff --git a/tests/test_strides.py b/tests/test_strides.py index efa82a6..d9e421e 100644 --- a/tests/test_strides.py +++ b/tests/test_strides.py @@ -2,10 +2,10 @@ import torch import torch.distributed as dist import torch.nn as nn -from utils import cleanup_parallel_strategy, fp32_allclose - from distconv import DCTensor, DistConvDDP, ParallelStrategy +from utils import cleanup_parallel_strategy, fp32_allclose + @pytest.fixture(scope="module") def parallel_strategy(device: torch.device): diff --git a/tests/utils.py b/tests/utils.py index 891254b..4e6fb39 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,6 +1,5 @@ import torch import torch.distributed as dist - from distconv import ParallelStrategy