-
Notifications
You must be signed in to change notification settings - Fork 485
Expand file tree
/
Copy pathtest_distributed_pp.py
More file actions
370 lines (286 loc) · 11.3 KB
/
test_distributed_pp.py
File metadata and controls
370 lines (286 loc) · 11.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
"""Pipeline Parallelism (PP) distributed tests for Archon Engine.
These tests require multiple GPUs and use torchrun for distributed execution.
Run tests:
pytest areal/tests/experimental/archon/test_distributed_pp.py -v -m multi_gpu
Test configuration:
2 GPU Tests (Core PP - manual P2P):
- test_pp_forward_2gpu: PP=2, manual activation passing (1F1B)
- test_pp_backward_2gpu: PP=2, manual gradient passing (1F1B)
- test_pp_gradient_correctness_2gpu: PP=2, tests PP gradients match non-PP
4 GPU Tests (Extended PP - manual P2P):
- test_pp_forward_4gpu: PP=4, manual activation passing (1F1B)
- test_pp_backward_4gpu: PP=4, manual gradient passing (1F1B)
Schedule API Tests (2 GPU):
- test_pp_zbv_forward_2gpu: PP=2, schedule.eval() with ZBVZeroBubble
- test_pp_zbv_backward_2gpu: PP=2, schedule.step() with ZBVZeroBubble
PP Combination Tests (4 GPU):
- test_pp_tp_forward_4gpu: PP=2, TP=2, tests PP+TP combination
- test_pp_dp_forward_4gpu: PP=2, DP=2, tests PP+DP combination
- test_pp_ep_forward_4gpu: PP=2, EP=2, tests PP+EP combination (MoE model)
PP Checkpoint Tests (2 GPU):
- test_pp_dcp_checkpoint_2gpu: PP=2, tests DCP save/load
- test_pp_dcp_with_optim_2gpu: PP=2, tests DCP with optimizer state
- test_pp_forward_match_2gpu: PP=2, tests forward match after checkpoint
"""
import subprocess
import tempfile
import pytest
import torch
from areal.infra.platforms import current_platform
from areal.utils.network import find_free_ports
pytestmark = pytest.mark.skipif(
not torch.cuda.is_available(), reason="CUDA not available"
)
def _run_pp_test_with_torchrun(
script: str,
n_gpus: int,
extra_args: list[str] | None = None,
timeout: int = 300,
):
"""Run a PP test script with torchrun.
Args:
script: Path to the test script
n_gpus: Number of GPUs to use
extra_args: Additional command line arguments
timeout: Timeout in seconds
Raises:
pytest.fail: If the test fails
"""
port = find_free_ports(1)[0]
with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False) as f:
out_file = f.name
cmd = [
"torchrun",
f"--nproc_per_node={n_gpus}",
"--nnodes=1",
"--master-addr=localhost",
f"--master_port={port}",
script,
f"--output={out_file}",
]
if extra_args:
cmd.extend(extra_args)
try:
result = subprocess.run(
cmd,
check=True,
capture_output=True,
text=True,
timeout=timeout,
)
# Check result file
with open(out_file) as f:
test_result = f.read().strip()
if test_result != "Passed":
pytest.fail(
f"Test returned '{test_result}'. "
f"stdout: {result.stdout}\nstderr: {result.stderr}"
)
except subprocess.CalledProcessError as e:
pytest.fail(f"Test failed with error: {e.stderr}\nstdout: {e.stdout}")
except subprocess.TimeoutExpired:
pytest.fail(f"Test timed out after {timeout} seconds")
# =============================================================================
# 2 GPU Tests (Core PP tests)
# =============================================================================
@pytest.mark.multi_gpu
@pytest.mark.slow
def test_pp_forward_2gpu():
"""Test PP forward pass with 2 GPUs (pp=2) via manual P2P.
Validates that PP model output matches golden (non-PP) model output
using manual activation passing between stages (1F1B only).
"""
if current_platform.device_count() < 2:
pytest.skip("This test requires 2 GPUs")
_run_pp_test_with_torchrun(
"areal/tests/experimental/archon/torchrun/run_pp_tests.py",
n_gpus=2,
extra_args=["--test_type=forward_p2p", "--pp_size=2"],
)
@pytest.mark.multi_gpu
@pytest.mark.slow
def test_pp_backward_2gpu():
"""Test PP backward pass with 2 GPUs (pp=2) via manual P2P.
Validates that gradients flow correctly through all PP stages
using manual gradient passing between stages (1F1B only).
"""
if current_platform.device_count() < 2:
pytest.skip("This test requires 2 GPUs")
_run_pp_test_with_torchrun(
"areal/tests/experimental/archon/torchrun/run_pp_tests.py",
n_gpus=2,
extra_args=["--test_type=backward_p2p", "--pp_size=2"],
)
@pytest.mark.multi_gpu
def test_pp_gradient_correctness_2gpu():
"""Test PP gradient correctness with 2 GPUs (pp=2).
Validates that PP step() API produces identical gradients to
manual forward/backward passes without pipeline parallelism.
This test uses a simple model with:
- FirstStageModel: embedding + 2 transformer blocks
- LastStageModel: 2 transformer blocks + output head
- 2 microbatches with different packed sequence lengths
"""
if current_platform.device_count() < 2:
pytest.skip("This test requires 2 GPUs")
_run_pp_test_with_torchrun(
"areal/tests/experimental/archon/torchrun/run_pp_gradient_verify.py",
n_gpus=2,
extra_args=["--n_microbatches=2", "--seq_len=64"],
timeout=120,
)
# =============================================================================
# Schedule API Tests (2 GPU)
# =============================================================================
@pytest.mark.multi_gpu
@pytest.mark.slow
def test_pp_zbv_forward_2gpu():
"""Test ZBVZeroBubble forward pass with 2 GPUs (pp=2) via schedule API.
Validates that PP model with ZBVZeroBubble schedule produces correct output
using schedule.eval() API with V-style stage assignment.
"""
if current_platform.device_count() < 2:
pytest.skip("This test requires 2 GPUs")
_run_pp_test_with_torchrun(
"areal/tests/experimental/archon/torchrun/run_pp_tests.py",
n_gpus=2,
extra_args=[
"--test_type=forward_schedule",
"--pp_size=2",
"--pp_schedule=ZBVZeroBubble",
],
)
@pytest.mark.multi_gpu
@pytest.mark.slow
def test_pp_zbv_backward_2gpu():
"""Test ZBVZeroBubble backward pass with 2 GPUs (pp=2) via schedule API.
Validates that gradients flow correctly through all PP stages
using schedule.step() API with ZBVZeroBubble V-style stage assignment.
"""
if current_platform.device_count() < 2:
pytest.skip("This test requires 2 GPUs")
_run_pp_test_with_torchrun(
"areal/tests/experimental/archon/torchrun/run_pp_tests.py",
n_gpus=2,
extra_args=[
"--test_type=backward_schedule",
"--pp_size=2",
"--pp_schedule=ZBVZeroBubble",
],
)
# =============================================================================
# 4 GPU Tests (Extended PP tests)
# =============================================================================
@pytest.mark.multi_gpu
@pytest.mark.slow
def test_pp_forward_4gpu():
"""Test PP forward pass with 4 GPUs (pp=4) via manual P2P.
Validates PP with more stages (4 stages instead of 2) using
manual activation passing (1F1B only).
"""
if current_platform.device_count() < 4:
pytest.skip("This test requires 4 GPUs")
_run_pp_test_with_torchrun(
"areal/tests/experimental/archon/torchrun/run_pp_tests.py",
n_gpus=4,
extra_args=["--test_type=forward_p2p", "--pp_size=4"],
)
@pytest.mark.multi_gpu
@pytest.mark.slow
def test_pp_backward_4gpu():
"""Test PP backward pass with 4 GPUs (pp=4) via manual P2P.
Validates gradient flow with more stages using manual gradient
passing (1F1B only).
"""
if current_platform.device_count() < 4:
pytest.skip("This test requires 4 GPUs")
_run_pp_test_with_torchrun(
"areal/tests/experimental/archon/torchrun/run_pp_tests.py",
n_gpus=4,
extra_args=["--test_type=backward_p2p", "--pp_size=4"],
)
# =============================================================================
# PP Combination Tests (PP+TP, PP+DP, PP+EP)
# =============================================================================
@pytest.mark.multi_gpu
@pytest.mark.slow
def test_pp_tp_forward_4gpu():
"""Test PP+TP combination with 4 GPUs (pp=2, tp=2).
Validates that PP works correctly when combined with Tensor Parallelism.
"""
if current_platform.device_count() < 4:
pytest.skip("This test requires 4 GPUs")
_run_pp_test_with_torchrun(
"areal/tests/experimental/archon/torchrun/run_pp_combinations.py",
n_gpus=4,
extra_args=["--test_type=pp_tp"],
)
@pytest.mark.multi_gpu
@pytest.mark.slow
def test_pp_dp_forward_4gpu():
"""Test PP+DP combination with 4 GPUs (pp=2, dp_shard=2).
Validates that PP works correctly when combined with Data Parallelism.
"""
if current_platform.device_count() < 4:
pytest.skip("This test requires 4 GPUs")
_run_pp_test_with_torchrun(
"areal/tests/experimental/archon/torchrun/run_pp_combinations.py",
n_gpus=4,
extra_args=["--test_type=pp_dp"],
)
@pytest.mark.multi_gpu
@pytest.mark.slow
def test_pp_ep_forward_4gpu():
"""Test PP+EP combination with 4 GPUs (pp=2, ep=2).
Validates that PP works correctly when combined with Expert Parallelism.
This test uses a MoE model since EP requires expert parallelism.
"""
if current_platform.device_count() < 4:
pytest.skip("This test requires 4 GPUs")
_run_pp_test_with_torchrun(
"areal/tests/experimental/archon/torchrun/run_pp_combinations.py",
n_gpus=4,
extra_args=["--test_type=pp_ep"],
)
# =============================================================================
# PP Checkpoint Tests
# =============================================================================
@pytest.mark.multi_gpu
@pytest.mark.slow
def test_pp_dcp_checkpoint_2gpu():
"""Test PP checkpoint save/load using DCP format with 2 GPUs.
Validates that PP model weights can be saved and loaded correctly.
"""
if current_platform.device_count() < 2:
pytest.skip("This test requires 2 GPUs")
_run_pp_test_with_torchrun(
"areal/tests/experimental/archon/torchrun/run_checkpoint_tests.py",
n_gpus=2,
extra_args=["--test_type=pp_dcp_checkpoint", "--pp_size=2"],
)
@pytest.mark.multi_gpu
@pytest.mark.slow
def test_pp_dcp_with_optim_2gpu():
"""Test PP checkpoint with optimizer state using DCP format with 2 GPUs.
Validates that optimizer state is correctly saved and loaded with PP.
"""
if current_platform.device_count() < 2:
pytest.skip("This test requires 2 GPUs")
_run_pp_test_with_torchrun(
"areal/tests/experimental/archon/torchrun/run_checkpoint_tests.py",
n_gpus=2,
extra_args=["--test_type=pp_dcp_with_optim", "--pp_size=2"],
)
@pytest.mark.multi_gpu
@pytest.mark.slow
def test_pp_forward_match_2gpu():
"""Test forward output matches after PP checkpoint save/load with 2 GPUs.
Validates that model behavior is preserved after checkpoint save/load.
"""
if current_platform.device_count() < 2:
pytest.skip("This test requires 2 GPUs")
_run_pp_test_with_torchrun(
"areal/tests/experimental/archon/torchrun/run_checkpoint_tests.py",
n_gpus=2,
extra_args=["--test_type=pp_forward_match", "--pp_size=2"],
)