-
Notifications
You must be signed in to change notification settings - Fork 3.7k
AsyncCheckpointIO Should Clone() to CPU not on GPU #21630
Copy link
Copy link
Open
Labels
bugSomething isn't workingSomething isn't workingneeds triageWaiting to be triaged by maintainersWaiting to be triaged by maintainersver: 2.7.x
Description
Bug description
In pytorch-lightning/src/lightning/pytorch/plugins/io/async_plugin.py
Tensor in Model Checkpoint are cloned() to prevent race condition when doing async uploading in threads.
# snapshot the checkpoint payload on the caller thread to avoid races with parameter mutation
def _clone_tensor(t: torch.Tensor) -> torch.Tensor:
"""Clones a tensor on the caller thread."""
# detach to avoid autograd history and clone to take a point-in-time copy
return t.detach().clone()
However, here it is clone() from GPU memory to GPU memory. Given GPU memory is often limited, this step is dangerous when num_thread go up.
A more clean solution is to clone the tensors to CPU. This achieves the same purpose without using without using GPU memory. CPU memory is abundant most of the time.
# snapshot the checkpoint payload on the caller thread to avoid races with parameter mutation
def _clone_tensor(t: torch.Tensor) -> torch.Tensor:
"""Clones a tensor on the caller thread."""
# detach to avoid autograd history and clone to take a point-in-time copy
return t.detach().cpu().clone()
What version are you seeing the problem on?
master
Reproduced in studio
No response
How to reproduce the bug
Error messages and logs
If cloning on GPU memory, this is dangerous for large model checkpoint (e.g. 15GB in our case)
Cloning to CPU
[ASYNC CHECKPOINT BEFORE clone] GPU 0: allocated=21.54 GB, reserved=124.30 GB
[ASYNC CHECKPOINT AFTER clone] GPU 0: allocated=21.54 GB, reserved=124.30 GB
Cloning on GPU
[ASYNC CHECKPOINT BEFORE clone] GPU 0: allocated=21.54 GB, reserved=124.30 GB
[ASYNC CHECKPOINT AFTER clone] GPU 0: allocated=37.54 GB, reserved=124.30 GB
Environment
Current environment
#- PyTorch Lightning Version (e.g., 2.6.0):
#- PyTorch Version (e.g., 2.5):
#- Python version (e.g., 3.12):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source):
More info
No response
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workingneeds triageWaiting to be triaged by maintainersWaiting to be triaged by maintainersver: 2.7.x