From f4d6078f51b3fbcd2da56abc3a36411857181779 Mon Sep 17 00:00:00 2001 From: Pier Fiedorowicz Date: Thu, 15 Jan 2026 09:42:45 -0800 Subject: [PATCH] Channels last implementation --- distconv/distconv.py | 77 +++++++++++++-- tests/test_channels_last.py | 189 ++++++++++++++++++++++++++++++++++++ 2 files changed, 257 insertions(+), 9 deletions(-) create mode 100644 tests/test_channels_last.py diff --git a/distconv/distconv.py b/distconv/distconv.py index f000683..195f172 100644 --- a/distconv/distconv.py +++ b/distconv/distconv.py @@ -8,6 +8,17 @@ from torch.utils._pytree import tree_map +def get_memory_format(tensor: torch.Tensor) -> torch.memory_format: + """Detect the memory format of a tensor.""" + if tensor.dim() == 4 and tensor.is_contiguous(memory_format=torch.channels_last): + return torch.channels_last + elif tensor.dim() == 5 and tensor.is_contiguous( + memory_format=torch.channels_last_3d + ): + return torch.channels_last_3d + return torch.contiguous_format + + class ParallelStrategy: """ ParallelStrategy defines the strategy for distributing tensors across multiple devices @@ -109,6 +120,9 @@ def forward_halo_exchange( if halo_size == 0: return tensor + # Detect memory format to preserve throughout operations + memory_format = get_memory_format(tensor) + # Extract parallel strategy parameters shard_dim = parallel_strategy.shard_dim num_shards = parallel_strategy.num_shards @@ -128,18 +142,30 @@ def forward_halo_exchange( # Receive halo from the previous rank and send their halo back ops += [ dist.P2POp(dist.irecv, halo_minus, minus_rank), - dist.P2POp(dist.isend, inner_halo_minus.contiguous(), minus_rank), + dist.P2POp( + dist.isend, + inner_halo_minus.contiguous(memory_format=memory_format), + minus_rank, + ), ] if shard_ind < (num_shards - 1) or is_periodic: # Send halo to the next rank and receive their halo ops += [ - dist.P2POp(dist.isend, inner_halo_plus.contiguous(), plus_rank), + dist.P2POp( + dist.isend, + inner_halo_plus.contiguous(memory_format=memory_format), + plus_rank, + ), dist.P2POp(dist.irecv, halo_plus, plus_rank), ] if shard_ind == 0 and is_periodic: ops += [ dist.P2POp(dist.irecv, halo_minus, minus_rank), - dist.P2POp(dist.isend, inner_halo_minus.contiguous(), minus_rank), + dist.P2POp( + dist.isend, + inner_halo_minus.contiguous(memory_format=memory_format), + minus_rank, + ), ] # Execute communication operations @@ -175,6 +201,9 @@ def backward_halo_exchange( if halo_size == 0: return tensor + # Detect memory format to preserve throughout operations + memory_format = get_memory_format(tensor) + # Extract parallel strategy parameters shard_dim = parallel_strategy.shard_dim num_shards = parallel_strategy.num_shards @@ -194,18 +223,30 @@ def backward_halo_exchange( # Receive halo from previous rank and send their halo back ops += [ dist.P2POp(dist.irecv, recv_halo_minus, minus_rank), - dist.P2POp(dist.isend, send_halo_minus.contiguous(), minus_rank), + dist.P2POp( + dist.isend, + send_halo_minus.contiguous(memory_format=memory_format), + minus_rank, + ), ] if shard_ind < (num_shards - 1) or is_periodic: # Send halo to the next rank and receive their halo ops += [ - dist.P2POp(dist.isend, send_halo_plus.contiguous(), plus_rank), + dist.P2POp( + dist.isend, + send_halo_plus.contiguous(memory_format=memory_format), + plus_rank, + ), dist.P2POp(dist.irecv, recv_halo_plus, plus_rank), ] if shard_ind == 0 and is_periodic: ops += [ dist.P2POp(dist.irecv, recv_halo_minus, minus_rank), - dist.P2POp(dist.isend, send_halo_minus.contiguous(), minus_rank), + dist.P2POp( + dist.isend, + send_halo_minus.contiguous(memory_format=memory_format), + minus_rank, + ), ] # Execute communication operations @@ -436,12 +477,18 @@ def distribute( Returns: DCTensor: A new instance of DCTensor with the tensor sharded according to the parallel strategy. """ + # Preserve memory format through distribution + memory_format = get_memory_format(tensor) dtensor = distribute_tensor( tensor, device_mesh=parallel_strategy.device_mesh["dc"], placements=[Shard(parallel_strategy.shard_dim)], ) - return cls(dtensor.to_local(), parallel_strategy) + local_tensor = dtensor.to_local() + # DTensor may not preserve memory format, so convert back if needed + if memory_format != torch.contiguous_format: + local_tensor = local_tensor.contiguous(memory_format=memory_format) + return cls(local_tensor, parallel_strategy) def to_ddp(self) -> torch.Tensor: """ @@ -450,6 +497,8 @@ def to_ddp(self) -> torch.Tensor: Returns: torch.Tensor: The tensor resharded to the batch dimension. """ + # Preserve memory format through redistribution + memory_format = get_memory_format(self._tensor) device_mesh = self._parallel_strategy.device_mesh["dc"] shard_dim = self._parallel_strategy.shard_dim dtensor = DTensor.from_local( @@ -457,7 +506,11 @@ def to_ddp(self) -> torch.Tensor: device_mesh=device_mesh, placements=[Shard(shard_dim)], ).redistribute(device_mesh=device_mesh, placements=[Shard(0)]) - return dtensor.to_local() + local_tensor = dtensor.to_local() + # DTensor may not preserve memory format, so convert back if needed + if memory_format != torch.contiguous_format: + local_tensor = local_tensor.contiguous(memory_format=memory_format) + return local_tensor def to_replicate(self) -> torch.Tensor: """ @@ -466,6 +519,8 @@ def to_replicate(self) -> torch.Tensor: Returns: torch.Tensor: The full tensor. """ + # Preserve memory format through redistribution + memory_format = get_memory_format(self._tensor) device_mesh = self._parallel_strategy.device_mesh["dc"] shard_dim = self._parallel_strategy.shard_dim dtensor = DTensor.from_local( @@ -473,7 +528,11 @@ def to_replicate(self) -> torch.Tensor: device_mesh=device_mesh, placements=[Shard(shard_dim)], ).redistribute(device_mesh=device_mesh, placements=[Replicate()]) - return dtensor.to_local() + local_tensor = dtensor.to_local() + # DTensor may not preserve memory format, so convert back if needed + if memory_format != torch.contiguous_format: + local_tensor = local_tensor.contiguous(memory_format=memory_format) + return local_tensor @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): diff --git a/tests/test_channels_last.py b/tests/test_channels_last.py new file mode 100644 index 0000000..c231dd7 --- /dev/null +++ b/tests/test_channels_last.py @@ -0,0 +1,189 @@ +import pytest +import torch +import torch.distributed as dist +import torch.nn as nn +from utils import cleanup_parallel_strategy, fp32_allclose + +from distconv import DCTensor, DistConvDDP, ParallelStrategy + + +@pytest.fixture(scope="module") +def parallel_strategy(device: torch.device): + ps = ParallelStrategy(num_shards=4, device_type=device.type) + yield ps + cleanup_parallel_strategy(ps) + + +def generate_channels_last_configs(): + """Generate test configurations for channels_last testing.""" + configs = [] + # Test 2D convolutions with channels_last (ndims=2 -> Conv2d) + for shard_dim in range(2): # H or W dimension + for kernel_size in [1, 3, 5]: + configs.append((2, shard_dim, kernel_size, torch.channels_last)) + + # Test 3D convolutions with channels_last_3d (ndims=3 -> Conv3d) + for shard_dim in range(3): # D, H, or W dimension + for kernel_size in [1, 3, 5]: + configs.append((3, shard_dim, kernel_size, torch.channels_last_3d)) + + return "ndims,shard_dim,kernel_size,memory_format", configs + + +@pytest.mark.parametrize(*generate_channels_last_configs()) +def test_channels_last_forward_backward( + parallel_strategy: ParallelStrategy, + ndims: int, + shard_dim: int, + kernel_size: int, + memory_format: torch.memory_format, + device: torch.device, +): + """ + Test distributed convolution with channels_last memory format. + Verifies correctness and that memory format is preserved. + + Args: + parallel_strategy (ParallelStrategy): Parallel strategy for the distributed convolution. + ndims (int): Number of dimensions for the convolution (2 or 3). + shard_dim (int): Dimension along which the tensor is sharded. + kernel_size (int): Size of the convolution kernel. + memory_format (torch.memory_format): Memory format to use. + device (torch.device): Torch device to run test with. + """ + parallel_strategy.shard_dim = 2 + shard_dim + + # Create input tensor with channels_last format + shape = [1, 4] + [64] * ndims + x = torch.randn(*shape, device=device).to(memory_format=memory_format).requires_grad_(True) + + # Verify input is in channels_last format + assert x.is_contiguous(memory_format=memory_format) + + # Create convolution layer with channels_last format + conv_class = getattr(nn, f"Conv{ndims}d") + conv = conv_class(4, 8, kernel_size=kernel_size, padding=kernel_size // 2).to( + device + ) + conv = conv.to(memory_format=memory_format) + + # Reference forward/backward + conv.zero_grad() + ref_y = conv(x) + ref_y.square().mean().backward() + ref_x_grad = x.grad.clone() + ref_conv_grad = conv.weight.grad.clone() + + # Distributed forward/backward + conv.zero_grad() + x.grad = None + ddp_conv = DistConvDDP(conv, parallel_strategy=parallel_strategy) + dcx = DCTensor.distribute(x, parallel_strategy) + + # Verify DCTensor preserves channels_last format + assert dcx._tensor.is_contiguous(memory_format=memory_format), ( + "DCTensor did not preserve channels_last format" + ) + + dcy = ddp_conv(dcx) + + # Verify output preserves channels_last format + assert dcy._tensor.is_contiguous(memory_format=memory_format), ( + "Output did not preserve channels_last format" + ) + + ddpy = dcy.to_ddp() + ddpy.square().mean().backward() + x_grad = dcx.grad.to_ddp() + dc_conv_grad = conv.weight.grad + + # Validate numerical correctness + if dist.get_rank() == 0: + assert fp32_allclose(ref_y, ddpy) + else: + assert ddpy.numel() == 0 + assert fp32_allclose(ref_x_grad, x_grad) + assert fp32_allclose(ref_conv_grad, dc_conv_grad) + + +def generate_periodic_channels_last_configs(): + """Generate test configurations for periodic padding with channels_last.""" + configs = [] + # Test 2D convolutions with channels_last + for shard_dim in range(2): + for kernel_size in [3, 5]: + configs.append((2, shard_dim, kernel_size, torch.channels_last)) + + # Test 3D convolutions with channels_last_3d + for shard_dim in range(3): + for kernel_size in [3, 5]: + configs.append((3, shard_dim, kernel_size, torch.channels_last_3d)) + + return "ndims,shard_dim,kernel_size,memory_format", configs + + +@pytest.mark.parametrize(*generate_periodic_channels_last_configs()) +def test_channels_last_periodic_padding( + parallel_strategy: ParallelStrategy, + ndims: int, + shard_dim: int, + kernel_size: int, + memory_format: torch.memory_format, + device: torch.device, +): + """ + Test periodic padding with channels_last format. + + Args: + parallel_strategy (ParallelStrategy): Parallel strategy for the distributed convolution. + ndims (int): Number of dimensions for the convolution (2 or 3). + shard_dim (int): Dimension along which the tensor is sharded. + kernel_size (int): Size of the convolution kernel. + memory_format (torch.memory_format): Memory format to use. + device (torch.device): Torch device to run test with. + """ + parallel_strategy.shard_dim = 2 + shard_dim + + # Create input tensor with channels_last format + shape = [1, 4] + [64] * ndims + x = torch.randn(*shape, device=device).to(memory_format=memory_format).requires_grad_(True) + + # Create convolution layer with circular padding + conv_class = getattr(nn, f"Conv{ndims}d") + conv = conv_class( + 4, 8, kernel_size=kernel_size, padding=kernel_size // 2, padding_mode="circular" + ).to(device) + conv = conv.to(memory_format=memory_format) + + # Reference forward/backward + conv.zero_grad() + ref_y = conv(x) + ref_y.square().mean().backward() + ref_x_grad = x.grad.clone() + ref_conv_grad = conv.weight.grad.clone() + + # Distributed forward/backward + conv.zero_grad() + x.grad = None + ddp_conv = DistConvDDP(conv, parallel_strategy=parallel_strategy) + dcx = DCTensor.distribute(x, parallel_strategy) + + dcy = ddp_conv(dcx) + + # Verify output preserves channels_last format + assert dcy._tensor.is_contiguous(memory_format=memory_format), ( + "Output did not preserve channels_last format with periodic padding" + ) + + ddpy = dcy.to_ddp() + ddpy.square().mean().backward() + x_grad = dcx.grad.to_ddp() + dc_conv_grad = conv.weight.grad + + # Validate numerical correctness + if dist.get_rank() == 0: + assert fp32_allclose(ref_y, ddpy) + else: + assert ddpy.numel() == 0 + assert fp32_allclose(ref_x_grad, x_grad) + assert fp32_allclose(ref_conv_grad, dc_conv_grad)