Skip to content

Commit 4a548c9

Browse files
fix: restore cached AMP step context after no_grad workaround (#21616)
* fix: restore cached AMP step context after no_grad workaround * chore: trigger ci * chore: trigger ci * test: add CUDA coverage for AMP no_grad cache handling
1 parent d3b25f8 commit 4a548c9

File tree

2 files changed

+157
-5
lines changed
  • src/lightning/pytorch/plugins/precision
  • tests/tests_pytorch/plugins/precision

2 files changed

+157
-5
lines changed

src/lightning/pytorch/plugins/precision/amp.py

Lines changed: 54 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# limitations under the License.
1212
from collections.abc import Generator
1313
from contextlib import contextmanager
14-
from typing import Any, Callable, Literal, Optional, Union
14+
from typing import Any, Callable, Literal, Optional, Union, cast
1515

1616
import torch
1717
from torch import Tensor
@@ -27,6 +27,26 @@
2727
from lightning.pytorch.utilities.exceptions import MisconfigurationException
2828

2929

30+
class _AutocastClearCacheOnExit:
31+
"""Proxy a grad-disabling context manager and clear the autocast cache when it exits."""
32+
33+
def __init__(self, context_manager: Any, *, clear_cache: bool) -> None:
34+
self._context_manager = context_manager
35+
self._clear_cache = clear_cache
36+
37+
def __enter__(self) -> Any:
38+
return self._context_manager.__enter__()
39+
40+
def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> Any:
41+
out = self._context_manager.__exit__(exc_type, exc, tb)
42+
if self._clear_cache:
43+
torch.clear_autocast_cache()
44+
return out
45+
46+
def __call__(self, func: Callable[..., Any]) -> Callable[..., Any]:
47+
return self._context_manager(func)
48+
49+
3050
class MixedPrecision(Precision):
3151
"""Plugin for Automatic Mixed Precision (AMP) training with ``torch.autocast``.
3252
@@ -118,9 +138,39 @@ def autocast_context_manager(self) -> torch.autocast:
118138
@override
119139
@contextmanager
120140
def forward_context(self) -> Generator[None, None, None]:
121-
"""Enable autocast context."""
122-
with self.autocast_context_manager():
123-
yield
141+
"""Enable autocast and clear cached casts after nested grad-disabling contexts exit."""
142+
original_no_grad = torch.no_grad
143+
original_inference_mode = torch.inference_mode
144+
145+
def _clear_cache_on_exit(
146+
context_factory: Callable[..., Any], *, clear_cache: Callable[..., bool]
147+
) -> Callable[..., Any]:
148+
def wrapper(*args: Any, **kwargs: Any) -> _AutocastClearCacheOnExit:
149+
return _AutocastClearCacheOnExit(
150+
context_factory(*args, **kwargs),
151+
clear_cache=clear_cache(*args, **kwargs),
152+
)
153+
154+
return wrapper
155+
156+
try:
157+
# Lightning wraps the whole step in a persistent autocast context. If a nested `no_grad` or
158+
# `inference_mode` block creates cached casts there, later grad-enabled forwards in the same step can
159+
# incorrectly reuse them. Clear the autocast cache when such nested contexts exit, while keeping the
160+
# default cached path for normal training.
161+
torch_module = cast(Any, torch)
162+
torch_module.no_grad = _clear_cache_on_exit(original_no_grad, clear_cache=lambda *args, **kwargs: True)
163+
torch_module.inference_mode = _clear_cache_on_exit(
164+
original_inference_mode,
165+
clear_cache=lambda *args, **kwargs: bool(args[0] if args else kwargs.get("mode", True)),
166+
)
167+
dtype = torch.bfloat16 if self.precision == "bf16-mixed" else torch.half
168+
with torch.autocast(self.device, dtype=dtype):
169+
yield
170+
finally:
171+
torch_module = cast(Any, torch)
172+
torch_module.no_grad = original_no_grad
173+
torch_module.inference_mode = original_inference_mode
124174

125175
@override
126176
def state_dict(self) -> dict[str, Any]:

tests/tests_pytorch/plugins/precision/test_amp.py

Lines changed: 103 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
from lightning.pytorch.plugins import MixedPrecision
2222
from lightning.pytorch.utilities import GradClipAlgorithmType
23+
from tests_pytorch.helpers.runif import RunIf
2324

2425

2526
def test_clip_gradients():
@@ -62,10 +63,111 @@ def test_amp_with_no_grad():
6263
x = torch.randn(1, 2)
6364
amp = MixedPrecision(precision="bf16-mixed", device="cpu")
6465

65-
with amp.autocast_context_manager():
66+
with amp.forward_context():
6667
with torch.no_grad():
6768
_ = layer(x)
6869

6970
loss = layer(x).mean()
7071
loss.backward()
7172
assert loss.grad_fn is not None
73+
74+
75+
def test_amp_with_inference_mode():
76+
"""Test that nested `inference_mode` also clears the autocast cache on exit."""
77+
layer = nn.Linear(2, 1)
78+
x = torch.randn(1, 2)
79+
amp = MixedPrecision(precision="bf16-mixed", device="cpu")
80+
81+
with amp.forward_context():
82+
with torch.inference_mode():
83+
_ = layer(x)
84+
85+
loss = layer(x).mean()
86+
loss.backward()
87+
assert loss.grad_fn is not None
88+
89+
90+
def test_amp_forward_context_restores_grad_mode_context_managers():
91+
amp = MixedPrecision(precision="bf16-mixed", device="cpu")
92+
original_no_grad = torch.no_grad
93+
original_inference_mode = torch.inference_mode
94+
95+
with amp.forward_context():
96+
assert torch.no_grad is not original_no_grad
97+
assert torch.inference_mode is not original_inference_mode
98+
99+
assert torch.no_grad is original_no_grad
100+
assert torch.inference_mode is original_inference_mode
101+
102+
103+
@pytest.mark.parametrize(("cache_enabled", "expect_grad"), [(True, False), (False, True)])
104+
def test_torch_autocast_cache_behavior_with_no_grad(cache_enabled, expect_grad):
105+
"""Document the underlying PyTorch autocast behavior that this plugin needs to handle."""
106+
layer = nn.Linear(2, 1)
107+
x = torch.randn(1, 2)
108+
109+
with torch.autocast("cpu", dtype=torch.bfloat16, cache_enabled=cache_enabled):
110+
with torch.no_grad():
111+
_ = layer(x)
112+
113+
loss = layer(x).mean()
114+
if expect_grad:
115+
loss.backward()
116+
assert loss.grad_fn is not None
117+
else:
118+
assert loss.grad_fn is None
119+
with pytest.raises(RuntimeError, match="does not require grad"):
120+
loss.backward()
121+
122+
123+
@RunIf(min_cuda_gpus=1)
124+
@pytest.mark.parametrize(("cache_enabled", "expect_grad"), [(True, False), (False, True)])
125+
def test_torch_autocast_cache_behavior_with_no_grad_cuda(cache_enabled, expect_grad):
126+
"""Document the same autocast cache behavior on CUDA, where the reported regression happens."""
127+
layer = nn.Linear(2, 1, device="cuda")
128+
x = torch.randn(1, 2, device="cuda")
129+
130+
with torch.autocast("cuda", dtype=torch.float16, cache_enabled=cache_enabled):
131+
with torch.no_grad():
132+
_ = layer(x)
133+
134+
loss = layer(x).mean()
135+
if expect_grad:
136+
loss.backward()
137+
assert loss.grad_fn is not None
138+
else:
139+
assert loss.grad_fn is None
140+
with pytest.raises(RuntimeError, match="does not require grad"):
141+
loss.backward()
142+
143+
144+
@RunIf(min_cuda_gpus=1)
145+
def test_amp_with_no_grad_cuda():
146+
"""Test the Lightning workaround on the CUDA path used by the reported regression."""
147+
layer = nn.Linear(2, 1, device="cuda")
148+
x = torch.randn(1, 2, device="cuda")
149+
amp = MixedPrecision(precision="16-mixed", device="cuda")
150+
151+
with amp.forward_context():
152+
with torch.no_grad():
153+
_ = layer(x)
154+
155+
loss = layer(x).mean()
156+
loss.backward()
157+
assert loss.grad_fn is not None
158+
159+
160+
def test_amp_autocast_context_manager_disables_cache():
161+
"""Test that the public autocast context manager preserves the existing no-cache workaround."""
162+
amp = MixedPrecision(precision="bf16-mixed", device="cpu")
163+
164+
with amp.autocast_context_manager():
165+
assert not torch.is_autocast_cache_enabled()
166+
167+
168+
def test_amp_forward_context_keeps_cache_enabled():
169+
"""Test that Lightning's internal step context keeps the cached autocast path enabled."""
170+
amp = MixedPrecision(precision="bf16-mixed", device="cpu")
171+
172+
with amp.forward_context():
173+
assert torch.is_autocast_cache_enabled()

0 commit comments

Comments
 (0)