Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 68 additions & 9 deletions distconv/distconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,17 @@
from torch.utils._pytree import tree_map


def get_memory_format(tensor: torch.Tensor) -> torch.memory_format:
"""Detect the memory format of a tensor."""
if tensor.dim() == 4 and tensor.is_contiguous(memory_format=torch.channels_last):
return torch.channels_last
elif tensor.dim() == 5 and tensor.is_contiguous(
memory_format=torch.channels_last_3d
):
return torch.channels_last_3d
return torch.contiguous_format
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could it be noncontiguous in other ways? Via some transposition etc.

To clarify, I’d rename the function to get memory format for halo



class ParallelStrategy:
"""
ParallelStrategy defines the strategy for distributing tensors across multiple devices
Expand Down Expand Up @@ -109,6 +120,9 @@ def forward_halo_exchange(
if halo_size == 0:
return tensor

# Detect memory format to preserve throughout operations
memory_format = get_memory_format(tensor)

# Extract parallel strategy parameters
shard_dim = parallel_strategy.shard_dim
num_shards = parallel_strategy.num_shards
Expand All @@ -128,18 +142,30 @@ def forward_halo_exchange(
# Receive halo from the previous rank and send their halo back
ops += [
dist.P2POp(dist.irecv, halo_minus, minus_rank),
dist.P2POp(dist.isend, inner_halo_minus.contiguous(), minus_rank),
dist.P2POp(
dist.isend,
inner_halo_minus.contiguous(memory_format=memory_format),
minus_rank,
),
]
if shard_ind < (num_shards - 1) or is_periodic:
# Send halo to the next rank and receive their halo
ops += [
dist.P2POp(dist.isend, inner_halo_plus.contiguous(), plus_rank),
dist.P2POp(
dist.isend,
inner_halo_plus.contiguous(memory_format=memory_format),
plus_rank,
),
dist.P2POp(dist.irecv, halo_plus, plus_rank),
]
if shard_ind == 0 and is_periodic:
ops += [
dist.P2POp(dist.irecv, halo_minus, minus_rank),
dist.P2POp(dist.isend, inner_halo_minus.contiguous(), minus_rank),
dist.P2POp(
dist.isend,
inner_halo_minus.contiguous(memory_format=memory_format),
minus_rank,
),
]

# Execute communication operations
Expand Down Expand Up @@ -175,6 +201,9 @@ def backward_halo_exchange(
if halo_size == 0:
return tensor

# Detect memory format to preserve throughout operations
memory_format = get_memory_format(tensor)

# Extract parallel strategy parameters
shard_dim = parallel_strategy.shard_dim
num_shards = parallel_strategy.num_shards
Expand All @@ -194,18 +223,30 @@ def backward_halo_exchange(
# Receive halo from previous rank and send their halo back
ops += [
dist.P2POp(dist.irecv, recv_halo_minus, minus_rank),
dist.P2POp(dist.isend, send_halo_minus.contiguous(), minus_rank),
dist.P2POp(
dist.isend,
send_halo_minus.contiguous(memory_format=memory_format),
minus_rank,
),
]
if shard_ind < (num_shards - 1) or is_periodic:
# Send halo to the next rank and receive their halo
ops += [
dist.P2POp(dist.isend, send_halo_plus.contiguous(), plus_rank),
dist.P2POp(
dist.isend,
send_halo_plus.contiguous(memory_format=memory_format),
plus_rank,
),
dist.P2POp(dist.irecv, recv_halo_plus, plus_rank),
]
if shard_ind == 0 and is_periodic:
ops += [
dist.P2POp(dist.irecv, recv_halo_minus, minus_rank),
dist.P2POp(dist.isend, send_halo_minus.contiguous(), minus_rank),
dist.P2POp(
dist.isend,
send_halo_minus.contiguous(memory_format=memory_format),
minus_rank,
),
]

# Execute communication operations
Expand Down Expand Up @@ -436,12 +477,18 @@ def distribute(
Returns:
DCTensor: A new instance of DCTensor with the tensor sharded according to the parallel strategy.
"""
# Preserve memory format through distribution
memory_format = get_memory_format(tensor)
dtensor = distribute_tensor(
tensor,
device_mesh=parallel_strategy.device_mesh["dc"],
placements=[Shard(parallel_strategy.shard_dim)],
)
return cls(dtensor.to_local(), parallel_strategy)
local_tensor = dtensor.to_local()
# DTensor may not preserve memory format, so convert back if needed
if memory_format != torch.contiguous_format:
local_tensor = local_tensor.contiguous(memory_format=memory_format)
return cls(local_tensor, parallel_strategy)

def to_ddp(self) -> torch.Tensor:
"""
Expand All @@ -450,14 +497,20 @@ def to_ddp(self) -> torch.Tensor:
Returns:
torch.Tensor: The tensor resharded to the batch dimension.
"""
# Preserve memory format through redistribution
memory_format = get_memory_format(self._tensor)
device_mesh = self._parallel_strategy.device_mesh["dc"]
shard_dim = self._parallel_strategy.shard_dim
dtensor = DTensor.from_local(
_ToTensor.apply(self),
device_mesh=device_mesh,
placements=[Shard(shard_dim)],
).redistribute(device_mesh=device_mesh, placements=[Shard(0)])
return dtensor.to_local()
local_tensor = dtensor.to_local()
# DTensor may not preserve memory format, so convert back if needed
if memory_format != torch.contiguous_format:
local_tensor = local_tensor.contiguous(memory_format=memory_format)
return local_tensor

def to_replicate(self) -> torch.Tensor:
"""
Expand All @@ -466,14 +519,20 @@ def to_replicate(self) -> torch.Tensor:
Returns:
torch.Tensor: The full tensor.
"""
# Preserve memory format through redistribution
memory_format = get_memory_format(self._tensor)
device_mesh = self._parallel_strategy.device_mesh["dc"]
shard_dim = self._parallel_strategy.shard_dim
dtensor = DTensor.from_local(
_ToTensor.apply(self),
device_mesh=device_mesh,
placements=[Shard(shard_dim)],
).redistribute(device_mesh=device_mesh, placements=[Replicate()])
return dtensor.to_local()
local_tensor = dtensor.to_local()
# DTensor may not preserve memory format, so convert back if needed
if memory_format != torch.contiguous_format:
local_tensor = local_tensor.contiguous(memory_format=memory_format)
return local_tensor

@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
Expand Down
189 changes: 189 additions & 0 deletions tests/test_channels_last.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
import pytest
import torch
import torch.distributed as dist
import torch.nn as nn
from utils import cleanup_parallel_strategy, fp32_allclose

from distconv import DCTensor, DistConvDDP, ParallelStrategy


@pytest.fixture(scope="module")
def parallel_strategy(device: torch.device):
ps = ParallelStrategy(num_shards=4, device_type=device.type)
yield ps
cleanup_parallel_strategy(ps)


def generate_channels_last_configs():
"""Generate test configurations for channels_last testing."""
configs = []
# Test 2D convolutions with channels_last (ndims=2 -> Conv2d)
for shard_dim in range(2): # H or W dimension
for kernel_size in [1, 3, 5]:
configs.append((2, shard_dim, kernel_size, torch.channels_last))

# Test 3D convolutions with channels_last_3d (ndims=3 -> Conv3d)
for shard_dim in range(3): # D, H, or W dimension
for kernel_size in [1, 3, 5]:
configs.append((3, shard_dim, kernel_size, torch.channels_last_3d))

return "ndims,shard_dim,kernel_size,memory_format", configs


@pytest.mark.parametrize(*generate_channels_last_configs())
def test_channels_last_forward_backward(
parallel_strategy: ParallelStrategy,
ndims: int,
shard_dim: int,
kernel_size: int,
memory_format: torch.memory_format,
device: torch.device,
):
"""
Test distributed convolution with channels_last memory format.
Verifies correctness and that memory format is preserved.

Args:
parallel_strategy (ParallelStrategy): Parallel strategy for the distributed convolution.
ndims (int): Number of dimensions for the convolution (2 or 3).
shard_dim (int): Dimension along which the tensor is sharded.
kernel_size (int): Size of the convolution kernel.
memory_format (torch.memory_format): Memory format to use.
device (torch.device): Torch device to run test with.
"""
parallel_strategy.shard_dim = 2 + shard_dim

# Create input tensor with channels_last format
shape = [1, 4] + [64] * ndims
x = torch.randn(*shape, device=device).to(memory_format=memory_format).requires_grad_(True)

# Verify input is in channels_last format
assert x.is_contiguous(memory_format=memory_format)

# Create convolution layer with channels_last format
conv_class = getattr(nn, f"Conv{ndims}d")
conv = conv_class(4, 8, kernel_size=kernel_size, padding=kernel_size // 2).to(
device
)
conv = conv.to(memory_format=memory_format)

# Reference forward/backward
conv.zero_grad()
ref_y = conv(x)
ref_y.square().mean().backward()
ref_x_grad = x.grad.clone()
ref_conv_grad = conv.weight.grad.clone()

# Distributed forward/backward
conv.zero_grad()
x.grad = None
ddp_conv = DistConvDDP(conv, parallel_strategy=parallel_strategy)
dcx = DCTensor.distribute(x, parallel_strategy)

# Verify DCTensor preserves channels_last format
assert dcx._tensor.is_contiguous(memory_format=memory_format), (
"DCTensor did not preserve channels_last format"
)

dcy = ddp_conv(dcx)

# Verify output preserves channels_last format
assert dcy._tensor.is_contiguous(memory_format=memory_format), (
"Output did not preserve channels_last format"
)

ddpy = dcy.to_ddp()
ddpy.square().mean().backward()
x_grad = dcx.grad.to_ddp()
dc_conv_grad = conv.weight.grad

# Validate numerical correctness
if dist.get_rank() == 0:
assert fp32_allclose(ref_y, ddpy)
else:
assert ddpy.numel() == 0
assert fp32_allclose(ref_x_grad, x_grad)
assert fp32_allclose(ref_conv_grad, dc_conv_grad)


def generate_periodic_channels_last_configs():
"""Generate test configurations for periodic padding with channels_last."""
configs = []
# Test 2D convolutions with channels_last
for shard_dim in range(2):
for kernel_size in [3, 5]:
configs.append((2, shard_dim, kernel_size, torch.channels_last))

# Test 3D convolutions with channels_last_3d
for shard_dim in range(3):
for kernel_size in [3, 5]:
configs.append((3, shard_dim, kernel_size, torch.channels_last_3d))

return "ndims,shard_dim,kernel_size,memory_format", configs


@pytest.mark.parametrize(*generate_periodic_channels_last_configs())
def test_channels_last_periodic_padding(
parallel_strategy: ParallelStrategy,
ndims: int,
shard_dim: int,
kernel_size: int,
memory_format: torch.memory_format,
device: torch.device,
):
"""
Test periodic padding with channels_last format.

Args:
parallel_strategy (ParallelStrategy): Parallel strategy for the distributed convolution.
ndims (int): Number of dimensions for the convolution (2 or 3).
shard_dim (int): Dimension along which the tensor is sharded.
kernel_size (int): Size of the convolution kernel.
memory_format (torch.memory_format): Memory format to use.
device (torch.device): Torch device to run test with.
"""
parallel_strategy.shard_dim = 2 + shard_dim

# Create input tensor with channels_last format
shape = [1, 4] + [64] * ndims
x = torch.randn(*shape, device=device).to(memory_format=memory_format).requires_grad_(True)

# Create convolution layer with circular padding
conv_class = getattr(nn, f"Conv{ndims}d")
conv = conv_class(
4, 8, kernel_size=kernel_size, padding=kernel_size // 2, padding_mode="circular"
).to(device)
conv = conv.to(memory_format=memory_format)

# Reference forward/backward
conv.zero_grad()
ref_y = conv(x)
ref_y.square().mean().backward()
ref_x_grad = x.grad.clone()
ref_conv_grad = conv.weight.grad.clone()

# Distributed forward/backward
conv.zero_grad()
x.grad = None
ddp_conv = DistConvDDP(conv, parallel_strategy=parallel_strategy)
dcx = DCTensor.distribute(x, parallel_strategy)

dcy = ddp_conv(dcx)

# Verify output preserves channels_last format
assert dcy._tensor.is_contiguous(memory_format=memory_format), (
"Output did not preserve channels_last format with periodic padding"
)

ddpy = dcy.to_ddp()
ddpy.square().mean().backward()
x_grad = dcx.grad.to_ddp()
dc_conv_grad = conv.weight.grad

# Validate numerical correctness
if dist.get_rank() == 0:
assert fp32_allclose(ref_y, ddpy)
else:
assert ddpy.numel() == 0
assert fp32_allclose(ref_x_grad, x_grad)
assert fp32_allclose(ref_conv_grad, dc_conv_grad)