-
Notifications
You must be signed in to change notification settings - Fork 2
Open
Description
The current implementation of _handle_circular_pad yields a memory leak.
For example, the following script:
import os
import torch
import torch.nn as nn
import torch.distributed as dist
from distconv import ParallelStrategy, DCTensor, DistConvDDP
def main(parallel_strategy, padding_mode="zeros"):
conv_kwargs = dict(
kernel_size=kernel_size,
padding=kernel_size // 2,
bias=False,
stride=1,
padding_mode=padding_mode,
)
# Initialize the input tensor and convolution layer
shape = [1, 4] + [4000] * ndims
conv_class = getattr(nn, f"Conv{ndims}d")
conv = conv_class(4, 8, **conv_kwargs).to(device).requires_grad_(False)
torch.nn.init.ones_(conv.weight)
conv.requires_grad_(True)
ddp_conv = DistConvDDP(conv, parallel_strategy=parallel_strategy)
# Perform forward and backward pass for distributed convolution
x = torch.randn(*shape, device=device, requires_grad=True)
for _ in range(10):
conv.zero_grad()
dcx = DCTensor.distribute(x, parallel_strategy)
dcy = ddp_conv(dcx)
dcy.to_ddp().square().mean().backward()
if local_rank == 0:
print(f"{padding_mode} gpumem {(torch.cuda.max_memory_allocated(device) / 2**30):<6.4f}")
local_rank = int(os.environ["LOCAL_RANK"])
if __name__ == "__main__":
ndims = 2
device = torch.device(local_rank)
torch.cuda.set_device(device)
dist.init_process_group("nccl")
kernel_size = 3
parallel_strategy = ParallelStrategy(num_shards=2, device_type=device.type)
for padding_mode in ["zeros", "circular"]:
main(parallel_strategy, padding_mode)
if local_rank == 0:
print(f"final {padding_mode} gpumem {(torch.cuda.max_memory_allocated(device) / 2**30):<6.4f}")
dist.destroy_process_group()gives the following output on my machine:
zeros gpumem 2.9804
zeros gpumem 2.9804
zeros gpumem 2.9804
zeros gpumem 2.9804
zeros gpumem 2.9804
zeros gpumem 2.9804
zeros gpumem 2.9804
zeros gpumem 2.9804
zeros gpumem 2.9804
zeros gpumem 2.9804
final zeros gpumem 2.9804
circular gpumem 3.0996
circular gpumem 3.3380
circular gpumem 3.5765
circular gpumem 3.8149
circular gpumem 4.0533
circular gpumem 4.2917
circular gpumem 4.5301
circular gpumem 4.7686
circular gpumem 5.0070
circular gpumem 5.2454
final circular gpumem 5.2454
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels