Skip to content

Periodic padding memory leak #16

@jvwilliams23

Description

@jvwilliams23

The current implementation of _handle_circular_pad yields a memory leak.

For example, the following script:

import os
import torch
import torch.nn as nn
import torch.distributed as dist
from distconv import ParallelStrategy, DCTensor, DistConvDDP

def main(parallel_strategy, padding_mode="zeros"):

    conv_kwargs = dict(
        kernel_size=kernel_size,
        padding=kernel_size // 2,
        bias=False,
        stride=1,
        padding_mode=padding_mode,
    )

    # Initialize the input tensor and convolution layer
    shape = [1, 4] + [4000] * ndims
    conv_class = getattr(nn, f"Conv{ndims}d")
    conv = conv_class(4, 8, **conv_kwargs).to(device).requires_grad_(False)
    torch.nn.init.ones_(conv.weight)
    conv.requires_grad_(True)
    ddp_conv = DistConvDDP(conv, parallel_strategy=parallel_strategy)

    # Perform forward and backward pass for distributed convolution
    x = torch.randn(*shape, device=device, requires_grad=True)
    for _ in range(10):
        conv.zero_grad()
        dcx = DCTensor.distribute(x, parallel_strategy)
        dcy = ddp_conv(dcx)
        dcy.to_ddp().square().mean().backward()
        if local_rank == 0:
            print(f"{padding_mode} gpumem {(torch.cuda.max_memory_allocated(device) / 2**30):<6.4f}")


local_rank = int(os.environ["LOCAL_RANK"])

if __name__ == "__main__":

    ndims = 2
    device = torch.device(local_rank)
    torch.cuda.set_device(device)
    dist.init_process_group("nccl")
    kernel_size = 3
    parallel_strategy = ParallelStrategy(num_shards=2, device_type=device.type)

    for padding_mode in ["zeros", "circular"]:
        main(parallel_strategy, padding_mode)
        if local_rank == 0:
            print(f"final {padding_mode} gpumem {(torch.cuda.max_memory_allocated(device) / 2**30):<6.4f}")
    dist.destroy_process_group()

gives the following output on my machine:

zeros gpumem 2.9804
zeros gpumem 2.9804
zeros gpumem 2.9804
zeros gpumem 2.9804
zeros gpumem 2.9804
zeros gpumem 2.9804
zeros gpumem 2.9804
zeros gpumem 2.9804
zeros gpumem 2.9804
zeros gpumem 2.9804
final zeros gpumem 2.9804
circular gpumem 3.0996
circular gpumem 3.3380
circular gpumem 3.5765
circular gpumem 3.8149
circular gpumem 4.0533
circular gpumem 4.2917
circular gpumem 4.5301
circular gpumem 4.7686
circular gpumem 5.0070
circular gpumem 5.2454
final circular gpumem 5.2454

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions