diff --git a/distconv/distconv.py b/distconv/distconv.py index 7e9ed6b..27e1b36 100644 --- a/distconv/distconv.py +++ b/distconv/distconv.py @@ -141,6 +141,8 @@ def check_is_distconv_supported( stride: List[int], padding: List[int], dilation: List[int], + transpose: bool, + output_padding: List[int], ) -> None: """ Check if the distributed convolution is supported with the given parameters. @@ -152,31 +154,42 @@ def check_is_distconv_supported( stride (List[int]): The stride of the convolution. padding (List[int]): The padding added to the input tensor. dilation (List[int]): The dilation applied to the kernel. + transpose (bool): Is transposed convolution. + dilation (List[int]): The output padding for transposed convolution. Raises: - Exception: If dilation is not 1. - Exception: If input size is not divisible by stride. - Exception: If kernel size is odd and padding is not equivalent to "same". - Exception: If kernel size is even and padding is not zero. - Exception: If kernel size is even and stride is not divisible by kernel size. + Exception: If local input size is not equal to stride times output size. + Exception: If local output size is not equal to stride times input size for transposed convolution. """ shard_dim = tensor_shard_dim - 2 kernel_size = weight.size(tensor_shard_dim) if dilation[shard_dim] != 1: raise Exception("DistConv: dilation must be 1") - if tensor.size(tensor_shard_dim) % stride[shard_dim] != 0: - raise Exception("DistConv: input size must be divisible by stride") - if kernel_size % 2 == 1: - if (kernel_size // 2) != padding[shard_dim]: + + input_size = tensor.size(tensor_shard_dim) + + if not transpose: + output_size = (input_size + 2 * padding[shard_dim] - kernel_size) // stride[ + shard_dim + ] + 1 + + if output_size * stride[shard_dim] != input_size: raise Exception( - 'DistConv: when kernel size is odd, padding must be equivalent to "same"' + "DistConv: The input size along the shard dimension must equal the stride times the output size for the local tensors.\n" + + "This indicates incompatible kernel size, stride, and/or padding for the given input shape and parallel strategy." ) else: - if padding[shard_dim] != 0: - raise Exception("DistConv: when kernel size is even, padding must be zero") - if stride[shard_dim] % kernel_size != 0: + output_size = ( + (input_size - 1) * stride[shard_dim] + - 2 * padding[shard_dim] + + kernel_size + + output_padding[shard_dim] + ) + + if output_size != input_size * stride[shard_dim]: raise Exception( - "DistConv: when kernel size is even, stride must be divisble by kernel size" + "DistConv: The output size along the shard dimension must equal the stride times the input size for the local tensors.\n" + + "This indicates incompatible kernel size, stride, padding, and/or output padding for the given input shape and parallel strategy." ) @@ -383,7 +396,9 @@ def distconv_forward(func: Callable, args: Tuple, kwargs: Dict) -> "DCTensor": args = list(args) # Unpack the necessary arguments - tensor, weight, bias, stride, padding, dilation = args[:6] + tensor, weight, bias, stride, padding, dilation, transpose, output_padding = args[ + :8 + ] # Extract the parallel strategy and shard dimension from the input tensor parallel_strategy = tensor._parallel_strategy @@ -391,10 +406,15 @@ def distconv_forward(func: Callable, args: Tuple, kwargs: Dict) -> "DCTensor": is_periodic = tensor._is_periodic for i, shard_dim_i in enumerate(shard_dim): if is_periodic[i]: - assert padding[shard_dim_i - 2] == 0, ( - "Cannot zero-pad a tensor marked for periodic padding on the shard dimension" - ) - padding[shard_dim_i - 2] = tensor._periodic_shard_padding[i] + if transpose: + padding[shard_dim_i - 2] -= ( + stride[shard_dim_i - 2] * tensor._periodic_shard_padding[i] + ) + else: + assert padding[shard_dim_i - 2] == 0, ( + "Cannot zero-pad a tensor marked for periodic padding on the shard dimension" + ) + padding[shard_dim_i - 2] = tensor._periodic_shard_padding[i] # Unwrap the underlying tensor from the DCTensor torch_tensor = tensor._tensor @@ -404,7 +424,14 @@ def distconv_forward(func: Callable, args: Tuple, kwargs: Dict) -> "DCTensor": halo_sizes = [] for i, shard_dim_i in enumerate(shard_dim): check_is_distconv_supported( - shard_dim_i, torch_tensor, weight, stride, padding, dilation + shard_dim_i, + torch_tensor, + weight, + stride, + padding, + dilation, + transpose, + output_padding, ) # Determine the halo size for halo exchange @@ -420,10 +447,14 @@ def distconv_forward(func: Callable, args: Tuple, kwargs: Dict) -> "DCTensor": # Save the tensor with its halo for the backward pass. tensor._tensor_with_halo = tensor_with_halo + if transpose: + padding[shard_dim_i - 2] += stride[shard_dim_i - 2] * halo_size + else: + padding[shard_dim_i - 2] = 0 # Update the arguments with the tensor including halos and adjusted padding args[0] = tensor_with_halo - padding[shard_dim_i - 2] = 0 args[4] = padding + args[7] = output_padding tensor._tensor = tensor_with_halo for i, shard_dim_i in enumerate(shard_dim): @@ -456,9 +487,17 @@ def distconv_backward( args = list(args) # Unpack the necessary arguments - grad_out_tensor, input_tensor, weight, bias_size, stride, padding, dilation = args[ - :7 - ] + ( + grad_out_tensor, + input_tensor, + weight, + bias_size, + stride, + padding, + dilation, + transpose, + output_padding, + ) = args[:9] # Extract the parallel strategy and shard dimension from the gradient output tensor parallel_strategy = grad_out_tensor._parallel_strategy @@ -466,10 +505,15 @@ def distconv_backward( is_periodic = input_tensor._is_periodic for i, shard_dim_i in enumerate(shard_dim): if is_periodic[i]: - assert padding[shard_dim_i - 2] == 0, ( - "Cannot zero-pad a tensor marked for periodic padding on the shard dimension" - ) - padding[shard_dim_i - 2] = input_tensor._periodic_shard_padding[i] + if transpose: + padding[shard_dim_i - 2] -= ( + stride[shard_dim_i - 2] * input_tensor._periodic_shard_padding[i] + ) + else: + assert padding[shard_dim_i - 2] == 0, ( + "Cannot zero-pad a tensor marked for periodic padding on the shard dimension" + ) + padding[shard_dim_i - 2] = input_tensor._periodic_shard_padding[i] # Unwrap the underlying tensors from the DCTensors grad_out_tensor = grad_out_tensor._tensor @@ -478,15 +522,24 @@ def distconv_backward( # Check if the distributed convolution is supported with the given parameters halo_sizes = [] for i, shard_dim_i in enumerate(shard_dim): - check_is_distconv_supported( - shard_dim_i, input_torch_tensor, weight, stride, padding, dilation - ) - # Determine the halo size for halo exchange kernel_size = weight.size(shard_dim_i) halo_size = kernel_size // 2 if (kernel_size % 2 == 1) else 0 halo_sizes.append(halo_size) - padding[shard_dim_i - 2] = 0 + check_is_distconv_supported( + shard_dim_i, + input_torch_tensor, + weight, + stride, + padding, + dilation, + transpose, + output_padding, + ) + if transpose: + padding[shard_dim_i - 2] += stride[shard_dim_i - 2] * halo_size + else: + padding[shard_dim_i - 2] = 0 # Get the input tensor including halos if available, otherwise perform forward halo exchange if input_tensor._tensor_with_halo is not None: @@ -506,6 +559,7 @@ def distconv_backward( args[0] = grad_out_tensor args[1] = input_tensor_with_halo args[5] = padding + args[8] = output_padding # Perform the backward convolution operation grad_in_tensor, grad_weight, grad_bias = func(*args, **kwargs) diff --git a/tests/test_convtranspose.py b/tests/test_convtranspose.py new file mode 100644 index 0000000..202a10d --- /dev/null +++ b/tests/test_convtranspose.py @@ -0,0 +1,170 @@ +import pytest +import torch +import torch.nn as nn +from utils import cleanup_parallel_strategy, fp32_allclose + +from distconv import DCTensor, DistConvDDP, ParallelStrategy + + +@pytest.fixture(scope="module") +def parallel_strategy(device: torch.device): + ps = ParallelStrategy(num_shards=4, device_type=device.type) + yield ps + cleanup_parallel_strategy(ps) + + +def find_padding(kernel_size, stride=1, explicit_padding=False): + ep = kernel_size // 2 if explicit_padding else 0 + pad = (kernel_size + 2 * ep * stride - 1) // 2 + out_pad = stride - 1 + if explicit_padding: + return pad, out_pad, ep + return pad, out_pad + + +def generate_configs(): + configs = [] + for ndims in [1, 2, 3]: + for shard_dim in range(ndims): + for kernel_size in [1, 3, 5]: + for stride in [1, 2, 4]: + configs.append((ndims, shard_dim, kernel_size, stride)) + + return "ndims,shard_dim,kernel_size,stride", configs + + +@pytest.mark.parametrize(*generate_configs()) +def test_transposeconv_zerospadding( + parallel_strategy: ParallelStrategy, + ndims: int, + shard_dim: int, + kernel_size: int, + stride: int, + device: torch.device, +): + """ + Test distributed convolution with different number of dimensions, kernel sizes, and strides. + Checks the output and gradients of the distributed convolution against the non-distributed + convolution. + + Args: + parallel_strategy (ParallelStrategy): Parallel strategy for the distributed convolution. + ndims (int): Number of dimensions for the convolution (1, 2, or 3). + shard_dim (int): Dimension along which the tensor is sharded. + kernel_size (int): Size of the convolution kernel. + stride (int): Stride of the convolution. + device (torch.device): Torch device to run test with. + """ + # Set the shard dimension for the parallel strategy + parallel_strategy.shard_dim = 2 + shard_dim + padding, output_padding = find_padding(kernel_size, stride) + + # Initialize the input tensor and convolution layer + shape = [1, 4] + [64] * ndims + x = torch.randn(*shape, device=device, requires_grad=True) + conv_class = getattr(nn, f"ConvTranspose{ndims}d") + conv = conv_class( + 4, + 8, + kernel_size=kernel_size, + padding=padding, + stride=stride, + output_padding=output_padding, + ).to(device) + + # Perform forward and backward pass for reference (non-distributed) convolution + conv.zero_grad() + ref_y = conv(x) + ref_y.square().mean().backward() + ref_x_grad = x.grad + ref_conv_grad = conv.weight.grad + + # Perform forward and backward pass for distributed convolution + conv.zero_grad() + dist_conv = DistConvDDP(conv, parallel_strategy=parallel_strategy) + dcx = DCTensor.distribute(x, parallel_strategy) + dcy = dist_conv(dcx) + dcy_merge = dcy.to_replicate() + dc_loss = dcy.to_ddp().square().mean() + dc_loss.backward() + x_grad = dcx.grad.to_replicate() + dc_conv_grad = conv.weight.grad + + assert fp32_allclose(ref_y, dcy_merge) + assert fp32_allclose(ref_x_grad, x_grad) + assert fp32_allclose(ref_conv_grad, dc_conv_grad) + + +@pytest.mark.parametrize(*generate_configs()) +def test_transposeconv_circularpadding( + parallel_strategy: ParallelStrategy, + ndims: int, + shard_dim: int, + kernel_size: int, + stride: int, + device: torch.device, +): + """ + Test distributed convolution with different number of dimensions, kernel sizes, and strides. + Checks the output and gradients of the distributed convolution against the non-distributed + convolution. + + Args: + parallel_strategy (ParallelStrategy): Parallel strategy for the distributed convolution. + ndims (int): Number of dimensions for the convolution (1, 2, or 3). + shard_dim (int): Dimension along which the tensor is sharded. + kernel_size (int): Size of the convolution kernel. + stride (int): Stride of the convolution. + device (torch.device): Torch device to run test with. + """ + # Set the shard dimension for the parallel strategy + parallel_strategy.shard_dim = 2 + shard_dim + padding, output_padding, explicit_padding = find_padding( + kernel_size, stride, explicit_padding=True + ) + + # Initialize the input tensor and convolution layer + shape = [1, 4] + [64] * ndims + x = torch.randn(*shape, device=device, requires_grad=True) + + conv_kwargs = dict( + kernel_size=kernel_size, stride=stride, output_padding=output_padding + ) + + # set periodic padding case for reference + explicit_padding = [explicit_padding, explicit_padding] * ndims + x_periodic = torch.nn.functional.pad(input=x, pad=explicit_padding, mode="circular") + + conv_class = getattr(nn, f"ConvTranspose{ndims}d") + conv = ( + conv_class(4, 8, padding=padding, **conv_kwargs) + .to(device) + .requires_grad_(False) + ) + conv.requires_grad_(True) + + # Perform forward and backward pass for reference (non-distributed) convolution + conv.zero_grad() + ref_y = conv(x_periodic) + ref_y.square().mean().backward() + ref_x_grad = x.grad + ref_conv_grad = conv.weight.grad + + # Perform forward and backward pass for distributed convolution + conv.zero_grad() + dist_conv = DistConvDDP(conv, parallel_strategy=parallel_strategy) + dcx = DCTensor.distribute(x, parallel_strategy) + dcx_periodic = torch.nn.functional.pad( + input=dcx, pad=explicit_padding, mode="circular" + ) + dcy = dist_conv(dcx_periodic) + dcy_merge = dcy.to_replicate() + dc_loss = dcy.to_ddp().square().mean() + dc_loss.backward() + x_grad = dcx.grad.to_replicate() + dc_conv_grad = conv.weight.grad + + # Validate the results + assert fp32_allclose(ref_y, dcy_merge) + assert fp32_allclose(ref_x_grad, x_grad) + assert fp32_allclose(ref_conv_grad, dc_conv_grad) diff --git a/tests/test_periodic.py b/tests/test_periodic.py index af6e309..9e13a88 100644 --- a/tests/test_periodic.py +++ b/tests/test_periodic.py @@ -1,3 +1,5 @@ +from math import ceil + import pytest import torch import torch.distributed as dist @@ -50,7 +52,7 @@ def test_periodic( conv_kwargs = dict( kernel_size=kernel_size, - padding=kernel_size // 2, + padding=ceil((kernel_size - stride) / 2), bias=False, stride=stride, padding_mode="circular",