diff --git a/distconv/distconv.py b/distconv/distconv.py index 7316163..89c53b3 100644 --- a/distconv/distconv.py +++ b/distconv/distconv.py @@ -246,6 +246,33 @@ def distconv_forward(func: Callable, args: Tuple, kwargs: Dict) -> "DCTensor": parallel_strategy = tensor._parallel_strategy shard_dim = parallel_strategy.shard_dim is_periodic = tensor._is_periodic + + # check for convNd_weight during double backprop + if weight.shape[shard_dim] * dilation[shard_dim - 2] == tensor.shape[shard_dim]: + args[0] = args[0].to_replicate() + args[1] = args[1].to_replicate() + # if is_periodic, account for shard padding + if is_periodic: + pad_map = ( + [ + 0, + ] + * (args[0].ndim - 2) + * 2 + ) + pad_map[2 * (shard_dim - 2)] = tensor._periodic_shard_padding + pad_map[2 * (shard_dim - 2) + 1] = tensor._periodic_shard_padding + args[0] = torch.nn.functional.pad(args[0], pad_map[::-1], mode="circular") + + # Note: DDP already scales the gradients by the world size + grad_reduction_factor = parallel_strategy.ddp_ranks + grad_reduction_factor = dist.get_world_size() / grad_reduction_factor + + grad_w = DCTensor(func(*args, **kwargs), tensor._parallel_strategy) + grad_w._is_periodic = tensor._is_periodic + grad_w._periodic_shard_padding = tensor._periodic_shard_padding + return grad_w / grad_reduction_factor + if is_periodic: assert padding[shard_dim - 2] == 0, ( "Cannot zero-pad a tensor marked for periodic padding on the shard dimension" @@ -284,7 +311,10 @@ def distconv_forward(func: Callable, args: Tuple, kwargs: Dict) -> "DCTensor": out_tensor = func(*args, **kwargs) # Wrap the output tensor in a DCTensor and return it - return DCTensor(out_tensor, parallel_strategy) + out_tensor = DCTensor(out_tensor, parallel_strategy) + out_tensor._is_periodic = tensor._is_periodic + out_tensor._periodic_shard_padding = tensor._periodic_shard_padding + return out_tensor def distconv_backward( @@ -357,6 +387,8 @@ def distconv_backward( # Wrap the gradient input tensor in a DCTensor grad_in_tensor = DCTensor(grad_in_tensor, parallel_strategy) + grad_in_tensor._is_periodic = input_tensor._is_periodic + grad_in_tensor._periodic_shard_padding = input_tensor._periodic_shard_padding # Return the gradients with respect to the input tensor, weight, and bias return grad_in_tensor, grad_weight, grad_bias @@ -545,6 +577,8 @@ def _handle_circular_pad(cls, func, args, kwargs): shard_padding = 0 # Call F.pad with modified padding (shard dim padding disabled) + input_tensor._is_periodic = shard_padding > 0 + input_tensor._periodic_shard_padding = shard_padding new_args = (_ToTensor.apply(input_tensor), tuple(pad_list)) + args[2:] partial_padded_tensor = func(*new_args, **kwargs) @@ -589,7 +623,10 @@ def unwrap(t): def wrap(t): if isinstance(t, torch.Tensor) and not isinstance(t, DCTensor): - return DCTensor(t, self._parallel_strategy) + out = DCTensor(t, self._parallel_strategy) + out._is_periodic = self._is_periodic + out._periodic_shard_padding = self._periodic_shard_padding + return out else: return t @@ -619,10 +656,13 @@ class _FromTensor(Function): @staticmethod def forward(ctx, tensor: torch.Tensor, parallel_strategy: ParallelStrategy): + ctx.parallel_strategy = parallel_strategy return DCTensor(tensor, parallel_strategy) @staticmethod def backward(ctx, grad: DCTensor): + if type(grad) != DCTensor: + grad = DCTensor.from_shard(grad, ctx.parallel_strategy) return _ToTensor.apply(grad), None @@ -640,8 +680,13 @@ class _ToTensor(Function): @staticmethod def forward(ctx, dc_tensor: DCTensor): ctx.parallel_strategy = dc_tensor._parallel_strategy + ctx._is_periodic = dc_tensor._is_periodic + ctx._periodic_shard_padding = dc_tensor._periodic_shard_padding return dc_tensor._tensor @staticmethod def backward(ctx, grad: torch.Tensor): - return _FromTensor.apply(grad, ctx.parallel_strategy) + result = _FromTensor.apply(grad, ctx.parallel_strategy) + result._is_periodic = ctx._is_periodic + result._periodic_shard_padding = ctx._periodic_shard_padding + return result diff --git a/tests/test_doublebackprop.py b/tests/test_doublebackprop.py new file mode 100644 index 0000000..1e9a7dc --- /dev/null +++ b/tests/test_doublebackprop.py @@ -0,0 +1,182 @@ +import pytest +import torch +import torch.nn as nn +from utils import cleanup_parallel_strategy, fp32_allclose + +from distconv import DCTensor, DistConvDDP, ParallelStrategy + + +@pytest.fixture(scope="module") +def parallel_strategy(device: torch.device): + ps = ParallelStrategy(num_shards=2, device_type=device.type) + yield ps + cleanup_parallel_strategy(ps) + + +def generate_configs(): + configs = [] + for ndims in [1, 2, 3]: + for shard_dim in range(ndims): + for kernel_size in [1, 3, 5]: + for num_shards in [2]: + configs.append((ndims, shard_dim, kernel_size, num_shards)) + return "ndims,shard_dim,kernel_size,num_shards", configs + + +@pytest.mark.parametrize(*generate_configs()) +def test_double_backprop_gradientloss( + parallel_strategy: ParallelStrategy, + ndims: int, + shard_dim: int, + kernel_size: int, + num_shards: int, + device: torch.device, +): + """ + Test distributed convolution with different number of dimensions and shard dimensions. + Also consider hybrid spatial-data parallelism. + Checks the output and gradients of the distributed convolution against the reference DDP + convolution. + + Args: + ndims (int): Number of dimensions for the convolution (1, 2, or 3). + shard_dim (int): Dimension along which the tensor is sharded. + kernel_size (int): Size of the convolution kernel. + num_shards (int): Number of spatial partitions for data + device (torch.device): Torch device to run test with. + """ + # Set the shard dimension for the parallel strategy + parallel_strategy.shard_dim = shard_dim + 2 + + conv_kwargs = dict( + kernel_size=kernel_size, + padding=kernel_size // 2, + bias=False, + stride=1, + padding_mode="circular", + ) + + # Initialize the input tensor and convolution layer + shape = [1, 4] + [16] * ndims + x = torch.randn(*shape, device=device, requires_grad=True) + conv_class = getattr(nn, f"Conv{ndims}d") + conv = conv_class(4, 8, **conv_kwargs).to(device).requires_grad_(False) + conv.requires_grad_(True) + + # Perform forward and backward pass for reference (non-distributed) convolution + conv.zero_grad() + ref_y = conv(x) + # find gradient wrt input + ref_grads = torch.autograd.grad( + outputs=[ref_y.sum()], inputs=[x], create_graph=True + )[0] + # find all losses + ref_loss_grad = ref_grads.mean() + ref_loss = ref_loss_grad + ref_loss.backward() + ref_conv_grad = conv.weight.grad.clone() + + # Perform forward and backward pass for distributed convolution + conv.zero_grad() + ddp_conv = DistConvDDP(conv, parallel_strategy=parallel_strategy) + dcx = DCTensor.distribute(x, parallel_strategy) + dcy = ddp_conv(dcx) + ddpy = dcy.to_replicate() + # find gradient wrt input + dc_grads = torch.autograd.grad( + outputs=[ddpy.sum()], inputs=[dcx], create_graph=True + )[0] + dc_grads_rep = dc_grads.to_replicate() + # find all losses + dc_loss_grad = dc_grads_rep.mean() + dc_loss = dc_loss_grad + dc_loss.backward() + dc_conv_grad = ddp_conv.module.weight.grad + + # Validate the results + assert fp32_allclose(ref_loss, dc_loss) + assert fp32_allclose(ref_y, ddpy) + assert fp32_allclose(ref_grads, dc_grads_rep) + assert fp32_allclose(ref_conv_grad, dc_conv_grad) + + +@pytest.mark.parametrize(*generate_configs()) +def test_double_backprop_combinedloss( + parallel_strategy: ParallelStrategy, + ndims: int, + shard_dim: int, + kernel_size: int, + num_shards: int, + device: torch.device, +): + """ + Test distributed convolution with different number of dimensions and shard dimensions. + Also consider hybrid spatial-data parallelism. + Checks the output and gradients of the distributed convolution against the reference DDP + convolution. + + Args: + ndims (int): Number of dimensions for the convolution (1, 2, or 3). + shard_dim (int): Dimension along which the tensor is sharded. + kernel_size (int): Size of the convolution kernel. + num_shards (int): Number of spatial partitions for data + device (torch.device): Torch device to run test with. + """ + # Set the shard dimension for the parallel strategy + parallel_strategy.shard_dim = shard_dim + 2 + + conv_kwargs = dict( + kernel_size=kernel_size, + padding=kernel_size // 2, + bias=False, + stride=1, + padding_mode="circular", + ) + + # Initialize the input tensor and convolution layer + shape = [1, 4] + [16] * ndims + x = torch.randn(*shape, device=device, requires_grad=True) + conv_class = getattr(nn, f"Conv{ndims}d") + conv = conv_class(4, 8, **conv_kwargs).to(device).requires_grad_(False) + conv.requires_grad_(True) + + # Perform forward and backward pass for reference (non-distributed) convolution + conv.zero_grad() + ref_y = conv(x) + # find gradient wrt input + ref_grads = torch.autograd.grad( + outputs=[ref_y.sum()], inputs=[x], create_graph=True + )[0] + # find all losses + ref_loss_y = ref_y.square().norm() + ref_loss_grad = ref_grads.mean() + ref_loss = ref_loss_y + ref_loss_grad + ref_loss.backward() + ref_x_grad = x.grad + ref_conv_grad = conv.weight.grad.clone() + + # Perform forward and backward pass for distributed convolution + conv.zero_grad() + ddp_conv = DistConvDDP(conv, parallel_strategy=parallel_strategy) + dcx = DCTensor.distribute(x, parallel_strategy) + dcy = ddp_conv(dcx) + ddpy = dcy.to_replicate() + # find gradient wrt input + dc_grads = torch.autograd.grad( + outputs=[ddpy.sum()], inputs=[dcx], create_graph=True + )[0] + dc_grads_rep = dc_grads.to_replicate() + # find all losses + dc_loss_y = ddpy.square().norm() + dc_loss_grad = dc_grads_rep.mean() + dc_loss = dc_loss_y + dc_loss_grad + dc_loss.backward() + x_grad = dcx.grad.to_replicate() + dc_conv_grad = ddp_conv.module.weight.grad + + # Validate the results + assert fp32_allclose(ref_loss, dc_loss) + assert fp32_allclose(ref_y, ddpy) + assert fp32_allclose(ref_grads, dc_grads_rep) + assert fp32_allclose(ref_x_grad, x_grad) + assert fp32_allclose(ref_conv_grad, dc_conv_grad)