Skip to content

Commit 71adffe

Browse files
author
Nytrynox
committed
fix: clone async checkpoint tensors to CPU to prevent GPU OOM
Move tensor cloning to CPU in AsyncCheckpointIO._clone_tensor() to prevent doubling GPU memory usage during async checkpoint saves. Previously, _clone_tensor() called t.detach().clone() which allocates new GPU memory for each cloned tensor. For large model checkpoints (e.g., 15GB+), this can cause GPU OOM errors since the entire checkpoint is temporarily duplicated in GPU memory. The fix changes the operation to t.detach().cpu().clone(), which moves tensors to CPU before cloning. CPU memory is typically abundant and this achieves the same race-condition prevention without the GPU memory overhead. Fixes #21630
1 parent 612ab08 commit 71adffe

File tree

2 files changed

+21
-3
lines changed

2 files changed

+21
-3
lines changed

src/lightning/pytorch/plugins/io/async_plugin.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,8 @@ def teardown(self) -> None:
9595

9696
# snapshot the checkpoint payload on the caller thread to avoid races with parameter mutation
9797
def _clone_tensor(t: torch.Tensor) -> torch.Tensor:
98-
"""Clones a tensor on the caller thread."""
99-
# detach to avoid autograd history and clone to take a point-in-time copy
100-
return t.detach().clone()
98+
"""Clones a tensor to CPU on the caller thread."""
99+
# detach to avoid autograd history, move to CPU to avoid doubling GPU memory usage, and clone to take a
100+
# point-in-time copy. Moving to CPU first is important because clone() on a CUDA tensor allocates new GPU memory,
101+
# which can cause OOM errors for large model checkpoints.
102+
return t.detach().cpu().clone()

tests/tests_pytorch/plugins/test_async_checkpoint.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,3 +51,19 @@ def test_async_checkpoint_should_snapshot_values_before_mutation():
5151
"AsyncCheckpointIO must snapshot the checkpoint (clone tensors) on the main thread "
5252
"to avoid races with parameter mutation; got mutated value instead"
5353
)
54+
55+
56+
@pytest.mark.filterwarnings("ignore::DeprecationWarning")
57+
def test_async_checkpoint_clones_tensors_to_cpu():
58+
"""Verify that _clone_tensor moves tensors to CPU to avoid doubling GPU memory usage."""
59+
from lightning.pytorch.plugins.io.async_plugin import _clone_tensor
60+
61+
t = torch.tensor([1.0, 2.0, 3.0])
62+
cloned = _clone_tensor(t)
63+
64+
# cloned tensor should be on CPU
65+
assert cloned.device == torch.device("cpu"), f"Expected CPU tensor, got {cloned.device}"
66+
# values should match
67+
assert torch.equal(cloned, t)
68+
# cloned tensor should not share storage with the original
69+
assert cloned.data_ptr() != t.data_ptr()

0 commit comments

Comments
 (0)