-
Notifications
You must be signed in to change notification settings - Fork 486
Expand file tree
/
Copy pathcli_args.py
More file actions
2058 lines (1850 loc) · 69.6 KB
/
cli_args.py
File metadata and controls
2058 lines (1850 loc) · 69.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import argparse
import json
import os
from dataclasses import MISSING as dataclass_missing
from dataclasses import asdict, dataclass, field, fields
from enum import Enum
from pathlib import Path
from typing import TYPE_CHECKING, Any, ClassVar
import uvloop
import yaml
from hydra import compose as hydra_compose
from hydra import initialize as hydra_init
from hydra.core.global_hydra import GlobalHydra
from omegaconf import MISSING, DictConfig, OmegaConf
from areal.utils import logging, name_resolve, pkg_version
from areal.utils.constants import (
PROX_LOGP_METHOD_RECOMPUTE,
PROX_LOGP_METHODS_ALL,
)
from areal.utils.pkg_version import is_version_less
if TYPE_CHECKING:
from transformers import PreTrainedTokenizerFast
uvloop.install()
logger = logging.getLogger("CLIArgs")
@dataclass
class NormConfig:
"""Configuration for reward/advantage normalization."""
mean_level: str | None = field(
default="batch",
metadata={
"help": "Mean level for normalization. None for no mean normalization.",
"choices": ["batch", "group", None],
},
)
mean_leave1out: bool = field(
default=False,
metadata={"help": "Whether to use leave-one-out average."},
)
std_level: str | None = field(
default="batch",
metadata={
"help": "Standard deviation level for normalization. None for no std normalization.",
"choices": ["batch", "group", None],
},
)
std_unbiased: bool = field(
default=True,
metadata={
"help": "Whether to use unbiased standard deviation computation. Defaults to True (changed from False in v0.3.4)."
},
)
eps: float = field(
default=1e-5,
metadata={
"help": "The eps when dividing by standard deviation to avoid numerical issues."
},
)
group_size: int = field(
default=1, metadata={"help": "Group size for group-level normalization"}
)
@dataclass
class MicroBatchSpec:
"""Specification for splitting micro-batches during training."""
n_mbs: int | None = field(
default=1,
metadata={
"help": "Number of micro-batches (or minimum number if max_tokens_per_mb is set). Used when max_tokens_per_mb is None or as minimum count",
},
)
granularity: int = field(
default=1,
metadata={
"help": "Granularity of each micro-batch. Adjacent sequences are grouped by this size when dividing microbatches.",
},
)
max_tokens_per_mb: int | None = field(
default=None,
metadata={
"help": "Maximum tokens per micro-batch for each forward pass. When set, n_mbs becomes the minimum number of micro-batches.",
},
)
n_mbs_divisor: int = field(
default=1,
metadata={
"help": "Divisor for the number of micro-batches. The final number of micro-batches will be adjusted to be divisible by this value.",
},
)
@classmethod
def new(cls, mb_spec: "MicroBatchSpec", **kwargs):
"""Create new spec with updated fields while maintaining Omegaconf compatibility."""
fields = dict(
n_mbs=mb_spec.n_mbs,
granularity=mb_spec.granularity,
max_tokens_per_mb=mb_spec.max_tokens_per_mb,
n_mbs_divisor=mb_spec.n_mbs_divisor,
)
fields.update(kwargs)
return cls(**fields)
@dataclass
class GenerationHyperparameters:
"""Controls text generation behavior for rollout."""
n_samples: int = field(
default=1, metadata={"help": "Number of sequences to generate per prompt."}
)
max_new_tokens: int = field(
default=16384, metadata={"help": "Maximum number of tokens to generate."}
)
min_new_tokens: int = field(
default=0, metadata={"help": "Minimum number of tokens to generate."}
)
max_tokens: int = field(
default=32768,
metadata={
"help": "Maximum number of tokens including prompt and generated tokens."
},
)
greedy: bool = field(
default=False,
metadata={"help": "Whether to use greedy decoding (max probability)."},
)
top_p: float = field(
default=1.0,
metadata={"help": "Nucleus sampling probability threshold (0.0, 1.0]."},
)
top_k: int = field(
default=int(1e8),
metadata={"help": "Number of highest probability tokens to consider."},
)
temperature: float = field(
default=1.0,
metadata={"help": "Sampling temperature. Higher values increase diversity."},
)
stop_token_ids: list[int] = field(
default_factory=list,
metadata={"help": "Stop generation when encountering these token IDs."},
)
ignore_eos: bool = field(
default=False,
metadata={"help": "Do not stop generation when EOS is encountered."},
)
skip_special_tokens: bool = field(
default=True,
metadata={"help": "Skip special tokens when decoding/displaying outputs."},
)
stop: list[str] | None = field(
default=None,
metadata={
"help": "One or multiple stop words. Generation will stop if one of these words is sampled."
},
)
frequency_penalty: float = field(
default=0.0,
metadata={
"help": (
"Penalizes tokens based on their frequency in generation so far. "
"Must be between -2 and 2 where negative numbers encourage repetition."
)
},
)
lora_name: str = field(
default="",
metadata={"help": "Lora name to be used for this generation."},
)
use_beam_search: bool = field(
default=False,
metadata={
"help": "Enable beam search in the vLLM engine. When enabled, sampling parameters like temperature, top-p, and top-k are auto ignored."
},
)
# NOTE: to add new parameters, please correctly handle them in the `to_openai_args_dict` method.
def new(self, **kwargs):
args = asdict(self)
args.update(kwargs)
return GenerationHyperparameters(**args)
def new_with_stop_and_pad_token_ids(self, tokenizer: "PreTrainedTokenizerFast"):
"""Create a new generation hyperparameters with stop and pad token ids added."""
new_stop_token_ids = self.stop_token_ids.copy()
if tokenizer.pad_token_id not in new_stop_token_ids:
new_stop_token_ids.append(tokenizer.pad_token_id)
if tokenizer.eos_token_id not in new_stop_token_ids:
new_stop_token_ids.append(tokenizer.eos_token_id)
return self.new(stop_token_ids=new_stop_token_ids)
def to_openai_completions_args_dict(
self, exclude_args: list[str] | None = None
) -> dict[str, Any]:
return self.to_openai_args_dict(
exclude_args=exclude_args, api_format="completions"
)
def to_openai_responses_args_dict(
self, exclude_args: list[str] | None = None
) -> dict[str, Any]:
return self.to_openai_args_dict(
exclude_args=exclude_args, api_format="responses"
)
def to_openai_agents_model_settings_dict(
self, exclude_args: list[str] | None = None
) -> dict[str, Any]:
return self.to_openai_args_dict(
exclude_args=exclude_args, api_format="openai-agents"
)
_OPENAI_UNSUPPORTED_ARGS: ClassVar[set[str]] = {
"min_new_tokens", # Not supported by OpenAI
"greedy", # Not directly supported by OpenAI
"top_k", # Not supported by OpenAI
"stop_token_ids", # Not supported by OpenAI
"ignore_eos", # Not supported by OpenAI
"skip_special_tokens", # Not supported by OpenAI
"lora_name", # Not supported by OpenAI
"use_beam_search", # Not supported by OpenAI
"max_tokens", # deprecated by "completions", not used in "responses", should be `max_new_tokens` in "openai-agents"
}
def to_openai_args_dict(
self, exclude_args: list[str] | None = None, api_format: str = "completions"
) -> dict[str, Any]:
"""Convert the generation hyperparameters to a dictionary of arguments for OpenAI client."""
final_exclude_args = set(exclude_args) if exclude_args is not None else set()
final_exclude_args.update(self._OPENAI_UNSUPPORTED_ARGS)
# TODO: move the excluded args into extra body, so they can be passed through the client request
mapping = {"n_samples": "n"}
if api_format == "completions":
mapping["max_new_tokens"] = "max_completion_tokens"
elif api_format == "responses":
mapping["max_new_tokens"] = "max_output_tokens"
elif api_format == "openai-agents":
# NOTE: max_tokens in openai-agents means `max_new_tokens` in sglang/vllm. This is not a bug
mapping["max_new_tokens"] = "max_tokens"
else:
raise ValueError(f"Unsupported API format: {api_format}")
res = {}
for k, v in asdict(self).items():
if k in final_exclude_args:
should_warn = False
current_value = getattr(self, k)
f = next(_field for _field in fields(self) if _field.name == k)
# Check if equal to the default value
if f.default is not dataclass_missing:
if current_value != f.default:
should_warn = True
elif f.default_factory is not dataclass_missing:
if current_value != f.default_factory():
should_warn = True
if should_warn:
logger.warning(
f"Unsupported arg for openai format: `{k}` with value {current_value}"
)
continue
key = mapping.get(k, k)
if key in res:
logger.warning(f"Overriding key: {key} from {k} with value: {v}")
res[key] = v
return res
# Train Engine Configs
@dataclass
class OptimizerConfig:
"""Configuration for model optimization during training."""
type: str = field(
default="adam",
metadata={
"help": "Optimizer type. Adam_bf16 currently only supported FSDP Engine.",
"choices": ["adam", "sgd", "adam_bf16"],
},
)
lr: float = field(default=1e-3, metadata={"help": "Learning rate"})
weight_decay: float = field(default=0.01, metadata={"help": "Weight decay"})
beta1: float = field(
default=0.9,
metadata={
"help": "Adam beta1 parameter. Only effective when optimizer_type is adam/adam_bf16"
},
)
beta2: float = field(
default=0.999,
metadata={
"help": "Adam beta2 parameter. Only effective when optimizer_type is adam/adam_bf16"
},
)
eps: float = field(
default=1e-8,
metadata={
"help": "Adam epsilon parameter. Only effective when optimizer_type is adam/adam_bf16"
},
)
min_lr_ratio: float = field(
default=0.0,
metadata={
"help": "Minimum learning rate ratio after annealing",
},
)
lr_scheduler_type: str = field(
default="constant",
metadata={
"help": "Learning rate scheduler type",
"choices": ["linear", "cosine", "constant"],
},
)
warmup_steps_proportion: float = field(
default=0.001,
metadata={
"help": "Proportion of training steps for warmup",
},
)
offload: bool = field(
default=False, metadata={"help": "Enable optimizer state offloading"}
)
initial_loss_scale: float = field(
default=2**32, metadata={"help": "Initial loss scaling factor"}
)
min_loss_scale: float = field(
default=1.0, metadata={"help": "Minimum loss scaling factor"}
)
loss_scale_window: float = field(
default=5, metadata={"help": "Window size for loss scaling adjustment"}
)
hysteresis: int = field(
default=2, metadata={"help": "Hysteresis (scaling factor) for loss scaling"}
)
gradient_clipping: float = field(
default=1.0, metadata={"help": "Gradient clipping threshold"}
)
@dataclass
class FSDPWrapPolicy:
"""Policy configuration for FSDP model layer wrapping. None defaults to wrapping transformer decoder layers defined by transformers."""
transformer_layer_cls_to_wrap: list[str] | None = field(
default=None,
metadata={"help": "A list of transformer layer names for FSDP to wrap."},
)
@dataclass
class FSDPEngineConfig:
"""Configuration for Fully Sharded Data Parallel (FSDP) training backend."""
wrap_policy: FSDPWrapPolicy | None = field(
default=None,
metadata={"help": "FSDP wrap policy, specifying model layers to wrap."},
)
offload_params: bool = field(
default=False,
metadata={"help": "Whether to offload FSDP parameters to CPU."},
)
memory_efficient_load: bool = field(
default=False,
metadata={
"help": "Enable memory-efficient model loading. When enabled, model weights "
"are initialized on CPU and only rank 0 loads pretrained weights, which are "
"then broadcast to all ranks after FSDP sharding. This reduces peak GPU memory "
"during initialization for large models. Note: For VLMs, rank 0 broadcast is "
"not used; each rank loads weights independently on CPU."
},
)
@dataclass
class ArchonEngineConfig:
"""Configuration for Archon Engine training backend."""
# Attention backend
attn_type: str = field(
default="varlen",
metadata={
"help": "Attention backend type. Use 'tree' for tree training.",
"choices": ["varlen", "sdpa", "tree"],
},
)
# CPU offloading for FSDP
offload_params: bool = field(
default=False,
metadata={"help": "Whether to offload FSDP parameters to CPU."},
)
# Whether to enable torch.compile
enable_compile: bool = field(
default=True,
metadata={"help": "Enable torch.compile for TransformerBlocks."},
)
# Activation Checkpointing (enabled when gradient_checkpointing=True)
ac_mode: str = field(
default="selective",
metadata={
"help": "Activation checkpointing mode. "
"'memory_budget' requires enable_compile=True.",
"choices": ["none", "full", "selective", "memory_budget"],
},
)
selective_ac_option: str = field(
default="op",
metadata={
"help": "Selective AC option: 'op' for op-level, "
"or integer string (e.g., '2') for every Nth layer."
},
)
ac_memory_budget: float = field(
default=0.5,
metadata={
"help": "Memory budget for 'memory_budget' AC mode. "
"0.0 = minimum memory (max recompute), 1.0 = default behavior (no recompute)."
},
)
ac_preserve_rng_state: bool = field(
default=False,
metadata={
"help": "Preserve RNG state during checkpointing for deterministic output. "
"Enabling this may slow down training."
},
)
ac_debug: bool = field(
default=False,
metadata={
"help": "(Testing only) Capture AC debug information. Will be slower."
},
)
# Pipeline Parallel Schedule
pp_schedule: str = field(
default="Interleaved1F1B",
metadata={
"help": "Pipeline parallel schedule type.",
"choices": ["1F1B", "Interleaved1F1B", "ZBVZeroBubble"],
},
)
# NOTE: The following three PP layer distribution parameters are advanced options
# that most users do not need to configure. The defaults work well for typical cases.
# TODO: Consider simplifying or refactoring these parameters in the future.
# Currently kept for consistency with Megatron's pipeline parallel configuration.
pp_layers_per_stage: int | None = field(
default=None,
metadata={
"help": "Number of transformer layers per (virtual) pipeline stage. "
"If set, num_virtual_stages is calculated from num_layers. "
"If None, stages are inferred from schedule type "
"(1 stage/rank for 1F1B, 2 stages/rank for Interleaved1F1B/ZBVZeroBubble).",
},
)
pp_first_stage_less_layers: int = field(
default=1,
metadata={
"help": "Number of layers to reduce in the first pipeline stage. "
"Accounts for embedding layer overhead.",
},
)
pp_last_stage_less_layers: int = field(
default=1,
metadata={
"help": "Number of layers to reduce in the last pipeline stage. "
"Accounts for output layer overhead.",
},
)
def __post_init__(self):
if self.pp_layers_per_stage is not None and self.pp_layers_per_stage < 1:
raise ValueError(
f"pp_layers_per_stage must be >= 1, got {self.pp_layers_per_stage}"
)
if self.pp_first_stage_less_layers < 0:
raise ValueError(
f"pp_first_stage_less_layers must be >= 0, "
f"got {self.pp_first_stage_less_layers}"
)
if self.pp_last_stage_less_layers < 0:
raise ValueError(
f"pp_last_stage_less_layers must be >= 0, "
f"got {self.pp_last_stage_less_layers}"
)
# These configurations are used by Megatron Bridge to build Megatron models.
@dataclass
class DistributedDataParallelConfig:
"""Configuration for Megatron's DistributedDataParallel.
Refer to Megatron-LM documentation for details.
"""
grad_reduce_in_fp32: bool = True
overlap_grad_reduce: bool = False
overlap_param_gather: bool = False
align_param_gather: bool = False
use_distributed_optimizer: bool = True
check_for_nan_in_grad: bool = False
bucket_size: int | None = None
average_in_collective: bool = False
fp8_param_gather: bool = False
@dataclass
class FP8EngineConfig:
"""Configuration for FP8 (8-bit floating point) training.
This configuration encapsulates all FP8-related parameters and can be reused
across different engines (e.g., Megatron, FSDP). When None in the parent config,
FP8 training is disabled.
"""
mode: str = field(
default="e4m3",
metadata={
"help": "FP8 precision mode. Options: "
"'e4m3' (uniform e4m3), "
"'hybrid' (e4m3 for activations/weights, e5m2 for output activation gradients)."
},
)
recipe: str = field(
default="delayed",
metadata={
"help": "FP8 scaling recipe. Options: 'tensorwise', 'delayed', 'mxfp8' (Blackwell only), 'blockwise'."
},
)
param: bool = field(
default=False,
metadata={
"help": "Keep parameters in FP8 precision to save memory. "
"Not all parameters will be converted to fp8; for example, biases will remain unchanged."
},
)
margin: int = field(
default=0,
metadata={"help": "Margin for FP8 scaling factor computation."},
)
amax_history_len: int = field(
default=1,
metadata={
"help": "Length of amax history window for scaling factor computation."
},
)
amax_compute_algo: str = field(
default="most_recent",
metadata={
"help": "Algorithm for choosing amax value. Options: 'max' (largest in history window), 'most_recent'."
},
)
wgrad: bool = field(
default=True,
metadata={
"help": "When False, override FP8 config and compute weight gradients in higher precision."
},
)
dot_product_attention: bool = field(
default=False,
metadata={"help": "Use FP8 implementation of Dot Product Attention."},
)
multi_head_attention: bool = field(
default=False,
metadata={"help": "Use FP8 implementation of Multi Head Attention."},
)
tp_only_amax_red: bool = field(
default=False,
metadata={"help": "Reduce FP8 AMAX only in TP or TP-CP domain."},
)
first_last_layers_bf16: bool = field(
default=False,
metadata={
"help": "Retain first and last N TransformerBlocks in BF16 instead of FP8."
},
)
num_layers_at_start_in_bf16: int = field(
default=1,
metadata={
"help": "Number of layers at start to keep in BF16 when first_last_layers_bf16 is True."
},
)
num_layers_at_end_in_bf16: int = field(
default=1,
metadata={
"help": "Number of layers at end to keep in BF16 when first_last_layers_bf16 is True."
},
)
direct_convert: bool = field(
default=True,
metadata={
"help": "Whether to use direct FP8 conversion during weight updates and save/load. "
"When True, FP8 parameters are directly converted between TE FP8 and PyTorch FP8 "
"without intermediate dequantization/quantization."
},
)
@dataclass
class MegatronEngineConfig:
"""Configuration for Megatron-LM training framework.
Refer to Megatron-LM documentation for implementation details.
"""
# Distributed Training Configuration
wrap_with_ddp: bool = True
use_torch_fsdp2: bool = False # TODO: pending test
use_custom_fsdp: bool = False # TODO: pending test
ddp: DistributedDataParallelConfig = field(
default_factory=DistributedDataParallelConfig
)
virtual_pipeline_parallel_size: int = field(
default=1,
metadata={
"help": (
"Virtual pipeline parallel size for Megatron interleaved schedule. "
"Set to >1 to enable VPP. Default is 1 (disabled)."
)
},
)
# Don't use MegatronOptimizerConfig here because OmegaConf
# does not recognize the annotation "torch.dtype"
overlap_param_gather_with_optimizer_step: bool = False
# Precision Configuration
use_precision_aware_optimizer: bool = False
main_grads_dtype: str = "float32"
main_params_dtype: str = "float32"
exp_avg_dtype: str = "float32"
exp_avg_sq_dtype: str = "float32"
# Checkpointing Configuration
async_save: bool = False
use_checkpoint_opt_param_scheduler: bool = True
# Deterministic Option
# NOTE: This option forces torch to use deterministic algorithms,
# which makes sure that two forward passes with the same input
# will produce the same output. However, it may have a performance impact.
# It is recommended to set this option to True for RL training on MoE models for stability.
use_deterministic_algorithms: bool = False
# Gradient checkpointing options, only effective when gradient_checkpointing=True
recompute_granularity: str | None = "full"
recompute_method: str | None = "uniform"
recompute_num_layers: int | None = 1
distribute_saved_activations: bool | None = None
recompute_modules: list[str] | None = None
# MoE
moe_router_dtype: str | None = "fp32"
moe_shared_expert_overlap: bool = field(
default=False,
metadata={
"help": "Enable overlapping between shared expert computations and dispatcher communications. "
"Without this, the shared experts execute after the routed experts."
},
)
moe_enable_deepep: bool = False
moe_token_dispatcher_type: str = field(
default="alltoall",
metadata={
"help": "Type of token dispatcher. Options: 'allgather','alltoall' and 'flex'."
},
)
moe_permute_fusion: bool = field(
default=False,
metadata={"help": "Fuse token rearrangement ops during token dispatching."},
)
# FP8 Training Configuration
fp8_config: FP8EngineConfig | None = None
class SchedulingStrategyType(str, Enum):
separation = "separation"
colocation = "colocation"
@dataclass
class SchedulingStrategy:
type: str = field(
default="separation",
metadata={"choices": ["separation", "colocation"]},
)
target: str | None = field(
default=None, metadata={"help": "The target role to be colocated with"}
)
fork: bool = field(
default=True,
metadata={
"help": "When True with colocation, the target worker spawns a new "
"process on the same node/GPUs instead of sharing its process. "
"Provides process isolation while sharing GPU resources."
},
)
@dataclass
class SchedulingSpec:
cpu: int = field(
default=8, metadata={"help": "Number of CPU cores required per GPU"}
)
gpu: int = field(
default=0,
metadata={
"help": "Number of GPU units required. Used only when allocating pods."
},
)
mem: int = field(
default=32, metadata={"help": "Amount of memory (GB) required per GPU"}
)
port_count: int = field(default=2, metadata={"help": "Number of ports to expose"})
image: str = field(
default="/storage/openpsi/images/areal-latest.sif",
metadata={
"help": "Docker/Singularity container image to use. "
"Currently only used by Slurm. Will be potentially used by Kubernetes in the future."
},
)
task_type: str = field(
default="worker",
metadata={
"help": "Task type (e.g., worker, engine)",
"choices": ["worker", "engine"],
},
)
env_vars: dict[str, str] = field(
default_factory=dict,
metadata={"help": "Environment variables for the container"},
)
cmd: str | None = field(
default=None,
metadata={
"help": "Command to execute inside the container. Defaults to AReaL's RPC server."
},
)
# Slurm specific options
srun_additional_args: str = field(
default="--unbuffered --mpi=pmi2 -K --chdir $PWD",
metadata={
"help": "Additional arguments to pass to the srun command. Only used by slurm."
},
)
additional_bash_cmds: list[str] | None = field(
default=None,
metadata={
"help": "Additional bash commands to setup the container before running "
"the torchrun command. Only used by slurm."
},
)
container_type: str = field(
default="apptainer",
metadata={
"help": "Type of containers used in slurm",
"choices": ["apptainer", "none"],
},
)
mount: str = field(
default="/storage:/storage", metadata={"help": "Mount path for slurm."}
)
nodelist: str | None = field(
default=None, metadata={"help": "sbatch/srun's `--nodelist` option for slurm."}
)
exclude: str | None = field(
default=None, metadata={"help": "sbatch/srun's `--exclude` option for slurm."}
)
@dataclass
class TrainEngineConfig:
"""Core configuration for model training, including optimization and backend settings."""
experiment_name: str = MISSING
trial_name: str = MISSING
path: str = field(default="", metadata={"help": "Path to HuggingFace checkpoint"})
attn_impl: str = field(
default="flash_attention_2",
metadata={
"help": "Attention implementation for huggingface transformers model.",
"choices": ["flash_attention_2"],
},
)
init_from_scratch: bool = field(
default=False, metadata={"help": "Initialize model weights randomly"}
)
is_critic: bool = field(
default=False,
metadata={"help": "Whether to use a critic/reward model"},
)
temperature: float = field(
default=1.0, metadata={"help": "Temperature during generation."}
)
# Runtime microbatch limit
mb_spec: MicroBatchSpec = field(default_factory=MicroBatchSpec)
pad_to_maximum: bool = field(
default=False,
metadata={
"help": (
"Whether to pad each microbatch to the length upper bound specified by mb_spec. "
"Can reduce memory fragmentation but slows down training."
)
},
)
# Training Backend Configuration
disable_dropout: bool = field(
default=False, metadata={"help": "Disable dropout layers during training"}
)
gradient_checkpointing: bool = field(
default=False, metadata={"help": "Enable gradient checkpointing"}
)
dtype: str = field(default="bfloat16", metadata={"help": "Parameter data type."})
grad_reduce_dtype: str = field(
default="float32", metadata={"help": "Gradient reduction data type."}
)
optimizer: OptimizerConfig | None = field(
default=None,
metadata={"help": "Optimizer configuration. None means no training."},
)
weight_update_mode: str = field(
default="xccl",
metadata={"help": "Weight update backend type.", "choices": ["disk", "xccl"]},
)
fsdp: FSDPEngineConfig = field(default_factory=FSDPEngineConfig)
archon: ArchonEngineConfig = field(default_factory=ArchonEngineConfig)
megatron: MegatronEngineConfig = field(default_factory=MegatronEngineConfig)
# Lora
use_lora: bool = field(
default=False,
metadata={
"help": "Whether to use LoRA. Only support FSDP. Note that should be enabled together with vLLM/SGLang."
},
)
lora_rank: int = field(default=32, metadata={"help": "lora rank"})
lora_alpha: int = field(default=16, metadata={"help": "lora alpha"})
target_modules: list[str] = field(
default_factory=list,
metadata={"help": "lora target_modules."},
)
peft_type: str = field(
default="lora",
metadata={"help": "peft method type. Only LoRA is supported for now."},
)
# Tree training
enable_tree_training: bool = field(
default=False,
metadata={"help": "Enable tree training with flex attention module."},
)
# Scheduling
scheduling_spec: tuple[SchedulingSpec, ...] = field(
default_factory=lambda: (
SchedulingSpec(cmd="python -m areal.scheduler.rpc.rpc_server"),
),
metadata={
"help": "Train engine schedule specs. Can accept 1 or 2 SchedulingSpec: "
"if 1 spec provided, it's used for both worker and engine, engine is embedded in the worker; "
"if 2 specs provided, first one is for worker, second one is for engine. "
"Currently only used by the TrainController."
},
)
scheduling_strategy: SchedulingStrategy = field(
default_factory=SchedulingStrategy,
metadata={
"help": "The scheduling strategy of this TrainEngine, either separation or colocation. "
"Currently only used by the TrainController."
},
)
def __post_init__(self):
"""Validate scheduling_spec length and config combinations."""
if len(self.scheduling_spec) not in (1, 2):
raise ValueError(
f"scheduling_spec must contain 1 or 2 SchedulingSpec, "
f"got {len(self.scheduling_spec)}"
)
if self.fsdp.memory_efficient_load and self.init_from_scratch:
raise ValueError(
"memory_efficient_load cannot be used with init_from_scratch=True. "
"memory_efficient_load is for loading pretrained weights on CPU, "
"but init_from_scratch creates a model without loading any weights."
)
@dataclass
class PPOActorConfig(TrainEngineConfig):
"""Configuration for PPO actor model, a subclass of a TrainEngine."""
# Core PPO/GRPO Parameters
ppo_n_minibatches: int = field(
default=4, metadata={"help": "Number of minibatches for each PPO update"}
)
eps_clip: float = field(
default=0.2, metadata={"help": "Clipping factor for policy ratio"}
)
eps_clip_higher: float | None = field(
default=None,
metadata={
"help": "Clipping factor (higher value) for policy ratio. Default is None. When eps_clip_higher is set (decoupled), eps_clip will be used as the lower value."
},
)
c_clip: float | None = field(
default=None,
metadata={
"help": "Dual clipping factor for policy ratio, must be > 1.0. None disables dual clipping."
},
)
# M2PO
m2_threshold: float | None = field(
default=None, metadata={"help": "The second momentum threshold for M2PO."}
)
# Reward
reward_norm: NormConfig | None = field(
default=None,
metadata={"help": "Normalization configuration for rewards"},
)
reward_scaling: float = field(
default=1.0, metadata={"help": "Reward scaling factor"}
)
reward_bias: float = field(default=0.0, metadata={"help": "Reward bias"})
reward_clip: float = field(
default=20.0, metadata={"help": "Maximum absolute value for reward clipping"}
)
overlong_reward_penalty: bool = field(
default=False,
metadata={"help": "Penalty for overlong sequences. Used within DAPO."},
)
overlong_tokens: int | None = field(
default=None,
metadata={"help": "Number of tokens in the tail that will receive a penalty"},
)
overlong_penalty_factor: float | None = field(
default=None,
metadata={"help": "Penalty factor for tokens in the tail"},
)
mask_no_eos_with_zero: bool = field(
default=False,
metadata={
"help": "Mask truncated generations (no EOS token) and exclude from training"
},
)
# Advantage Estimation
discount: float = field(
default=1.0, metadata={"help": "Discount factor for future rewards"}
)
gae_lambda: float = field(
default=1.0, metadata={"help": "Lambda parameter for GAE"}
)
adv_norm: NormConfig | None = field(
default=None, metadata={"help": "Normalization configuration for advantages."}
)
# KL Control
kl_ctl: float = field(default=0.1, metadata={"help": "KL divergence coefficient"})
kl_estimator: str = field(
default="k1",
metadata={"help": "KL divergence estimator", "choices": ["k1", "k2", "k3"]},
)
# SAPO (Soft Adaptive Policy Optimization) - https://arxiv.org/abs/2511.20347
use_sapo_loss: bool = field(
default=False,
metadata={"help": "Use SAPO loss (mutually exclusive with PPO clipping)"},
)
sapo_tau_pos: float = field(
default=1.0,
metadata={"help": "SAPO temperature for positive advantages"},