Skip to content

Commit ab411da

Browse files
authored
feat(archon): add ZBVZeroBubble pipeline schedule support (#916)
Add V-style (zero bubble) pipeline scheduling to ArchonEngine. ZBVZeroBubble splits backward into input-grad and weight-grad steps, enabling near-zero pipeline bubbles with 2 stages per rank. Key changes: - V-style stage assignment in _get_stage_indices() (rank 0 gets first and last stages) - Schedule-aware _pp_last_stage_rank determination - Auto-disable torch.compile and op-level selective AC for V-style schedules (incompatible with split backward) - Generalize V-style guards to also cover ScheduleDualPipeV for forward compatibility - Add ZBV forward/backward distributed tests
1 parent e03f32f commit ab411da

File tree

7 files changed

+752
-399
lines changed

7 files changed

+752
-399
lines changed

areal/api/cli_args.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -453,7 +453,7 @@ class ArchonEngineConfig:
453453
default="Interleaved1F1B",
454454
metadata={
455455
"help": "Pipeline parallel schedule type.",
456-
"choices": ["1F1B", "Interleaved1F1B"],
456+
"choices": ["1F1B", "Interleaved1F1B", "ZBVZeroBubble"],
457457
},
458458
)
459459
# NOTE: The following three PP layer distribution parameters are advanced options
@@ -466,7 +466,7 @@ class ArchonEngineConfig:
466466
"help": "Number of transformer layers per (virtual) pipeline stage. "
467467
"If set, num_virtual_stages is calculated from num_layers. "
468468
"If None, stages are inferred from schedule type "
469-
"(1 stage/rank for 1F1B, 2 stages/rank for Interleaved1F1B).",
469+
"(1 stage/rank for 1F1B, 2 stages/rank for Interleaved1F1B/ZBVZeroBubble).",
470470
},
471471
)
472472
pp_first_stage_less_layers: int = field(

areal/experimental/engine/archon_engine.py

Lines changed: 53 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,11 @@
1212
import torch
1313
import torch.distributed as dist
1414
from torch import nn
15+
from torch.distributed.pipelining.schedules import (
16+
ScheduleDualPipeV,
17+
ScheduleZBVZeroBubble,
18+
get_schedule_class,
19+
)
1520
from transformers import (
1621
AutoConfig,
1722
PretrainedConfig,
@@ -239,9 +244,9 @@ def create_process_group(
239244

240245
# Pipeline parallel rank
241246
if self.parallel_dims.pp_enabled:
242-
pp_group = self.parallel_dims.get_group("pp")
243247
self._pp_rank = self.parallel_dims.get_mesh("pp").get_local_rank()
244-
self._pp_last_stage_rank = dist.get_process_group_ranks(pp_group)[-1]
248+
# Set in _apply_pipeline_parallelism() after pipeline setup
249+
self._pp_last_stage_rank = None
245250
else:
246251
self._pp_rank = 0
247252
self._pp_last_stage_rank = None
@@ -297,6 +302,39 @@ def initialize(self, addr: str | None, ft_spec: FinetuneSpec, *args, **kwargs):
297302
ac_config = self._build_ac_config()
298303
enable_compile = self.config.archon.enable_compile
299304

305+
# V-style schedules (ZBVZeroBubble, DualPipeV) split backward into
306+
# I (input grad) and W (weight grad) steps. This is incompatible with:
307+
# 1. torch.compile — its donated buffer optimization assumes a single
308+
# backward pass (retain_graph=False).
309+
# 2. Op-level selective AC — its per-op cache (storage.pop) is consumed
310+
# by the I step, leaving nothing for the W step recompute.
311+
# 3. memory_budget AC — it depends on torch.compile.
312+
# Full AC / layer-level selective AC use standard checkpoint_wrapper
313+
# whose gid-based recompute supports multiple backward passes.
314+
schedule_class = get_schedule_class(self.config.archon.pp_schedule)
315+
v_style_schedules = (ScheduleZBVZeroBubble, ScheduleDualPipeV)
316+
if schedule_class in v_style_schedules:
317+
schedule_name = self.config.archon.pp_schedule
318+
if enable_compile:
319+
self.logger.warning(
320+
f"{schedule_name} is incompatible with torch.compile. "
321+
"Disabling torch.compile."
322+
)
323+
enable_compile = False
324+
325+
if ac_config is not None and (
326+
(
327+
ac_config.mode == "selective"
328+
and ac_config.selective_ac_option == "op"
329+
)
330+
or ac_config.mode == "memory_budget"
331+
):
332+
self.logger.warning(
333+
f"{schedule_name} is incompatible with {ac_config.mode} AC. "
334+
"Falling back to full AC."
335+
)
336+
ac_config.mode = "full"
337+
300338
# Force pad_to_maximum when compile is enabled to avoid dynamic shape issues
301339
if enable_compile and not self.config.pad_to_maximum:
302340
self.logger.info(
@@ -742,6 +780,7 @@ def onload(self) -> None:
742780
self.is_offload = False
743781

744782
def export_stats(self) -> dict[str, float]:
783+
assert self._initialized
745784
data = stats_tracker.export_all(reduce_group=self.data_parallel_group)
746785
if self.parallel_dims.pp_enabled:
747786
data_list = [data]
@@ -832,6 +871,18 @@ def _apply_pipeline_parallelism(
832871
# Delete original model to free memory
833872
del self.model
834873

874+
# Determine which rank holds the last pipeline stage
875+
pp_group = self.parallel_dims.get_group("pp")
876+
pp_ranks = dist.get_process_group_ranks(pp_group)
877+
schedule_class = get_schedule_class(self.config.archon.pp_schedule)
878+
v_style_schedules = (ScheduleZBVZeroBubble, ScheduleDualPipeV)
879+
if schedule_class in v_style_schedules:
880+
# V-style: rank 0 holds stages (0, num_stages-1)
881+
self._pp_last_stage_rank = pp_ranks[0]
882+
else:
883+
# Loop-style: last rank has last stage
884+
self._pp_last_stage_rank = pp_ranks[-1]
885+
835886
self.logger.info(
836887
f"PP enabled: has_first={self.pp_has_first_stage}, "
837888
f"has_last={self.pp_has_last_stage}"

areal/experimental/models/archon/pipeline_parallel.py

Lines changed: 30 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
from torch.distributed.pipelining.schedules import (
1515
PipelineScheduleMulti,
1616
PipelineScheduleSingle,
17+
ScheduleDualPipeV,
18+
ScheduleZBVZeroBubble,
1719
get_schedule_class,
1820
)
1921

@@ -210,7 +212,7 @@ def pipeline_module_split(
210212
Args:
211213
whole_model: The complete model to split
212214
pp_mesh: Pipeline parallel device mesh
213-
pp_schedule: Schedule type ("1F1B" or "Interleaved1F1B")
215+
pp_schedule: Schedule type ("1F1B", "Interleaved1F1B", or "ZBVZeroBubble")
214216
device: Target device for stages
215217
module_names_per_stage: Module FQNs for each stage
216218
@@ -297,29 +299,31 @@ def _get_stage_indices() -> tuple[int, ...]:
297299
Examples (pp_degree=4, num_stages=8):
298300
1F1B: Rank 0->(0,), Rank 1->(1,), ...
299301
Interleaved1F1B: Rank 0->(0,4), Rank 1->(1,5), Rank 2->(2,6), Rank 3->(3,7)
302+
ZBVZeroBubble: Rank 0->(0,7), Rank 1->(1,6), Rank 2->(2,5), Rank 3->(3,4)
300303
"""
301304
if num_stages % pp_degree != 0:
302305
raise ValueError(
303-
f"num_stages ({num_stages}) must be divisible by pp_degree ({pp_degree})"
306+
f"num_stages ({num_stages}) must be evenly divisible by "
307+
f"pp_degree ({pp_degree})"
304308
)
305309
stages_per_rank = num_stages // pp_degree
306310

307-
if pp_schedule == "1F1B":
308-
if stages_per_rank != 1:
309-
raise ValueError(
310-
f"1F1B schedule requires exactly 1 stage per rank, "
311-
f"got {stages_per_rank} ({num_stages} stages / {pp_degree} ranks)"
312-
)
313-
return (pp_rank,)
314-
elif pp_schedule == "Interleaved1F1B":
315-
if stages_per_rank < 2:
311+
schedule_class = get_schedule_class(pp_schedule)
312+
v_style_schedules = (ScheduleZBVZeroBubble, ScheduleDualPipeV)
313+
style = "v" if schedule_class in v_style_schedules else "loop"
314+
315+
if style == "v":
316+
if stages_per_rank != 2:
316317
raise ValueError(
317-
f"Interleaved1F1B schedule requires >= 2 stages per rank, "
318-
f"got {stages_per_rank} ({num_stages} stages / {pp_degree} ranks)"
318+
f"V-style schedules require exactly 2 stages per rank, "
319+
f"got {stages_per_rank}"
319320
)
320-
return tuple(pp_rank + s * pp_degree for s in range(stages_per_rank))
321+
stage_v_pairs = list(
322+
zip(range(pp_degree), range(num_stages - 1, pp_degree - 1, -1))
323+
)
324+
return stage_v_pairs[pp_rank]
321325
else:
322-
raise ValueError(f"Unknown pp_schedule: {pp_schedule}")
326+
return tuple(pp_rank + s * pp_degree for s in range(stages_per_rank))
323327

324328
stages: list[PipelineStage] = []
325329
model_parts: list[nn.Module] = []
@@ -351,7 +355,7 @@ def pipeline_llm(
351355
352356
Workflow:
353357
1. Generate module names for each virtual stage
354-
2. Split model into stages (multiple per rank for Interleaved1F1B)
358+
2. Split model into stages (multiple per rank for Interleaved1F1B/ZBVZeroBubble)
355359
3. Apply parallelization (TP, FSDP) to each model part
356360
357361
Args:
@@ -364,7 +368,7 @@ def pipeline_llm(
364368
365369
Returns:
366370
Tuple of:
367-
- stages: List of PipelineStage (1 for 1F1B, 2+ for Interleaved1F1B)
371+
- stages: List of PipelineStage (1 for 1F1B, 2+ for Interleaved1F1B/ZBVZeroBubble)
368372
- model_parts: List of model parts
369373
- has_first_stage: Whether this rank has the first stage
370374
- has_last_stage: Whether this rank has the last stage
@@ -429,6 +433,15 @@ def pipeline_llm(
429433
f"but got {stages_per_rank} (from layers_per_stage={layers_per_stage}). "
430434
f"Use 1F1B schedule for single stage per rank."
431435
)
436+
# V-style schedules require exactly 2 stages per rank
437+
v_style_schedules = (ScheduleZBVZeroBubble, ScheduleDualPipeV)
438+
if schedule_class in v_style_schedules and stages_per_rank != 2:
439+
raise ValueError(
440+
f"{pp_schedule} requires exactly 2 stages per rank, "
441+
f"but got {stages_per_rank}. "
442+
f"Set pp_layers_per_stage to achieve 2 stages per rank, "
443+
f"or let it default (None)."
444+
)
432445
else:
433446
# Default: 1 for single-stage schedules, 2 for multi-stage schedules
434447
stages_per_rank = 1 if is_single_stage_schedule else 2

areal/tests/experimental/archon/test_distributed_pp.py

Lines changed: 75 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,18 @@
66
pytest areal/tests/experimental/archon/test_distributed_pp.py -v -m multi_gpu
77
88
Test configuration:
9-
2 GPU Tests (Core PP):
10-
- test_pp_forward_2gpu: PP=2, tests forward pass matches golden model
11-
- test_pp_backward_2gpu: PP=2, tests gradient flow through stages
9+
2 GPU Tests (Core PP - manual P2P):
10+
- test_pp_forward_2gpu: PP=2, manual activation passing (1F1B)
11+
- test_pp_backward_2gpu: PP=2, manual gradient passing (1F1B)
1212
- test_pp_gradient_correctness_2gpu: PP=2, tests PP gradients match non-PP
1313
14-
4 GPU Tests (Extended PP):
15-
- test_pp_forward_4gpu: PP=4, tests forward with more stages
16-
- test_pp_backward_4gpu: PP=4, tests backward with more stages
14+
4 GPU Tests (Extended PP - manual P2P):
15+
- test_pp_forward_4gpu: PP=4, manual activation passing (1F1B)
16+
- test_pp_backward_4gpu: PP=4, manual gradient passing (1F1B)
17+
18+
Schedule API Tests (2 GPU):
19+
- test_pp_zbv_forward_2gpu: PP=2, schedule.eval() with ZBVZeroBubble
20+
- test_pp_zbv_backward_2gpu: PP=2, schedule.step() with ZBVZeroBubble
1721
1822
PP Combination Tests (4 GPU):
1923
- test_pp_tp_forward_4gpu: PP=2, TP=2, tests PP+TP combination
@@ -107,34 +111,36 @@ def _run_pp_test_with_torchrun(
107111
@pytest.mark.multi_gpu
108112
@pytest.mark.slow
109113
def test_pp_forward_2gpu():
110-
"""Test PP forward pass with 2 GPUs (pp=2).
114+
"""Test PP forward pass with 2 GPUs (pp=2) via manual P2P.
111115
112-
Validates that PP model output matches golden (non-PP) model output.
116+
Validates that PP model output matches golden (non-PP) model output
117+
using manual activation passing between stages (1F1B only).
113118
"""
114119
if current_platform.device_count() < 2:
115120
pytest.skip("This test requires 2 GPUs")
116121

117122
_run_pp_test_with_torchrun(
118123
"areal/tests/experimental/archon/torchrun/run_pp_tests.py",
119124
n_gpus=2,
120-
extra_args=["--test_type=forward", "--pp_size=2"],
125+
extra_args=["--test_type=forward_p2p", "--pp_size=2"],
121126
)
122127

123128

124129
@pytest.mark.multi_gpu
125130
@pytest.mark.slow
126131
def test_pp_backward_2gpu():
127-
"""Test PP backward pass with 2 GPUs (pp=2).
132+
"""Test PP backward pass with 2 GPUs (pp=2) via manual P2P.
128133
129-
Validates that gradients flow correctly through all PP stages.
134+
Validates that gradients flow correctly through all PP stages
135+
using manual gradient passing between stages (1F1B only).
130136
"""
131137
if current_platform.device_count() < 2:
132138
pytest.skip("This test requires 2 GPUs")
133139

134140
_run_pp_test_with_torchrun(
135141
"areal/tests/experimental/archon/torchrun/run_pp_tests.py",
136142
n_gpus=2,
137-
extra_args=["--test_type=backward", "--pp_size=2"],
143+
extra_args=["--test_type=backward_p2p", "--pp_size=2"],
138144
)
139145

140146

@@ -161,6 +167,55 @@ def test_pp_gradient_correctness_2gpu():
161167
)
162168

163169

170+
# =============================================================================
171+
# Schedule API Tests (2 GPU)
172+
# =============================================================================
173+
174+
175+
@pytest.mark.multi_gpu
176+
@pytest.mark.slow
177+
def test_pp_zbv_forward_2gpu():
178+
"""Test ZBVZeroBubble forward pass with 2 GPUs (pp=2) via schedule API.
179+
180+
Validates that PP model with ZBVZeroBubble schedule produces correct output
181+
using schedule.eval() API with V-style stage assignment.
182+
"""
183+
if current_platform.device_count() < 2:
184+
pytest.skip("This test requires 2 GPUs")
185+
186+
_run_pp_test_with_torchrun(
187+
"areal/tests/experimental/archon/torchrun/run_pp_tests.py",
188+
n_gpus=2,
189+
extra_args=[
190+
"--test_type=forward_schedule",
191+
"--pp_size=2",
192+
"--pp_schedule=ZBVZeroBubble",
193+
],
194+
)
195+
196+
197+
@pytest.mark.multi_gpu
198+
@pytest.mark.slow
199+
def test_pp_zbv_backward_2gpu():
200+
"""Test ZBVZeroBubble backward pass with 2 GPUs (pp=2) via schedule API.
201+
202+
Validates that gradients flow correctly through all PP stages
203+
using schedule.step() API with ZBVZeroBubble V-style stage assignment.
204+
"""
205+
if current_platform.device_count() < 2:
206+
pytest.skip("This test requires 2 GPUs")
207+
208+
_run_pp_test_with_torchrun(
209+
"areal/tests/experimental/archon/torchrun/run_pp_tests.py",
210+
n_gpus=2,
211+
extra_args=[
212+
"--test_type=backward_schedule",
213+
"--pp_size=2",
214+
"--pp_schedule=ZBVZeroBubble",
215+
],
216+
)
217+
218+
164219
# =============================================================================
165220
# 4 GPU Tests (Extended PP tests)
166221
# =============================================================================
@@ -169,34 +224,36 @@ def test_pp_gradient_correctness_2gpu():
169224
@pytest.mark.multi_gpu
170225
@pytest.mark.slow
171226
def test_pp_forward_4gpu():
172-
"""Test PP forward pass with 4 GPUs (pp=4).
227+
"""Test PP forward pass with 4 GPUs (pp=4) via manual P2P.
173228
174-
Validates PP with more stages (4 stages instead of 2).
229+
Validates PP with more stages (4 stages instead of 2) using
230+
manual activation passing (1F1B only).
175231
"""
176232
if current_platform.device_count() < 4:
177233
pytest.skip("This test requires 4 GPUs")
178234

179235
_run_pp_test_with_torchrun(
180236
"areal/tests/experimental/archon/torchrun/run_pp_tests.py",
181237
n_gpus=4,
182-
extra_args=["--test_type=forward", "--pp_size=4"],
238+
extra_args=["--test_type=forward_p2p", "--pp_size=4"],
183239
)
184240

185241

186242
@pytest.mark.multi_gpu
187243
@pytest.mark.slow
188244
def test_pp_backward_4gpu():
189-
"""Test PP backward pass with 4 GPUs (pp=4).
245+
"""Test PP backward pass with 4 GPUs (pp=4) via manual P2P.
190246
191-
Validates gradient flow with more stages.
247+
Validates gradient flow with more stages using manual gradient
248+
passing (1F1B only).
192249
"""
193250
if current_platform.device_count() < 4:
194251
pytest.skip("This test requires 4 GPUs")
195252

196253
_run_pp_test_with_torchrun(
197254
"areal/tests/experimental/archon/torchrun/run_pp_tests.py",
198255
n_gpus=4,
199-
extra_args=["--test_type=backward", "--pp_size=4"],
256+
extra_args=["--test_type=backward_p2p", "--pp_size=4"],
200257
)
201258

202259

0 commit comments

Comments
 (0)