Skip to content

Commit b55dbc0

Browse files
perf: skip redundant clone() for CUDA tensors in async checkpoint
For CUDA tensors, cpu() already allocates a new host-memory copy, so an additional clone() is unnecessary and wastes memory bandwidth. For CPU tensors cpu() is a no-op, so clone() remains necessary. Co-authored-by: TheGreatFrankie
1 parent 865d600 commit b55dbc0

File tree

2 files changed

+16
-6
lines changed

2 files changed

+16
-6
lines changed

src/lightning/pytorch/plugins/io/async_plugin.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,15 @@ def teardown(self) -> None:
9595

9696
# snapshot the checkpoint payload on the caller thread to avoid races with parameter mutation
9797
def _clone_tensor(t: torch.Tensor) -> torch.Tensor:
98-
"""Clones a tensor to CPU on the caller thread."""
99-
# detach to avoid autograd history, move to CPU to avoid doubling GPU memory usage, and clone to take a
100-
# point-in-time copy. Moving to CPU first is important because clone() on a CUDA tensor allocates new GPU memory,
101-
# which can cause OOM errors for large model checkpoints.
102-
return t.detach().cpu().clone()
98+
"""Clone a tensor to CPU memory.
99+
100+
Detaches from autograd, moves to CPU, and ensures a point-in-time snapshot
101+
that won't be mutated by ongoing training.
102+
103+
For CUDA tensors ``cpu()`` already allocates a new host-memory copy, so an
104+
extra ``clone()`` is unnecessary. For CPU tensors ``cpu()`` is a no-op, so
105+
``clone()`` is required to break storage sharing.
106+
"""
107+
if t.is_cuda:
108+
return t.detach().cpu()
109+
return t.detach().clone()

tests/tests_pytorch/plugins/test_async_checkpoint.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def test_async_checkpoint_should_snapshot_values_before_mutation():
5555

5656
@pytest.mark.filterwarnings("ignore::DeprecationWarning")
5757
def test_async_checkpoint_clones_tensors_to_cpu():
58-
"""Verify that _clone_tensor moves tensors to CPU to avoid doubling GPU memory usage."""
58+
"""Verify that _clone_tensor produces a CPU snapshot that does not share storage."""
5959
from lightning.pytorch.plugins.io.async_plugin import _clone_tensor
6060

6161
t = torch.tensor([1.0, 2.0, 3.0])
@@ -67,3 +67,6 @@ def test_async_checkpoint_clones_tensors_to_cpu():
6767
assert torch.equal(cloned, t)
6868
# cloned tensor should not share storage with the original
6969
assert cloned.data_ptr() != t.data_ptr()
70+
# mutation of the original must not affect the clone
71+
t.add_(1.0)
72+
assert torch.equal(cloned, torch.tensor([1.0, 2.0, 3.0]))

0 commit comments

Comments
 (0)