Skip to content

Commit 8c8a8db

Browse files
authored
feat(utils): add Karmarkar-Karp partitioning algorithm for sequence packing (#1151)
Add KK (Largest Differencing Method) as an alternative to FFD for micro-batch allocation. KK produces more balanced partitions with lower max-min spread, beneficial for RL workloads with variable sequence lengths. Key changes: - Add _KKSet, _KKState, _kk_partition, kk_allocate in seqpack.py - Add packing_algorithm field to MicroBatchSpec (ffd/kk) - Wire KK allocation through dist_rollout and data utils - Add sequence_packing docs (en/zh) and CLI reference updates - Add comprehensive unit tests and torchrun benchmark Refs: #1151
1 parent bc9f009 commit 8c8a8db

12 files changed

Lines changed: 1408 additions & 20 deletions

File tree

areal/api/cli_args.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
PROX_LOGP_METHODS_ALL,
2828
)
2929
from areal.utils.pkg_version import is_version_less
30+
from areal.utils.seqpack import PACKING_ALGORITHMS
3031

3132
if TYPE_CHECKING:
3233
from transformers import PreTrainedTokenizerFast
@@ -123,6 +124,27 @@ class MicroBatchSpec:
123124
"help": "Divisor for the number of micro-batches. The final number of micro-batches will be adjusted to be divisible by this value.",
124125
},
125126
)
127+
packing_algorithm: str = field(
128+
default="ffd",
129+
metadata={
130+
"help": (
131+
"Sequence packing algorithm for micro-batch allocation. "
132+
"Supported values: 'ffd' (First Fit Decreasing, default), "
133+
"'kk' (Karmarkar-Karp, better balance but slightly slower). "
134+
"KK is recommended when workload balance across DP ranks is "
135+
"critical (e.g., large-scale RL training with variable-length sequences)."
136+
),
137+
"choices": ["ffd", "kk"],
138+
},
139+
)
140+
141+
def __post_init__(self):
142+
"""Validate packing algorithm configuration."""
143+
if self.packing_algorithm not in PACKING_ALGORITHMS:
144+
raise ValueError(
145+
f"packing_algorithm must be one of {sorted(PACKING_ALGORITHMS)}, "
146+
f"got '{self.packing_algorithm}'"
147+
)
126148

127149
@classmethod
128150
def new(cls, mb_spec: "MicroBatchSpec", **kwargs):
@@ -132,6 +154,7 @@ def new(cls, mb_spec: "MicroBatchSpec", **kwargs):
132154
granularity=mb_spec.granularity,
133155
max_tokens_per_mb=mb_spec.max_tokens_per_mb,
134156
n_mbs_divisor=mb_spec.n_mbs_divisor,
157+
packing_algorithm=mb_spec.packing_algorithm,
135158
)
136159
fields.update(kwargs)
137160
return cls(**fields)

areal/infra/dist_rollout.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
split_and_unpad_tensor,
1616
tensor_container_to,
1717
)
18-
from areal.utils.seqpack import ffd_allocate
18+
from areal.utils.seqpack import get_allocate_fn
1919

2020

2121
@dataclass
@@ -29,6 +29,7 @@ class RedistributedData:
2929
def redistribute_trajectories(
3030
trajectories: list[dict[str, Any]],
3131
group=None,
32+
packing_algorithm: str = "ffd",
3233
) -> RedistributedData:
3334
"""Redistribute a list of trajectory dicts across a process group.
3435
@@ -43,6 +44,8 @@ def redistribute_trajectories(
4344
contains tensors with shape [batch_size, seqlen, ...].
4445
group : dist.ProcessGroup, optional
4546
The process group for communication. If None, uses the default group.
47+
packing_algorithm : str, optional
48+
Packing algorithm to use ("ffd" or "kk"). Default is "ffd".
4649
4750
Returns
4851
-------
@@ -73,9 +76,10 @@ def redistribute_trajectories(
7376
for d in all_data
7477
]
7578

76-
# Allocate trajectories to ranks using first-fit-decreasing
79+
allocate_fn = get_allocate_fn(packing_algorithm)
80+
# Allocate trajectories to ranks using the configured packing algorithm
7781
# No capacity limit leads to balanced partition across this group
78-
group_indices = ffd_allocate(
82+
group_indices = allocate_fn(
7983
seqlens, capacity=int(1e12), min_groups=dist.get_world_size(group)
8084
)
8185
local_indices = group_indices[dist.get_rank(group=group)]
@@ -119,9 +123,13 @@ def _broadcast_and_redistribute_trajectories(
119123
Redistributed and broadcast batch available on all ranks (list of trajs)
120124
"""
121125
if trajectories is not None:
126+
config = getattr(self.train_engine, "config", None)
127+
mb_spec = getattr(config, "mb_spec", None)
128+
packing_algorithm = getattr(mb_spec, "packing_algorithm", "ffd")
122129
redist = redistribute_trajectories(
123130
trajectories,
124131
group=self.train_engine.data_parallel_group,
132+
packing_algorithm=packing_algorithm,
125133
)
126134
batch = redist.data
127135
else:

areal/utils/data.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from areal.infra.platforms import current_platform
2020
from areal.utils import logging, seqpack
2121
from areal.utils.math import align
22+
from areal.utils.seqpack import get_allocate_fn
2223

2324
logger = logging.getLogger("DataUtils")
2425

@@ -445,8 +446,23 @@ def unpack_sequence(
445446

446447

447448
def allocate_balanced_mbs(mb_spec: MicroBatchSpec, lens: list[int]) -> list[list[int]]:
449+
"""Allocate sequences into balanced micro-batches using the configured algorithm.
450+
451+
The packing algorithm is determined by ``mb_spec.packing_algorithm``:
452+
- ``"ffd"`` (default): First Fit Decreasing — fast greedy heuristic.
453+
- ``"kk"``: Karmarkar-Karp — produces more balanced partitions at a
454+
slight computational cost.
455+
456+
Args:
457+
mb_spec: MicroBatchSpec containing packing configuration.
458+
lens: List of sequence lengths to allocate.
459+
460+
Returns:
461+
List of lists of indices, one per micro-batch.
462+
"""
448463
assert mb_spec.max_tokens_per_mb is not None
449-
group_indices = seqpack.ffd_allocate(
464+
allocate_fn = get_allocate_fn(getattr(mb_spec, "packing_algorithm", "ffd"))
465+
group_indices = allocate_fn(
450466
lens,
451467
mb_spec.max_tokens_per_mb,
452468
min_groups=mb_spec.n_mbs,

0 commit comments

Comments
 (0)