Skip to content

AsyncCheckpointIO Should Clone() to CPU not on GPU #21630

@TheGreatFrankie

Description

@TheGreatFrankie

Bug description

In pytorch-lightning/src/lightning/pytorch/plugins/io/async_plugin.py

Tensor in Model Checkpoint are cloned() to prevent race condition when doing async uploading in threads.

# snapshot the checkpoint payload on the caller thread to avoid races with parameter mutation
def _clone_tensor(t: torch.Tensor) -> torch.Tensor:
    """Clones a tensor on the caller thread."""
    # detach to avoid autograd history and clone to take a point-in-time copy
    return t.detach().clone()

However, here it is clone() from GPU memory to GPU memory. Given GPU memory is often limited, this step is dangerous when num_thread go up.

A more clean solution is to clone the tensors to CPU. This achieves the same purpose without using without using GPU memory. CPU memory is abundant most of the time.

# snapshot the checkpoint payload on the caller thread to avoid races with parameter mutation
def _clone_tensor(t: torch.Tensor) -> torch.Tensor:
    """Clones a tensor on the caller thread."""
    # detach to avoid autograd history and clone to take a point-in-time copy
    return t.detach().cpu().clone()

What version are you seeing the problem on?

master

Reproduced in studio

No response

How to reproduce the bug

Error messages and logs

If cloning on GPU memory, this is dangerous for large model checkpoint (e.g. 15GB in our case)

Cloning to CPU

[ASYNC CHECKPOINT BEFORE clone] GPU 0: allocated=21.54 GB, reserved=124.30 GB
[ASYNC CHECKPOINT AFTER clone] GPU 0: allocated=21.54 GB, reserved=124.30 GB

Cloning on GPU

[ASYNC CHECKPOINT BEFORE clone] GPU 0: allocated=21.54 GB, reserved=124.30 GB
[ASYNC CHECKPOINT AFTER clone] GPU 0: allocated=37.54 GB, reserved=124.30 GB

Environment

Current environment
#- PyTorch Lightning Version (e.g., 2.6.0):
#- PyTorch Version (e.g., 2.5):
#- Python version (e.g., 3.12):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source):

More info

No response

cc @ethanwharris

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingneeds triageWaiting to be triaged by maintainersver: 2.7.x

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions