Skip to content

fix: clone async checkpoint tensors to CPU to prevent GPU OOM#21631

Open
karthik-idikuda wants to merge 4 commits intoLightning-AI:masterfrom
karthik-idikuda:fix/async-checkpoint-clone-to-cpu
Open

fix: clone async checkpoint tensors to CPU to prevent GPU OOM#21631
karthik-idikuda wants to merge 4 commits intoLightning-AI:masterfrom
karthik-idikuda:fix/async-checkpoint-clone-to-cpu

Conversation

@karthik-idikuda
Copy link
Copy Markdown

@karthik-idikuda karthik-idikuda commented Mar 31, 2026

Summary

Clone async checkpoint tensors to CPU to prevent GPU OOM during async saves.

Details

AsyncCheckpointIO._clone_tensor() previously called t.detach().clone(), which allocates new GPU memory for each cloned tensor. For large model checkpoints (e.g., 15GB+), this doubles GPU memory usage during checkpoint saves, which can cause OOM errors.

Before (GPU clone):

[ASYNC CHECKPOINT BEFORE clone] GPU 0: allocated=21.54 GB
[ASYNC CHECKPOINT AFTER clone]  GPU 0: allocated=37.54 GB  ← +16 GB!

After (CPU clone):

[ASYNC CHECKPOINT BEFORE clone] GPU 0: allocated=21.54 GB
[ASYNC CHECKPOINT AFTER clone]  GPU 0: allocated=21.54 GB  ← no change

Changes

  • async_plugin.py: Changed t.detach().clone()t.detach().cpu().clone() in _clone_tensor(). Moving to CPU first avoids doubling GPU memory. CPU memory is typically abundant and this still prevents the race condition with parameter mutation.
  • test_async_checkpoint.py: Added test_async_checkpoint_clones_tensors_to_cpu() to verify cloned tensors are on CPU and retain correct values.

Fixes #21630


📚 Documentation preview 📚: https://pytorch-lightning--21631.org.readthedocs.build/en/21631/

@github-actions github-actions bot added the pl Generic label for PyTorch Lightning package label Mar 31, 2026
@karthik-idikuda karthik-idikuda force-pushed the fix/async-checkpoint-clone-to-cpu branch from 71adffe to 95d6ff2 Compare March 31, 2026 10:33
Move tensor cloning to CPU in AsyncCheckpointIO._clone_tensor() to prevent
doubling GPU memory usage during async checkpoint saves.

Previously, _clone_tensor() called t.detach().clone() which allocates new GPU
memory for each cloned tensor. For large model checkpoints (e.g., 15GB+), this
can cause GPU OOM errors since the entire checkpoint is temporarily duplicated
in GPU memory.

The fix changes the operation to t.detach().cpu().clone(), which moves tensors
to CPU before cloning. CPU memory is typically abundant and this achieves the
same race-condition prevention without the GPU memory overhead.

Fixes Lightning-AI#21630
@karthik-idikuda karthik-idikuda force-pushed the fix/async-checkpoint-clone-to-cpu branch from 95d6ff2 to 865d600 Compare March 31, 2026 10:55
@codecov
Copy link
Copy Markdown

codecov bot commented Mar 31, 2026

Codecov Report

❌ Patch coverage is 66.66667% with 1 line in your changes missing coverage. Please review.
✅ Project coverage is 79%. Comparing base (612ab08) to head (019a502).
⚠️ Report is 4 commits behind head on master.
✅ All tests successful. No failed tests found.

❗ There is a different number of reports uploaded between BASE (612ab08) and HEAD (019a502). Click for more details.

HEAD has 920 uploads less than BASE
Flag BASE (612ab08) HEAD (019a502)
cpu 251 42
lightning_fabric 80 0
pytest 125 0
python3.12 72 12
python 18 3
lightning 90 15
python3.11 36 6
python3.13 53 9
python3.12.7 54 9
python3.10 18 3
pytorch_lightning 81 27
pytorch2.7 9 3
pytest-full 126 42
pytorch2.1 18 6
pytorch2.4.1 9 3
pytorch2.5.1 9 3
pytorch2.2.2 9 3
pytorch2.9 18 6
pytorch2.10 18 6
pytorch2.8 18 6
pytorch2.3 9 3
pytorch2.6 9 3
Additional details and impacted files
@@            Coverage Diff            @@
##           master   #21631     +/-   ##
=========================================
- Coverage      87%      79%     -8%     
=========================================
  Files         270      267      -3     
  Lines       23898    23877     -21     
=========================================
- Hits        20678    18799   -1879     
- Misses       3220     5078   +1858     

karthik-idikuda and others added 2 commits April 1, 2026 07:03
For CUDA tensors, cpu() already allocates a new host-memory copy, so an
additional clone() is unnecessary and wastes memory bandwidth. For CPU
tensors cpu() is a no-op, so clone() remains necessary.

Co-authored-by: TheGreatFrankie
@justusschock
Copy link
Copy Markdown
Member

justusschock commented Apr 1, 2026

I am not sure we want this to change. The reason being that a transfer from GPU to CPU is a synchronization point and the whole point of async checkpointing is to avoid those. If in doubt, you can still use synchronous checkpointing.

@TheGreatFrankie
Copy link
Copy Markdown

I am not sure we want this to change. The reason being that a transfer from GPU to CPU is a synchronization point and the whole point of async checkpointing is to avoid those. If in doubt, you can still use synchronous checkpointing.

🤔 but the original GPU to GPU clone in memory is also a synchronization point. I don't fully understand your point.
If GPU to GPU clone is a synchronization point and GPU to CPU clone is also a synchronization point. why don't we clone to CPU to save GPU memory?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

pl Generic label for PyTorch Lightning package

Projects

None yet

Development

Successfully merging this pull request may close these issues.

AsyncCheckpointIO Should Clone() to CPU not on GPU

4 participants