fix: clone async checkpoint tensors to CPU to prevent GPU OOM#21631
fix: clone async checkpoint tensors to CPU to prevent GPU OOM#21631karthik-idikuda wants to merge 4 commits intoLightning-AI:masterfrom
Conversation
71adffe to
95d6ff2
Compare
Move tensor cloning to CPU in AsyncCheckpointIO._clone_tensor() to prevent doubling GPU memory usage during async checkpoint saves. Previously, _clone_tensor() called t.detach().clone() which allocates new GPU memory for each cloned tensor. For large model checkpoints (e.g., 15GB+), this can cause GPU OOM errors since the entire checkpoint is temporarily duplicated in GPU memory. The fix changes the operation to t.detach().cpu().clone(), which moves tensors to CPU before cloning. CPU memory is typically abundant and this achieves the same race-condition prevention without the GPU memory overhead. Fixes Lightning-AI#21630
95d6ff2 to
865d600
Compare
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #21631 +/- ##
=========================================
- Coverage 87% 79% -8%
=========================================
Files 270 267 -3
Lines 23898 23877 -21
=========================================
- Hits 20678 18799 -1879
- Misses 3220 5078 +1858 |
For CUDA tensors, cpu() already allocates a new host-memory copy, so an additional clone() is unnecessary and wastes memory bandwidth. For CPU tensors cpu() is a no-op, so clone() remains necessary. Co-authored-by: TheGreatFrankie
for more information, see https://pre-commit.ci
|
I am not sure we want this to change. The reason being that a transfer from GPU to CPU is a synchronization point and the whole point of async checkpointing is to avoid those. If in doubt, you can still use synchronous checkpointing. |
🤔 but the original GPU to GPU clone in memory is also a synchronization point. I don't fully understand your point. |
Summary
Clone async checkpoint tensors to CPU to prevent GPU OOM during async saves.
Details
AsyncCheckpointIO._clone_tensor()previously calledt.detach().clone(), which allocates new GPU memory for each cloned tensor. For large model checkpoints (e.g., 15GB+), this doubles GPU memory usage during checkpoint saves, which can cause OOM errors.Before (GPU clone):
After (CPU clone):
Changes
async_plugin.py: Changedt.detach().clone()→t.detach().cpu().clone()in_clone_tensor(). Moving to CPU first avoids doubling GPU memory. CPU memory is typically abundant and this still prevents the race condition with parameter mutation.test_async_checkpoint.py: Addedtest_async_checkpoint_clones_tensors_to_cpu()to verify cloned tensors are on CPU and retain correct values.Fixes #21630
📚 Documentation preview 📚: https://pytorch-lightning--21631.org.readthedocs.build/en/21631/