Skip to content

Commit 3e9cda4

Browse files
Copilotvfdev-5
andauthored
fix: cast scaler.scale(loss) result to torch.Tensor to fix pyrefly type error
Agent-Logs-Url: https://github.com/pytorch/ignite/sessions/4d2b947c-13f4-40a8-b7ef-8760007f3602 Co-authored-by: vfdev-5 <2459423+vfdev-5@users.noreply.github.com>
1 parent 69a95b2 commit 3e9cda4

1 file changed

Lines changed: 2 additions & 2 deletions

File tree

ignite/engine/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
from collections.abc import Callable, Mapping, Sequence
4-
from typing import TYPE_CHECKING, Any
4+
from typing import TYPE_CHECKING, Any, cast
55

66
if TYPE_CHECKING:
77
# GradScaler is imported here rather than used as a string literal ("torch.amp.GradScaler")
@@ -221,7 +221,7 @@ def update(engine: Engine, batch: Sequence[torch.Tensor]) -> Any | tuple[torch.T
221221
if gradient_accumulation_steps > 1:
222222
loss = loss / gradient_accumulation_steps
223223
if scaler:
224-
scaler.scale(loss).backward()
224+
cast(torch.Tensor, scaler.scale(loss)).backward()
225225
if engine.state.iteration % gradient_accumulation_steps == 0:
226226
scaler.step(optimizer)
227227
scaler.update()

0 commit comments

Comments
 (0)