Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 87 additions & 34 deletions distconv/distconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ def check_is_distconv_supported(
stride: List[int],
padding: List[int],
dilation: List[int],
transpose: bool,
output_padding: List[int],
) -> None:
"""
Check if the distributed convolution is supported with the given parameters.
Expand All @@ -60,31 +62,42 @@ def check_is_distconv_supported(
stride (List[int]): The stride of the convolution.
padding (List[int]): The padding added to the input tensor.
dilation (List[int]): The dilation applied to the kernel.
transpose (bool): Is transposed convolution.
dilation (List[int]): The output padding for transposed convolution.

Raises:
Exception: If dilation is not 1.
Exception: If input size is not divisible by stride.
Exception: If kernel size is odd and padding is not equivalent to "same".
Exception: If kernel size is even and padding is not zero.
Exception: If kernel size is even and stride is not divisible by kernel size.
Exception: If local input size is not equal to stride times output size.
Exception: If local output size is not equal to stride times input size for transposed convolution.
"""
shard_dim = tensor_shard_dim - 2
kernel_size = weight.size(tensor_shard_dim)
if dilation[shard_dim] != 1:
raise Exception("DistConv: dilation must be 1")
if tensor.size(tensor_shard_dim) % stride[shard_dim] != 0:
raise Exception("DistConv: input size must be divisible by stride")
if kernel_size % 2 == 1:
if (kernel_size // 2) != padding[shard_dim]:

input_size = tensor.size(tensor_shard_dim)

if not transpose:
output_size = (input_size + 2 * padding[shard_dim] - kernel_size) // stride[
shard_dim
] + 1

if output_size * stride[shard_dim] != input_size:
raise Exception(
'DistConv: when kernel size is odd, padding must be equivalent to "same"'
"DistConv: The input size along the shard dimension must equal the stride times the output size for the local tensors.\n"
+ "This indicates incompatible kernel size, stride, and/or padding for the given input shape and parallel strategy."
)
else:
if padding[shard_dim] != 0:
raise Exception("DistConv: when kernel size is even, padding must be zero")
if stride[shard_dim] % kernel_size != 0:
output_size = (
(input_size - 1) * stride[shard_dim]
- 2 * padding[shard_dim]
+ kernel_size
+ output_padding[shard_dim]
)

if output_size != input_size * stride[shard_dim]:
raise Exception(
"DistConv: when kernel size is even, stride must be divisble by kernel size"
"DistConv: The output size along the shard dimension must equal the stride times the input size for the local tensors.\n"
+ "This indicates incompatible kernel size, stride, padding, and/or output padding for the given input shape and parallel strategy."
)


Expand Down Expand Up @@ -240,29 +253,42 @@ def distconv_forward(func: Callable, args: Tuple, kwargs: Dict) -> "DCTensor":
args = list(args)

# Unpack the necessary arguments
tensor, weight, bias, stride, padding, dilation = args[:6]
tensor, weight, bias, stride, padding, dilation, transpose, output_padding = args[
:8
]

# Extract the parallel strategy and shard dimension from the input tensor
parallel_strategy = tensor._parallel_strategy
shard_dim = parallel_strategy.shard_dim
is_periodic = tensor._is_periodic
if is_periodic:
assert padding[shard_dim - 2] == 0, (
"Cannot zero-pad a tensor marked for periodic padding on the shard dimension"
)
padding[shard_dim - 2] = tensor._periodic_shard_padding
if transpose:
padding[shard_dim - 2] -= (
stride[shard_dim - 2] * tensor._periodic_shard_padding
)
else:
assert padding[shard_dim - 2] == 0, (
"Cannot zero-pad a tensor marked for periodic padding on the shard dimension"
)
padding[shard_dim - 2] = tensor._periodic_shard_padding

# Unwrap the underlying tensor from the DCTensor
torch_tensor = tensor._tensor

# Check if the distributed convolution is supported with the given parameters
check_is_distconv_supported(
shard_dim, torch_tensor, weight, stride, padding, dilation
shard_dim,
torch_tensor,
weight,
stride,
padding,
dilation,
transpose,
output_padding,
)

# Determine the halo size for halo exchange
kernel_size = weight.size(shard_dim)
halo_size = kernel_size // 2 if (kernel_size % 2 == 1) else 0
halo_size = padding[shard_dim - 2]

# Perform forward halo exchange to prepare the tensor for convolution
tensor_with_halo = forward_halo_exchange(
Expand All @@ -276,9 +302,13 @@ def distconv_forward(func: Callable, args: Tuple, kwargs: Dict) -> "DCTensor":
)

# Update the arguments with the tensor including halos and adjusted padding
padding[shard_dim - 2] = 0
if transpose:
padding[shard_dim - 2] += stride[shard_dim - 2] * halo_size
else:
padding[shard_dim - 2] = 0
args[0] = tensor_with_halo
args[4] = padding
args[7] = output_padding

# Perform the convolution operation
out_tensor = func(*args, **kwargs)
Expand All @@ -305,32 +335,51 @@ def distconv_backward(
args = list(args)

# Unpack the necessary arguments
grad_out_tensor, input_tensor, weight, bias_size, stride, padding, dilation = args[
:7
]
(
grad_out_tensor,
input_tensor,
weight,
bias_size,
stride,
padding,
dilation,
transpose,
output_padding,
) = args[:9]

# Extract the parallel strategy and shard dimension from the gradient output tensor
parallel_strategy = grad_out_tensor._parallel_strategy
shard_dim = parallel_strategy.shard_dim
is_periodic = input_tensor._is_periodic
if is_periodic:
assert padding[shard_dim - 2] == 0, (
"Cannot zero-pad a tensor marked for periodic padding on the shard dimension"
)
padding[shard_dim - 2] = input_tensor._periodic_shard_padding
if transpose:
padding[shard_dim - 2] -= (
stride[shard_dim - 2] * input_tensor._periodic_shard_padding
)
else:
assert padding[shard_dim - 2] == 0, (
"Cannot zero-pad a tensor marked for periodic padding on the shard dimension"
)
padding[shard_dim - 2] = input_tensor._periodic_shard_padding

# Unwrap the underlying tensors from the DCTensors
grad_out_tensor = grad_out_tensor._tensor
input_torch_tensor = input_tensor._tensor

# Check if the distributed convolution is supported with the given parameters
check_is_distconv_supported(
shard_dim, input_torch_tensor, weight, stride, padding, dilation
shard_dim,
input_torch_tensor,
weight,
stride,
padding,
dilation,
transpose,
output_padding,
)

# Determine the halo size for halo exchange
kernel_size = weight.size(shard_dim)
halo_size = kernel_size // 2 if (kernel_size % 2 == 1) else 0
halo_size = padding[shard_dim - 2]

# Get the input tensor including halos if available, otherwise perform forward halo exchange
if input_tensor._tensor_with_halo is not None:
Expand All @@ -341,10 +390,14 @@ def distconv_backward(
)

# Update the arguments with the gradient output tensor, input tensor including halos, and adjusted padding
padding[shard_dim - 2] = 0
if transpose:
padding[shard_dim - 2] += stride[shard_dim - 2] * halo_size
else:
padding[shard_dim - 2] = 0
args[0] = grad_out_tensor
args[1] = input_tensor_with_halo
args[5] = padding
args[8] = output_padding

# Perform the backward convolution operation
grad_in_tensor, grad_weight, grad_bias = func(*args, **kwargs)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
import torch
import torch.distributed as dist
import torch.nn as nn
from utils import cleanup_parallel_strategy, fp32_allclose

from distconv import DCTensor, DistConvDDP, ParallelStrategy

from utils import cleanup_parallel_strategy, fp32_allclose


@pytest.fixture(scope="module")
def parallel_strategy(device: torch.device):
Expand Down
Loading