Skip to content

Commit a39ea69

Browse files
committed
feat(archon): add ZBVZeroBubble pipeline schedule support
Add V-style stage assignment for ZBVZeroBubble, where rank 0 holds both the first and last stages for near-zero pipeline bubbles. ZBV's split backward (I/W separation) is incompatible with torch.compile, op-level selective AC, and memory_budget AC. These are auto-detected and handled: compile is disabled, incompatible AC modes fall back to full AC. Key changes: - Add "ZBVZeroBubble" to pp_schedule config choices - Add V-style stage assignment in _get_stage_indices() - Move _pp_last_stage_rank setup into _apply_pipeline_parallelism() - Auto-disable torch.compile and incompatible AC modes for ZBV - Add ZBV FQN generation tests and distributed test entries - Add torchrun ZBV forward/backward test support
1 parent 0d11df6 commit a39ea69

File tree

7 files changed

+453
-196
lines changed

7 files changed

+453
-196
lines changed

areal/api/cli_args.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -453,7 +453,7 @@ class ArchonEngineConfig:
453453
default="Interleaved1F1B",
454454
metadata={
455455
"help": "Pipeline parallel schedule type.",
456-
"choices": ["1F1B", "Interleaved1F1B"],
456+
"choices": ["1F1B", "Interleaved1F1B", "ZBVZeroBubble"],
457457
},
458458
)
459459
# NOTE: The following three PP layer distribution parameters are advanced options
@@ -466,7 +466,7 @@ class ArchonEngineConfig:
466466
"help": "Number of transformer layers per (virtual) pipeline stage. "
467467
"If set, num_virtual_stages is calculated from num_layers. "
468468
"If None, stages are inferred from schedule type "
469-
"(1 stage/rank for 1F1B, 2 stages/rank for Interleaved1F1B).",
469+
"(1 stage/rank for 1F1B, 2 stages/rank for Interleaved1F1B/ZBVZeroBubble).",
470470
},
471471
)
472472
pp_first_stage_less_layers: int = field(

areal/experimental/engine/archon_engine.py

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@
1212
import torch
1313
import torch.distributed as dist
1414
from torch import nn
15+
from torch.distributed.pipelining.schedules import (
16+
ScheduleZBVZeroBubble,
17+
get_schedule_class,
18+
)
1519
from transformers import (
1620
AutoConfig,
1721
PretrainedConfig,
@@ -211,9 +215,9 @@ def create_process_group(
211215

212216
# Pipeline parallel rank
213217
if self.parallel_dims.pp_enabled:
214-
pp_group = self.parallel_dims.get_group("pp")
215218
self._pp_rank = self.parallel_dims.get_mesh("pp").get_local_rank()
216-
self._pp_last_stage_rank = dist.get_process_group_ranks(pp_group)[-1]
219+
# Set in _apply_pipeline_parallelism() after pipeline setup
220+
self._pp_last_stage_rank = None
217221
else:
218222
self._pp_rank = 0
219223
self._pp_last_stage_rank = None
@@ -269,6 +273,36 @@ def initialize(self, addr: str | None, ft_spec: FinetuneSpec, *args, **kwargs):
269273
ac_config = self._build_ac_config()
270274
enable_compile = self.config.archon.enable_compile
271275

276+
# ZBVZeroBubble splits backward into I (input grad) and W (weight grad)
277+
# steps. This is incompatible with:
278+
# 1. torch.compile — its donated buffer optimization assumes a single
279+
# backward pass (retain_graph=False).
280+
# 2. Op-level selective AC — its per-op cache (storage.pop) is consumed
281+
# by the I step, leaving nothing for the W step recompute.
282+
# 3. memory_budget AC — it depends on torch.compile.
283+
# Full AC / layer-level selective AC use standard checkpoint_wrapper
284+
# whose gid-based recompute supports multiple backward passes.
285+
if self.config.archon.pp_schedule == "ZBVZeroBubble":
286+
if enable_compile:
287+
self.logger.warning(
288+
"ZBVZeroBubble is incompatible with torch.compile. "
289+
"Disabling torch.compile."
290+
)
291+
enable_compile = False
292+
293+
if ac_config is not None and (
294+
(
295+
ac_config.mode == "selective"
296+
and ac_config.selective_ac_option == "op"
297+
)
298+
or ac_config.mode == "memory_budget"
299+
):
300+
self.logger.warning(
301+
f"ZBVZeroBubble is incompatible with {ac_config.mode} AC. "
302+
"Falling back to full AC."
303+
)
304+
ac_config.mode = "full"
305+
272306
# Force pad_to_maximum when compile is enabled to avoid dynamic shape issues
273307
if enable_compile and not self.config.pad_to_maximum:
274308
self.logger.info(
@@ -790,6 +824,17 @@ def _apply_pipeline_parallelism(
790824
# Delete original model to free memory
791825
del self.model
792826

827+
# Determine which rank holds the last pipeline stage
828+
pp_group = self.parallel_dims.get_group("pp")
829+
pp_ranks = dist.get_process_group_ranks(pp_group)
830+
schedule_class = get_schedule_class(self.config.archon.pp_schedule)
831+
if schedule_class is ScheduleZBVZeroBubble:
832+
# V-style: rank 0 holds stages (0, num_stages-1)
833+
self._pp_last_stage_rank = pp_ranks[0]
834+
else:
835+
# Loop-style: last rank has last stage
836+
self._pp_last_stage_rank = pp_ranks[-1]
837+
793838
self.logger.info(
794839
f"PP enabled: has_first={self.pp_has_first_stage}, "
795840
f"has_last={self.pp_has_last_stage}"

areal/experimental/models/archon/pipeline_parallel.py

Lines changed: 29 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
from torch.distributed.pipelining.schedules import (
1515
PipelineScheduleMulti,
1616
PipelineScheduleSingle,
17+
ScheduleDualPipeV,
18+
ScheduleZBVZeroBubble,
1719
get_schedule_class,
1820
)
1921

@@ -210,7 +212,7 @@ def pipeline_module_split(
210212
Args:
211213
whole_model: The complete model to split
212214
pp_mesh: Pipeline parallel device mesh
213-
pp_schedule: Schedule type ("1F1B" or "Interleaved1F1B")
215+
pp_schedule: Schedule type ("1F1B", "Interleaved1F1B", or "ZBVZeroBubble")
214216
device: Target device for stages
215217
module_names_per_stage: Module FQNs for each stage
216218
@@ -297,29 +299,31 @@ def _get_stage_indices() -> tuple[int, ...]:
297299
Examples (pp_degree=4, num_stages=8):
298300
1F1B: Rank 0->(0,), Rank 1->(1,), ...
299301
Interleaved1F1B: Rank 0->(0,4), Rank 1->(1,5), Rank 2->(2,6), Rank 3->(3,7)
302+
ZBVZeroBubble: Rank 0->(0,7), Rank 1->(1,6), Rank 2->(2,5), Rank 3->(3,4)
300303
"""
301304
if num_stages % pp_degree != 0:
302305
raise ValueError(
303-
f"num_stages ({num_stages}) must be divisible by pp_degree ({pp_degree})"
306+
f"num_stages ({num_stages}) must be evenly divisible by "
307+
f"pp_degree ({pp_degree})"
304308
)
305309
stages_per_rank = num_stages // pp_degree
306310

307-
if pp_schedule == "1F1B":
308-
if stages_per_rank != 1:
309-
raise ValueError(
310-
f"1F1B schedule requires exactly 1 stage per rank, "
311-
f"got {stages_per_rank} ({num_stages} stages / {pp_degree} ranks)"
312-
)
313-
return (pp_rank,)
314-
elif pp_schedule == "Interleaved1F1B":
315-
if stages_per_rank < 2:
311+
schedule_class = get_schedule_class(pp_schedule)
312+
v_style_schedules = (ScheduleZBVZeroBubble, ScheduleDualPipeV)
313+
style = "v" if schedule_class in v_style_schedules else "loop"
314+
315+
if style == "v":
316+
if stages_per_rank != 2:
316317
raise ValueError(
317-
f"Interleaved1F1B schedule requires >= 2 stages per rank, "
318-
f"got {stages_per_rank} ({num_stages} stages / {pp_degree} ranks)"
318+
f"V-style schedules require exactly 2 stages per rank, "
319+
f"got {stages_per_rank}"
319320
)
320-
return tuple(pp_rank + s * pp_degree for s in range(stages_per_rank))
321+
stage_v_pairs = list(
322+
zip(range(pp_degree), range(num_stages - 1, pp_degree - 1, -1))
323+
)
324+
return stage_v_pairs[pp_rank]
321325
else:
322-
raise ValueError(f"Unknown pp_schedule: {pp_schedule}")
326+
return tuple(pp_rank + s * pp_degree for s in range(stages_per_rank))
323327

324328
stages: list[PipelineStage] = []
325329
model_parts: list[nn.Module] = []
@@ -351,7 +355,7 @@ def pipeline_llm(
351355
352356
Workflow:
353357
1. Generate module names for each virtual stage
354-
2. Split model into stages (multiple per rank for Interleaved1F1B)
358+
2. Split model into stages (multiple per rank for Interleaved1F1B/ZBVZeroBubble)
355359
3. Apply parallelization (TP, FSDP) to each model part
356360
357361
Args:
@@ -364,7 +368,7 @@ def pipeline_llm(
364368
365369
Returns:
366370
Tuple of:
367-
- stages: List of PipelineStage (1 for 1F1B, 2+ for Interleaved1F1B)
371+
- stages: List of PipelineStage (1 for 1F1B, 2+ for Interleaved1F1B/ZBVZeroBubble)
368372
- model_parts: List of model parts
369373
- has_first_stage: Whether this rank has the first stage
370374
- has_last_stage: Whether this rank has the last stage
@@ -429,6 +433,14 @@ def pipeline_llm(
429433
f"but got {stages_per_rank} (from layers_per_stage={layers_per_stage}). "
430434
f"Use 1F1B schedule for single stage per rank."
431435
)
436+
# ZBVZeroBubble requires exactly 2 stages per rank
437+
if schedule_class is ScheduleZBVZeroBubble and stages_per_rank != 2:
438+
raise ValueError(
439+
f"ZBVZeroBubble requires exactly 2 stages per rank, "
440+
f"but got {stages_per_rank}. "
441+
f"Set pp_layers_per_stage to achieve 2 stages per rank, "
442+
f"or let it default (None)."
443+
)
432444
else:
433445
# Default: 1 for single-stage schedules, 2 for multi-stage schedules
434446
stages_per_rank = 1 if is_single_stage_schedule else 2

areal/tests/experimental/archon/test_distributed_pp.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@
1515
- test_pp_forward_4gpu: PP=4, tests forward with more stages
1616
- test_pp_backward_4gpu: PP=4, tests backward with more stages
1717
18+
ZBV Tests (2 GPU):
19+
- test_pp_zbv_forward_2gpu: PP=2, ZBVZeroBubble schedule forward
20+
- test_pp_zbv_backward_2gpu: PP=2, ZBVZeroBubble schedule backward
21+
1822
PP Combination Tests (4 GPU):
1923
- test_pp_tp_forward_4gpu: PP=2, TP=2, tests PP+TP combination
2024
- test_pp_dp_forward_4gpu: PP=2, DP=2, tests PP+DP combination
@@ -161,6 +165,55 @@ def test_pp_gradient_correctness_2gpu():
161165
)
162166

163167

168+
# =============================================================================
169+
# ZBV Tests (2 GPU)
170+
# =============================================================================
171+
172+
173+
@pytest.mark.multi_gpu
174+
@pytest.mark.slow
175+
def test_pp_zbv_forward_2gpu():
176+
"""Test ZBVZeroBubble forward pass with 2 GPUs (pp=2).
177+
178+
Validates that PP model with ZBVZeroBubble schedule produces correct output.
179+
Uses V-style stage assignment where rank 0 holds first and last stages.
180+
"""
181+
if current_platform.device_count() < 2:
182+
pytest.skip("This test requires 2 GPUs")
183+
184+
_run_pp_test_with_torchrun(
185+
"areal/tests/experimental/archon/torchrun/run_pp_tests.py",
186+
n_gpus=2,
187+
extra_args=[
188+
"--test_type=forward",
189+
"--pp_size=2",
190+
"--pp_schedule=ZBVZeroBubble",
191+
],
192+
)
193+
194+
195+
@pytest.mark.multi_gpu
196+
@pytest.mark.slow
197+
def test_pp_zbv_backward_2gpu():
198+
"""Test ZBVZeroBubble backward pass with 2 GPUs (pp=2).
199+
200+
Validates that gradients flow correctly through all PP stages
201+
using ZBVZeroBubble V-style stage assignment.
202+
"""
203+
if current_platform.device_count() < 2:
204+
pytest.skip("This test requires 2 GPUs")
205+
206+
_run_pp_test_with_torchrun(
207+
"areal/tests/experimental/archon/torchrun/run_pp_tests.py",
208+
n_gpus=2,
209+
extra_args=[
210+
"--test_type=backward",
211+
"--pp_size=2",
212+
"--pp_schedule=ZBVZeroBubble",
213+
],
214+
)
215+
216+
164217
# =============================================================================
165218
# 4 GPU Tests (Extended PP tests)
166219
# =============================================================================

areal/tests/experimental/archon/test_pipeline_parallel.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,3 +221,30 @@ def test_exact_distribution(self):
221221
assert result[2] == ["layers.3", "layers.4"]
222222
# Stage 3: 1 layer + norm + output (1) = 2
223223
assert result[3] == ["layers.5", "norm", "output"]
224+
225+
226+
class TestZBVFqnGeneration:
227+
"""Test FQN generation for ZBV pipeline configurations."""
228+
229+
def test_zbv_fqn_generation(self):
230+
"""Verify FQN distribution for a typical ZBV config (pp_degree=2, 8 layers)."""
231+
result = generate_llm_fqn_per_model_part(num_stages=4, num_layers=8)
232+
assert len(result) == 4
233+
234+
# Rank 0 gets stages (0, 3), rank 1 gets stages (1, 2)
235+
rank0_modules = result[0] + result[3]
236+
rank1_modules = result[1] + result[2]
237+
238+
# Rank 0 has first and last stages
239+
assert "tok_embeddings" in rank0_modules
240+
assert "norm" in rank0_modules
241+
assert "output" in rank0_modules
242+
243+
# Rank 1 has only middle layers (no embeddings or output head)
244+
assert all(m.startswith("layers.") for m in rank1_modules)
245+
246+
# All layers covered exactly once
247+
all_layers = []
248+
for stage in result:
249+
all_layers.extend([m for m in stage if m.startswith("layers.")])
250+
assert all_layers == [f"layers.{i}" for i in range(8)]

0 commit comments

Comments
 (0)