Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 48 additions & 3 deletions distconv/distconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,33 @@ def distconv_forward(func: Callable, args: Tuple, kwargs: Dict) -> "DCTensor":
parallel_strategy = tensor._parallel_strategy
shard_dim = parallel_strategy.shard_dim
is_periodic = tensor._is_periodic

# check for convNd_weight during double backprop
if weight.shape[shard_dim] * dilation[shard_dim - 2] == tensor.shape[shard_dim]:
args[0] = args[0].to_replicate()
args[1] = args[1].to_replicate()
# if is_periodic, account for shard padding
if is_periodic:
pad_map = (
[
0,
]
* (args[0].ndim - 2)
* 2
)
pad_map[2 * (shard_dim - 2)] = tensor._periodic_shard_padding
pad_map[2 * (shard_dim - 2) + 1] = tensor._periodic_shard_padding
args[0] = torch.nn.functional.pad(args[0], pad_map[::-1], mode="circular")

# Note: DDP already scales the gradients by the world size
grad_reduction_factor = parallel_strategy.ddp_ranks
grad_reduction_factor = dist.get_world_size() / grad_reduction_factor

grad_w = DCTensor(func(*args, **kwargs), tensor._parallel_strategy)
grad_w._is_periodic = tensor._is_periodic
grad_w._periodic_shard_padding = tensor._periodic_shard_padding
return grad_w / grad_reduction_factor

if is_periodic:
assert padding[shard_dim - 2] == 0, (
"Cannot zero-pad a tensor marked for periodic padding on the shard dimension"
Expand Down Expand Up @@ -284,7 +311,10 @@ def distconv_forward(func: Callable, args: Tuple, kwargs: Dict) -> "DCTensor":
out_tensor = func(*args, **kwargs)

# Wrap the output tensor in a DCTensor and return it
return DCTensor(out_tensor, parallel_strategy)
out_tensor = DCTensor(out_tensor, parallel_strategy)
out_tensor._is_periodic = tensor._is_periodic
out_tensor._periodic_shard_padding = tensor._periodic_shard_padding
return out_tensor


def distconv_backward(
Expand Down Expand Up @@ -357,6 +387,8 @@ def distconv_backward(

# Wrap the gradient input tensor in a DCTensor
grad_in_tensor = DCTensor(grad_in_tensor, parallel_strategy)
grad_in_tensor._is_periodic = input_tensor._is_periodic
grad_in_tensor._periodic_shard_padding = input_tensor._periodic_shard_padding

# Return the gradients with respect to the input tensor, weight, and bias
return grad_in_tensor, grad_weight, grad_bias
Expand Down Expand Up @@ -545,6 +577,8 @@ def _handle_circular_pad(cls, func, args, kwargs):
shard_padding = 0

# Call F.pad with modified padding (shard dim padding disabled)
input_tensor._is_periodic = shard_padding > 0
input_tensor._periodic_shard_padding = shard_padding
new_args = (_ToTensor.apply(input_tensor), tuple(pad_list)) + args[2:]
partial_padded_tensor = func(*new_args, **kwargs)

Expand Down Expand Up @@ -589,7 +623,10 @@ def unwrap(t):

def wrap(t):
if isinstance(t, torch.Tensor) and not isinstance(t, DCTensor):
return DCTensor(t, self._parallel_strategy)
out = DCTensor(t, self._parallel_strategy)
out._is_periodic = self._is_periodic
out._periodic_shard_padding = self._periodic_shard_padding
return out
else:
return t

Expand Down Expand Up @@ -619,10 +656,13 @@ class _FromTensor(Function):

@staticmethod
def forward(ctx, tensor: torch.Tensor, parallel_strategy: ParallelStrategy):
ctx.parallel_strategy = parallel_strategy
return DCTensor(tensor, parallel_strategy)

@staticmethod
def backward(ctx, grad: DCTensor):
if type(grad) != DCTensor:
grad = DCTensor.from_shard(grad, ctx.parallel_strategy)
return _ToTensor.apply(grad), None


Expand All @@ -640,8 +680,13 @@ class _ToTensor(Function):
@staticmethod
def forward(ctx, dc_tensor: DCTensor):
ctx.parallel_strategy = dc_tensor._parallel_strategy
ctx._is_periodic = dc_tensor._is_periodic
ctx._periodic_shard_padding = dc_tensor._periodic_shard_padding
return dc_tensor._tensor

@staticmethod
def backward(ctx, grad: torch.Tensor):
return _FromTensor.apply(grad, ctx.parallel_strategy)
result = _FromTensor.apply(grad, ctx.parallel_strategy)
result._is_periodic = ctx._is_periodic
result._periodic_shard_padding = ctx._periodic_shard_padding
return result
182 changes: 182 additions & 0 deletions tests/test_doublebackprop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
import pytest
import torch
import torch.nn as nn
from utils import cleanup_parallel_strategy, fp32_allclose

from distconv import DCTensor, DistConvDDP, ParallelStrategy


@pytest.fixture(scope="module")
def parallel_strategy(device: torch.device):
ps = ParallelStrategy(num_shards=2, device_type=device.type)
yield ps
cleanup_parallel_strategy(ps)


def generate_configs():
configs = []
for ndims in [1, 2, 3]:
for shard_dim in range(ndims):
for kernel_size in [1, 3, 5]:
for num_shards in [2]:
configs.append((ndims, shard_dim, kernel_size, num_shards))
return "ndims,shard_dim,kernel_size,num_shards", configs


@pytest.mark.parametrize(*generate_configs())
def test_double_backprop_gradientloss(
parallel_strategy: ParallelStrategy,
ndims: int,
shard_dim: int,
kernel_size: int,
num_shards: int,
device: torch.device,
):
"""
Test distributed convolution with different number of dimensions and shard dimensions.
Also consider hybrid spatial-data parallelism.
Checks the output and gradients of the distributed convolution against the reference DDP
convolution.

Args:
ndims (int): Number of dimensions for the convolution (1, 2, or 3).
shard_dim (int): Dimension along which the tensor is sharded.
kernel_size (int): Size of the convolution kernel.
num_shards (int): Number of spatial partitions for data
device (torch.device): Torch device to run test with.
"""
# Set the shard dimension for the parallel strategy
parallel_strategy.shard_dim = shard_dim + 2

conv_kwargs = dict(
kernel_size=kernel_size,
padding=kernel_size // 2,
bias=False,
stride=1,
padding_mode="circular",
)

# Initialize the input tensor and convolution layer
shape = [1, 4] + [16] * ndims
x = torch.randn(*shape, device=device, requires_grad=True)
conv_class = getattr(nn, f"Conv{ndims}d")
conv = conv_class(4, 8, **conv_kwargs).to(device).requires_grad_(False)
conv.requires_grad_(True)

# Perform forward and backward pass for reference (non-distributed) convolution
conv.zero_grad()
ref_y = conv(x)
# find gradient wrt input
ref_grads = torch.autograd.grad(
outputs=[ref_y.sum()], inputs=[x], create_graph=True
)[0]
# find all losses
ref_loss_grad = ref_grads.mean()
ref_loss = ref_loss_grad
ref_loss.backward()
ref_conv_grad = conv.weight.grad.clone()

# Perform forward and backward pass for distributed convolution
conv.zero_grad()
ddp_conv = DistConvDDP(conv, parallel_strategy=parallel_strategy)
dcx = DCTensor.distribute(x, parallel_strategy)
dcy = ddp_conv(dcx)
ddpy = dcy.to_replicate()
# find gradient wrt input
dc_grads = torch.autograd.grad(
outputs=[ddpy.sum()], inputs=[dcx], create_graph=True
)[0]
dc_grads_rep = dc_grads.to_replicate()
# find all losses
dc_loss_grad = dc_grads_rep.mean()
dc_loss = dc_loss_grad
dc_loss.backward()
dc_conv_grad = ddp_conv.module.weight.grad

# Validate the results
assert fp32_allclose(ref_loss, dc_loss)
assert fp32_allclose(ref_y, ddpy)
assert fp32_allclose(ref_grads, dc_grads_rep)
assert fp32_allclose(ref_conv_grad, dc_conv_grad)


@pytest.mark.parametrize(*generate_configs())
def test_double_backprop_combinedloss(
parallel_strategy: ParallelStrategy,
ndims: int,
shard_dim: int,
kernel_size: int,
num_shards: int,
device: torch.device,
):
"""
Test distributed convolution with different number of dimensions and shard dimensions.
Also consider hybrid spatial-data parallelism.
Checks the output and gradients of the distributed convolution against the reference DDP
convolution.

Args:
ndims (int): Number of dimensions for the convolution (1, 2, or 3).
shard_dim (int): Dimension along which the tensor is sharded.
kernel_size (int): Size of the convolution kernel.
num_shards (int): Number of spatial partitions for data
device (torch.device): Torch device to run test with.
"""
# Set the shard dimension for the parallel strategy
parallel_strategy.shard_dim = shard_dim + 2

conv_kwargs = dict(
kernel_size=kernel_size,
padding=kernel_size // 2,
bias=False,
stride=1,
padding_mode="circular",
)

# Initialize the input tensor and convolution layer
shape = [1, 4] + [16] * ndims
x = torch.randn(*shape, device=device, requires_grad=True)
conv_class = getattr(nn, f"Conv{ndims}d")
conv = conv_class(4, 8, **conv_kwargs).to(device).requires_grad_(False)
conv.requires_grad_(True)

# Perform forward and backward pass for reference (non-distributed) convolution
conv.zero_grad()
ref_y = conv(x)
# find gradient wrt input
ref_grads = torch.autograd.grad(
outputs=[ref_y.sum()], inputs=[x], create_graph=True
)[0]
# find all losses
ref_loss_y = ref_y.square().norm()
ref_loss_grad = ref_grads.mean()
ref_loss = ref_loss_y + ref_loss_grad
ref_loss.backward()
ref_x_grad = x.grad
ref_conv_grad = conv.weight.grad.clone()

# Perform forward and backward pass for distributed convolution
conv.zero_grad()
ddp_conv = DistConvDDP(conv, parallel_strategy=parallel_strategy)
dcx = DCTensor.distribute(x, parallel_strategy)
dcy = ddp_conv(dcx)
ddpy = dcy.to_replicate()
# find gradient wrt input
dc_grads = torch.autograd.grad(
outputs=[ddpy.sum()], inputs=[dcx], create_graph=True
)[0]
dc_grads_rep = dc_grads.to_replicate()
# find all losses
dc_loss_y = ddpy.square().norm()
dc_loss_grad = dc_grads_rep.mean()
dc_loss = dc_loss_y + dc_loss_grad
dc_loss.backward()
x_grad = dcx.grad.to_replicate()
dc_conv_grad = ddp_conv.module.weight.grad

# Validate the results
assert fp32_allclose(ref_loss, dc_loss)
assert fp32_allclose(ref_y, ddpy)
assert fp32_allclose(ref_grads, dc_grads_rep)
assert fp32_allclose(ref_x_grad, x_grad)
assert fp32_allclose(ref_conv_grad, dc_conv_grad)