Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions areal/api/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,7 @@ class ArchonEngineConfig:
default="Interleaved1F1B",
metadata={
"help": "Pipeline parallel schedule type.",
"choices": ["1F1B", "Interleaved1F1B"],
"choices": ["1F1B", "Interleaved1F1B", "ZBVZeroBubble"],
},
)
# NOTE: The following three PP layer distribution parameters are advanced options
Expand All @@ -466,7 +466,7 @@ class ArchonEngineConfig:
"help": "Number of transformer layers per (virtual) pipeline stage. "
"If set, num_virtual_stages is calculated from num_layers. "
"If None, stages are inferred from schedule type "
"(1 stage/rank for 1F1B, 2 stages/rank for Interleaved1F1B).",
"(1 stage/rank for 1F1B, 2 stages/rank for Interleaved1F1B/ZBVZeroBubble).",
},
)
pp_first_stage_less_layers: int = field(
Expand Down
55 changes: 53 additions & 2 deletions areal/experimental/engine/archon_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@
import torch
import torch.distributed as dist
from torch import nn
from torch.distributed.pipelining.schedules import (
ScheduleDualPipeV,
ScheduleZBVZeroBubble,
get_schedule_class,
)
from transformers import (
AutoConfig,
PretrainedConfig,
Expand Down Expand Up @@ -239,9 +244,9 @@ def create_process_group(

# Pipeline parallel rank
if self.parallel_dims.pp_enabled:
pp_group = self.parallel_dims.get_group("pp")
self._pp_rank = self.parallel_dims.get_mesh("pp").get_local_rank()
self._pp_last_stage_rank = dist.get_process_group_ranks(pp_group)[-1]
# Set in _apply_pipeline_parallelism() after pipeline setup
self._pp_last_stage_rank = None
else:
self._pp_rank = 0
self._pp_last_stage_rank = None
Expand Down Expand Up @@ -297,6 +302,39 @@ def initialize(self, addr: str | None, ft_spec: FinetuneSpec, *args, **kwargs):
ac_config = self._build_ac_config()
enable_compile = self.config.archon.enable_compile

# V-style schedules (ZBVZeroBubble, DualPipeV) split backward into
# I (input grad) and W (weight grad) steps. This is incompatible with:
# 1. torch.compile — its donated buffer optimization assumes a single
# backward pass (retain_graph=False).
# 2. Op-level selective AC — its per-op cache (storage.pop) is consumed
# by the I step, leaving nothing for the W step recompute.
# 3. memory_budget AC — it depends on torch.compile.
# Full AC / layer-level selective AC use standard checkpoint_wrapper
# whose gid-based recompute supports multiple backward passes.
schedule_class = get_schedule_class(self.config.archon.pp_schedule)
v_style_schedules = (ScheduleZBVZeroBubble, ScheduleDualPipeV)
if schedule_class in v_style_schedules:
schedule_name = self.config.archon.pp_schedule
if enable_compile:
self.logger.warning(
f"{schedule_name} is incompatible with torch.compile. "
"Disabling torch.compile."
)
enable_compile = False

if ac_config is not None and (
(
ac_config.mode == "selective"
and ac_config.selective_ac_option == "op"
)
or ac_config.mode == "memory_budget"
):
self.logger.warning(
f"{schedule_name} is incompatible with {ac_config.mode} AC. "
"Falling back to full AC."
)
ac_config.mode = "full"

# Force pad_to_maximum when compile is enabled to avoid dynamic shape issues
if enable_compile and not self.config.pad_to_maximum:
self.logger.info(
Expand Down Expand Up @@ -742,6 +780,7 @@ def onload(self) -> None:
self.is_offload = False

def export_stats(self) -> dict[str, float]:
assert self._initialized
data = stats_tracker.export_all(reduce_group=self.data_parallel_group)
if self.parallel_dims.pp_enabled:
data_list = [data]
Expand Down Expand Up @@ -832,6 +871,18 @@ def _apply_pipeline_parallelism(
# Delete original model to free memory
del self.model

# Determine which rank holds the last pipeline stage
pp_group = self.parallel_dims.get_group("pp")
pp_ranks = dist.get_process_group_ranks(pp_group)
schedule_class = get_schedule_class(self.config.archon.pp_schedule)
v_style_schedules = (ScheduleZBVZeroBubble, ScheduleDualPipeV)
if schedule_class in v_style_schedules:
# V-style: rank 0 holds stages (0, num_stages-1)
self._pp_last_stage_rank = pp_ranks[0]
else:
# Loop-style: last rank has last stage
self._pp_last_stage_rank = pp_ranks[-1]

self.logger.info(
f"PP enabled: has_first={self.pp_has_first_stage}, "
f"has_last={self.pp_has_last_stage}"
Expand Down
47 changes: 30 additions & 17 deletions areal/experimental/models/archon/pipeline_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from torch.distributed.pipelining.schedules import (
PipelineScheduleMulti,
PipelineScheduleSingle,
ScheduleDualPipeV,
ScheduleZBVZeroBubble,
get_schedule_class,
)

Expand Down Expand Up @@ -210,7 +212,7 @@ def pipeline_module_split(
Args:
whole_model: The complete model to split
pp_mesh: Pipeline parallel device mesh
pp_schedule: Schedule type ("1F1B" or "Interleaved1F1B")
pp_schedule: Schedule type ("1F1B", "Interleaved1F1B", or "ZBVZeroBubble")
device: Target device for stages
module_names_per_stage: Module FQNs for each stage

Expand Down Expand Up @@ -297,29 +299,31 @@ def _get_stage_indices() -> tuple[int, ...]:
Examples (pp_degree=4, num_stages=8):
1F1B: Rank 0->(0,), Rank 1->(1,), ...
Interleaved1F1B: Rank 0->(0,4), Rank 1->(1,5), Rank 2->(2,6), Rank 3->(3,7)
ZBVZeroBubble: Rank 0->(0,7), Rank 1->(1,6), Rank 2->(2,5), Rank 3->(3,4)
"""
if num_stages % pp_degree != 0:
raise ValueError(
f"num_stages ({num_stages}) must be divisible by pp_degree ({pp_degree})"
f"num_stages ({num_stages}) must be evenly divisible by "
f"pp_degree ({pp_degree})"
)
stages_per_rank = num_stages // pp_degree

if pp_schedule == "1F1B":
if stages_per_rank != 1:
raise ValueError(
f"1F1B schedule requires exactly 1 stage per rank, "
f"got {stages_per_rank} ({num_stages} stages / {pp_degree} ranks)"
)
return (pp_rank,)
elif pp_schedule == "Interleaved1F1B":
if stages_per_rank < 2:
schedule_class = get_schedule_class(pp_schedule)
v_style_schedules = (ScheduleZBVZeroBubble, ScheduleDualPipeV)
style = "v" if schedule_class in v_style_schedules else "loop"

if style == "v":
if stages_per_rank != 2:
raise ValueError(
f"Interleaved1F1B schedule requires >= 2 stages per rank, "
f"got {stages_per_rank} ({num_stages} stages / {pp_degree} ranks)"
f"V-style schedules require exactly 2 stages per rank, "
f"got {stages_per_rank}"
)
return tuple(pp_rank + s * pp_degree for s in range(stages_per_rank))
stage_v_pairs = list(
zip(range(pp_degree), range(num_stages - 1, pp_degree - 1, -1))
)
return stage_v_pairs[pp_rank]
else:
raise ValueError(f"Unknown pp_schedule: {pp_schedule}")
return tuple(pp_rank + s * pp_degree for s in range(stages_per_rank))

stages: list[PipelineStage] = []
model_parts: list[nn.Module] = []
Expand Down Expand Up @@ -351,7 +355,7 @@ def pipeline_llm(

Workflow:
1. Generate module names for each virtual stage
2. Split model into stages (multiple per rank for Interleaved1F1B)
2. Split model into stages (multiple per rank for Interleaved1F1B/ZBVZeroBubble)
3. Apply parallelization (TP, FSDP) to each model part

Args:
Expand All @@ -364,7 +368,7 @@ def pipeline_llm(

Returns:
Tuple of:
- stages: List of PipelineStage (1 for 1F1B, 2+ for Interleaved1F1B)
- stages: List of PipelineStage (1 for 1F1B, 2+ for Interleaved1F1B/ZBVZeroBubble)
- model_parts: List of model parts
- has_first_stage: Whether this rank has the first stage
- has_last_stage: Whether this rank has the last stage
Expand Down Expand Up @@ -429,6 +433,15 @@ def pipeline_llm(
f"but got {stages_per_rank} (from layers_per_stage={layers_per_stage}). "
f"Use 1F1B schedule for single stage per rank."
)
# V-style schedules require exactly 2 stages per rank
v_style_schedules = (ScheduleZBVZeroBubble, ScheduleDualPipeV)
if schedule_class in v_style_schedules and stages_per_rank != 2:
raise ValueError(
f"{pp_schedule} requires exactly 2 stages per rank, "
f"but got {stages_per_rank}. "
f"Set pp_layers_per_stage to achieve 2 stages per rank, "
f"or let it default (None)."
)
else:
# Default: 1 for single-stage schedules, 2 for multi-stage schedules
stages_per_rank = 1 if is_single_stage_schedule else 2
Expand Down
93 changes: 75 additions & 18 deletions areal/tests/experimental/archon/test_distributed_pp.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,18 @@
pytest areal/tests/experimental/archon/test_distributed_pp.py -v -m multi_gpu

Test configuration:
2 GPU Tests (Core PP):
- test_pp_forward_2gpu: PP=2, tests forward pass matches golden model
- test_pp_backward_2gpu: PP=2, tests gradient flow through stages
2 GPU Tests (Core PP - manual P2P):
- test_pp_forward_2gpu: PP=2, manual activation passing (1F1B)
- test_pp_backward_2gpu: PP=2, manual gradient passing (1F1B)
- test_pp_gradient_correctness_2gpu: PP=2, tests PP gradients match non-PP

4 GPU Tests (Extended PP):
- test_pp_forward_4gpu: PP=4, tests forward with more stages
- test_pp_backward_4gpu: PP=4, tests backward with more stages
4 GPU Tests (Extended PP - manual P2P):
- test_pp_forward_4gpu: PP=4, manual activation passing (1F1B)
- test_pp_backward_4gpu: PP=4, manual gradient passing (1F1B)

Schedule API Tests (2 GPU):
- test_pp_zbv_forward_2gpu: PP=2, schedule.eval() with ZBVZeroBubble
- test_pp_zbv_backward_2gpu: PP=2, schedule.step() with ZBVZeroBubble

PP Combination Tests (4 GPU):
- test_pp_tp_forward_4gpu: PP=2, TP=2, tests PP+TP combination
Expand Down Expand Up @@ -107,34 +111,36 @@ def _run_pp_test_with_torchrun(
@pytest.mark.multi_gpu
@pytest.mark.slow
def test_pp_forward_2gpu():
"""Test PP forward pass with 2 GPUs (pp=2).
"""Test PP forward pass with 2 GPUs (pp=2) via manual P2P.

Validates that PP model output matches golden (non-PP) model output.
Validates that PP model output matches golden (non-PP) model output
using manual activation passing between stages (1F1B only).
"""
if current_platform.device_count() < 2:
pytest.skip("This test requires 2 GPUs")

_run_pp_test_with_torchrun(
"areal/tests/experimental/archon/torchrun/run_pp_tests.py",
n_gpus=2,
extra_args=["--test_type=forward", "--pp_size=2"],
extra_args=["--test_type=forward_p2p", "--pp_size=2"],
)


@pytest.mark.multi_gpu
@pytest.mark.slow
def test_pp_backward_2gpu():
"""Test PP backward pass with 2 GPUs (pp=2).
"""Test PP backward pass with 2 GPUs (pp=2) via manual P2P.

Validates that gradients flow correctly through all PP stages.
Validates that gradients flow correctly through all PP stages
using manual gradient passing between stages (1F1B only).
"""
if current_platform.device_count() < 2:
pytest.skip("This test requires 2 GPUs")

_run_pp_test_with_torchrun(
"areal/tests/experimental/archon/torchrun/run_pp_tests.py",
n_gpus=2,
extra_args=["--test_type=backward", "--pp_size=2"],
extra_args=["--test_type=backward_p2p", "--pp_size=2"],
)


Expand All @@ -161,6 +167,55 @@ def test_pp_gradient_correctness_2gpu():
)


# =============================================================================
# Schedule API Tests (2 GPU)
# =============================================================================


@pytest.mark.multi_gpu
@pytest.mark.slow
def test_pp_zbv_forward_2gpu():
"""Test ZBVZeroBubble forward pass with 2 GPUs (pp=2) via schedule API.

Validates that PP model with ZBVZeroBubble schedule produces correct output
using schedule.eval() API with V-style stage assignment.
"""
if current_platform.device_count() < 2:
pytest.skip("This test requires 2 GPUs")

_run_pp_test_with_torchrun(
"areal/tests/experimental/archon/torchrun/run_pp_tests.py",
n_gpus=2,
extra_args=[
"--test_type=forward_schedule",
"--pp_size=2",
"--pp_schedule=ZBVZeroBubble",
],
)


@pytest.mark.multi_gpu
@pytest.mark.slow
def test_pp_zbv_backward_2gpu():
"""Test ZBVZeroBubble backward pass with 2 GPUs (pp=2) via schedule API.

Validates that gradients flow correctly through all PP stages
using schedule.step() API with ZBVZeroBubble V-style stage assignment.
"""
if current_platform.device_count() < 2:
pytest.skip("This test requires 2 GPUs")

_run_pp_test_with_torchrun(
"areal/tests/experimental/archon/torchrun/run_pp_tests.py",
n_gpus=2,
extra_args=[
"--test_type=backward_schedule",
"--pp_size=2",
"--pp_schedule=ZBVZeroBubble",
],
)


# =============================================================================
# 4 GPU Tests (Extended PP tests)
# =============================================================================
Expand All @@ -169,34 +224,36 @@ def test_pp_gradient_correctness_2gpu():
@pytest.mark.multi_gpu
@pytest.mark.slow
def test_pp_forward_4gpu():
"""Test PP forward pass with 4 GPUs (pp=4).
"""Test PP forward pass with 4 GPUs (pp=4) via manual P2P.

Validates PP with more stages (4 stages instead of 2).
Validates PP with more stages (4 stages instead of 2) using
manual activation passing (1F1B only).
"""
if current_platform.device_count() < 4:
pytest.skip("This test requires 4 GPUs")

_run_pp_test_with_torchrun(
"areal/tests/experimental/archon/torchrun/run_pp_tests.py",
n_gpus=4,
extra_args=["--test_type=forward", "--pp_size=4"],
extra_args=["--test_type=forward_p2p", "--pp_size=4"],
)


@pytest.mark.multi_gpu
@pytest.mark.slow
def test_pp_backward_4gpu():
"""Test PP backward pass with 4 GPUs (pp=4).
"""Test PP backward pass with 4 GPUs (pp=4) via manual P2P.

Validates gradient flow with more stages.
Validates gradient flow with more stages using manual gradient
passing (1F1B only).
"""
if current_platform.device_count() < 4:
pytest.skip("This test requires 4 GPUs")

_run_pp_test_with_torchrun(
"areal/tests/experimental/archon/torchrun/run_pp_tests.py",
n_gpus=4,
extra_args=["--test_type=backward", "--pp_size=4"],
extra_args=["--test_type=backward_p2p", "--pp_size=4"],
)


Expand Down
Loading