Skip to content

Commit 805c47d

Browse files
committed
fix: use fsspec for DeepSpeed checkpoint path validation to support remote URIs
Fixes #21635. _validate_checkpoint_directory was wrapping paths in pathlib.Path, which mangles remote URI schemes (e.g. s3:// to s3:/) and uses local-only .is_dir()/.is_file() checks that always return False for remote paths like S3, GCS, or HDFS. Replace with get_filesystem() + fs.isdir()/fs.isfile() from cloud_io, which is the established pattern used by ModelCheckpoint, TorchCheckpointIO, and CheckpointConnector.
1 parent 612ab08 commit 805c47d

File tree

2 files changed

+53
-14
lines changed

2 files changed

+53
-14
lines changed

src/lightning/fabric/strategies/deepspeed.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from lightning.fabric.strategies.ddp import DDPStrategy
3737
from lightning.fabric.strategies.registry import _StrategyRegistry
3838
from lightning.fabric.strategies.strategy import _Sharded
39+
from lightning.fabric.utilities.cloud_io import get_filesystem
3940
from lightning.fabric.utilities.distributed import log
4041
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_6
4142
from lightning.fabric.utilities.load import _move_state_into
@@ -45,6 +46,7 @@
4546

4647
if TYPE_CHECKING:
4748
from deepspeed import DeepSpeedEngine
49+
from fsspec import AbstractFileSystem
4850
from torch.optim.lr_scheduler import _LRScheduler
4951

5052
_DEEPSPEED_AVAILABLE = RequirementCache("deepspeed")
@@ -885,9 +887,9 @@ def _validate_device_index_selection(parallel_devices: list[torch.device]) -> No
885887
)
886888

887889

888-
def _is_deepspeed_checkpoint(path: Path) -> bool:
890+
def _is_deepspeed_checkpoint(path: str, fs: "AbstractFileSystem") -> bool:
889891
"""Heuristic check whether the path points to a top-level DeepSpeed checkpoint directory."""
890-
return path.is_dir() and (path / "checkpoint").is_dir()
892+
return fs.isdir(path) and fs.isdir(f"{path.rstrip('/')}/checkpoint")
891893

892894

893895
def _validate_checkpoint_directory(path: _PATH) -> None:
@@ -903,25 +905,28 @@ def _validate_checkpoint_directory(path: _PATH) -> None:
903905
# ├── latest
904906
# └── zero_to_fp32.py
905907

906-
path = Path(path)
907-
path_is_ds_checkpoint = _is_deepspeed_checkpoint(path)
908-
default_message = f"The provided path is not a valid DeepSpeed checkpoint: {path}"
908+
path_str = str(path)
909+
fs = get_filesystem(path_str)
910+
path_is_ds_checkpoint = _is_deepspeed_checkpoint(path_str, fs)
911+
default_message = f"The provided path is not a valid DeepSpeed checkpoint: {path_str}"
909912

910913
if not path_is_ds_checkpoint:
911914
# Case 1: User may have accidentally passed the subfolder "checkpoint"
912-
parent_is_ds_checkpoint = _is_deepspeed_checkpoint(path.parent)
913-
if parent_is_ds_checkpoint:
915+
parent = path_str.rstrip("/").rsplit("/", 1)[0] if "/" in path_str else ""
916+
if parent and _is_deepspeed_checkpoint(parent, fs):
914917
raise FileNotFoundError(
915918
f"{default_message}. It looks like you passed the path to a subfolder."
916-
f" Try to load using this parent directory instead: {path.parent}"
919+
f" Try to load using this parent directory instead: {parent}"
917920
)
918921
# Case 2: User may have accidentally passed the path to a file inside the "checkpoint" subfolder
919-
parent_parent_is_ds_checkpoint = path.is_file() and _is_deepspeed_checkpoint(path.parent.parent)
920-
if parent_parent_is_ds_checkpoint:
921-
raise FileNotFoundError(
922-
f"{default_message}. It looks like you passed the path to a file inside a DeepSpeed checkpoint folder."
923-
f" Try to load using this parent directory instead: {path.parent.parent}"
924-
)
922+
if parent and fs.isfile(path_str):
923+
grandparent = parent.rstrip("/").rsplit("/", 1)[0] if "/" in parent else ""
924+
if grandparent and _is_deepspeed_checkpoint(grandparent, fs):
925+
raise FileNotFoundError(
926+
f"{default_message}. It looks like you passed the path to a file inside a DeepSpeed"
927+
f" checkpoint folder."
928+
f" Try to load using this parent directory instead: {grandparent}"
929+
)
925930
raise FileNotFoundError(default_message)
926931

927932

tests/tests_fabric/strategies/test_deepspeed.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,40 @@ def test_deepspeed_load_checkpoint_validate_path(tmp_path):
264264
strategy.load_checkpoint(path=checkpoint_path, state={"model": Mock()})
265265

266266

267+
def test_validate_checkpoint_directory_remote_uri():
268+
"""Test that _validate_checkpoint_directory works with remote filesystem URIs (e.g., S3, HDFS)."""
269+
from lightning.fabric.strategies.deepspeed import _validate_checkpoint_directory
270+
271+
mock_fs = Mock()
272+
mock_fs.isdir = Mock(side_effect=lambda p: p in ("s3://bucket/ckpt", "s3://bucket/ckpt/checkpoint"))
273+
mock_fs.isfile = Mock(return_value=False)
274+
275+
with mock.patch("lightning.fabric.strategies.deepspeed.get_filesystem", return_value=mock_fs):
276+
# Should not raise when the remote path is a valid DeepSpeed checkpoint
277+
_validate_checkpoint_directory("s3://bucket/ckpt")
278+
279+
# Verify the URI was NOT mangled (s3:// must stay as s3://, not s3:/)
280+
mock_fs.isdir.assert_any_call("s3://bucket/ckpt")
281+
mock_fs.isdir.assert_any_call("s3://bucket/ckpt/checkpoint")
282+
283+
284+
def test_validate_checkpoint_directory_remote_uri_subfolder_suggestion():
285+
"""Test that the subfolder suggestion works with remote URIs."""
286+
from lightning.fabric.strategies.deepspeed import _validate_checkpoint_directory
287+
288+
mock_fs = Mock()
289+
mock_fs.isdir = Mock(
290+
side_effect=lambda p: p in ("s3://bucket/ckpt", "s3://bucket/ckpt/checkpoint"),
291+
)
292+
mock_fs.isfile = Mock(return_value=False)
293+
294+
with (
295+
mock.patch("lightning.fabric.strategies.deepspeed.get_filesystem", return_value=mock_fs),
296+
pytest.raises(FileNotFoundError, match="Try to load using this parent directory instead: s3://bucket/ckpt"),
297+
):
298+
_validate_checkpoint_directory("s3://bucket/ckpt/checkpoint")
299+
300+
267301
@RunIf(deepspeed=True)
268302
def test_deepspeed_load_checkpoint_no_state(tmp_path):
269303
"""Test that DeepSpeed can't load the full state without access to a model instance from the user."""

0 commit comments

Comments
 (0)