Skip to content

Commit 1dfa07c

Browse files
author
The tunix Authors
committed
Merge pull request #1430 from google:nicogrande/fix-delete-buffers
PiperOrigin-RevId: 906544403
2 parents 4cf1448 + dda8673 commit 1dfa07c

3 files changed

Lines changed: 235 additions & 36 deletions

File tree

tests/generate/utils_test.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1466,6 +1466,135 @@ def test_transfer_state_directly_fuses_moe_weights_scanned_to_unrolled(self):
14661466
jnp.concatenate([wi_0_val[1], wi_1_val[1]], axis=-1),
14671467
)
14681468

1469+
def test_transfer_state_directly_delete_dst_buffers_no_chunking(self):
1470+
"""delete_dst_buffers=True must never pass deleted arrays to reshard_fn."""
1471+
src_val = jnp.array([1.0, 2.0, 3.0])
1472+
src_state = nnx.Dict(
1473+
decoder=nnx.Dict(layer0=nnx.Dict(weight=nnx.Param(src_val)))
1474+
)
1475+
dst_state = nnx.Dict(
1476+
decoder=nnx.Dict(
1477+
layer0=nnx.Dict(weight=nnx.Param(jnp.zeros(3, dtype=jnp.float32)))
1478+
)
1479+
)
1480+
1481+
inspected_targets = []
1482+
1483+
def reshard_fn(source, target):
1484+
inspected_targets.extend(jax.tree_util.tree_leaves(target))
1485+
return source
1486+
1487+
utils.transfer_state_directly(
1488+
src_state, dst_state, reshard_fn=reshard_fn, delete_dst_buffers=True
1489+
)
1490+
1491+
self.assertNotEmpty(inspected_targets)
1492+
for leaf in inspected_targets:
1493+
# Pre-fix this would have been a (possibly deleted) jax.Array.
1494+
self.assertIsInstance(
1495+
leaf, (NamedSharding, sharding.SingleDeviceSharding)
1496+
)
1497+
np.testing.assert_array_equal(
1498+
dst_state['decoder']['layer0']['weight'][...], src_val
1499+
)
1500+
1501+
def test_transfer_state_directly_delete_dst_buffers_chunked(self):
1502+
"""delete_dst_buffers=True works through the chunked path too."""
1503+
src_state = nnx.Dict(
1504+
decoder=nnx.Dict(**{
1505+
f'layer{i}': nnx.Dict(weight=nnx.Param(jnp.array([float(i + 1)])))
1506+
for i in range(4)
1507+
})
1508+
)
1509+
dst_state = nnx.Dict(
1510+
decoder=nnx.Dict(**{
1511+
f'layer{i}': nnx.Dict(weight=nnx.Param(jnp.array([0.0])))
1512+
for i in range(4)
1513+
})
1514+
)
1515+
1516+
inspected_targets = []
1517+
1518+
def reshard_fn(source, target):
1519+
inspected_targets.extend(jax.tree_util.tree_leaves(target))
1520+
return source
1521+
1522+
utils.transfer_state_directly(
1523+
src_state,
1524+
dst_state,
1525+
reshard_fn=reshard_fn,
1526+
delete_dst_buffers=True,
1527+
reshard_chunk_size=2,
1528+
)
1529+
1530+
self.assertNotEmpty(inspected_targets)
1531+
for leaf in inspected_targets:
1532+
self.assertIsInstance(
1533+
leaf, (NamedSharding, sharding.SingleDeviceSharding)
1534+
)
1535+
for i in range(4):
1536+
np.testing.assert_array_equal(
1537+
dst_state['decoder'][f'layer{i}']['weight'][...],
1538+
jnp.array([float(i + 1)]),
1539+
)
1540+
1541+
def test_transfer_state_directly_delete_dst_buffers_skips_aliased_buffers(
1542+
self,
1543+
):
1544+
"""When src and dst Variables share a backing jax.Array, skip deletion."""
1545+
shared = jnp.array([1.0, 2.0, 3.0])
1546+
src_state = nnx.Dict(
1547+
decoder=nnx.Dict(layer0=nnx.Dict(weight=nnx.Param(shared)))
1548+
)
1549+
# Same backing jax.Array object on both sides — typical of collocated
1550+
# trainer/sampler setups where the rollout state aliases trainer weights.
1551+
dst_state = nnx.Dict(
1552+
decoder=nnx.Dict(layer0=nnx.Dict(weight=nnx.Param(shared)))
1553+
)
1554+
1555+
mock_reshard = lambda source, target: source
1556+
utils.transfer_state_directly(
1557+
src_state,
1558+
dst_state,
1559+
reshard_fn=mock_reshard,
1560+
delete_dst_buffers=True,
1561+
)
1562+
1563+
# If deletion misfired the next access raises "Array has been deleted".
1564+
np.testing.assert_array_equal(np.asarray(shared), [1.0, 2.0, 3.0])
1565+
np.testing.assert_array_equal(
1566+
dst_state['decoder']['layer0']['weight'][...], [1.0, 2.0, 3.0]
1567+
)
1568+
1569+
def test_transfer_state_directly_delete_dst_buffers_scanned_layers(self):
1570+
"""Unstacked-slice targets remain valid after delete_dst_buffers=True."""
1571+
scanned = jnp.arange(8, dtype=jnp.float32).reshape(2, 4)
1572+
src_state = nnx.Dict(layers=nnx.Dict(weight=nnx.Param(scanned)))
1573+
dst_state = nnx.Dict(**{
1574+
'layers_0': nnx.Dict(
1575+
weight=nnx.Param(jnp.zeros(4, dtype=jnp.float32))
1576+
),
1577+
'layers_1': nnx.Dict(
1578+
weight=nnx.Param(jnp.zeros(4, dtype=jnp.float32))
1579+
),
1580+
})
1581+
1582+
mock_reshard = lambda source, target: source
1583+
utils.transfer_state_directly(
1584+
src_state,
1585+
dst_state,
1586+
reshard_fn=mock_reshard,
1587+
scan_axis=0,
1588+
delete_dst_buffers=True,
1589+
)
1590+
1591+
np.testing.assert_array_equal(
1592+
dst_state['layers_0']['weight'][...], scanned[0]
1593+
)
1594+
np.testing.assert_array_equal(
1595+
dst_state['layers_1']['weight'][...], scanned[1]
1596+
)
1597+
14691598

14701599
class ResolveParallelismSizesTest(parameterized.TestCase):
14711600

tunix/generate/utils.py

Lines changed: 103 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1050,6 +1050,63 @@ def _fuse_moe_weights(src_flat: Dict[Tuple[str, ...], Any], tgt_flat: Dict[Tuple
10501050
return new_src_flat
10511051

10521052

1053+
def _collect_src_buffer_ids(src_flat: Mapping[Any, Any]) -> set:
1054+
"""Collects physical device buffer pointers for arrays in src_flat.
1055+
1056+
Used to detect when a target jax.Array shares its underlying buffer with a
1057+
source array — Python identity (`is`) is insufficient because two distinct
1058+
jax.Array wrappers can back the same physical shard (e.g. when source slices
1059+
come from a scanned tensor that also backs another spec entry).
1060+
"""
1061+
ids = set()
1062+
for v in src_flat.values():
1063+
arr = v.value if hasattr(v, 'value') else v
1064+
if not hasattr(arr, 'addressable_shards'):
1065+
continue
1066+
for shard in arr.addressable_shards:
1067+
try:
1068+
ids.add(shard.data.unsafe_buffer_pointer())
1069+
except Exception: # pylint: disable=broad-except
1070+
pass
1071+
return ids
1072+
1073+
1074+
def _delete_target_buffers(
1075+
spec_flat: Mapping[Any, Any],
1076+
src_flat: Mapping[Any, Any],
1077+
) -> None:
1078+
"""Deletes target arrays in spec_flat that don't alias any source shard."""
1079+
src_buffer_ids = _collect_src_buffer_ids(src_flat)
1080+
for tgt_val in spec_flat.values():
1081+
tgt_arr = tgt_val.value if hasattr(tgt_val, 'value') else tgt_val
1082+
if not hasattr(tgt_arr, 'delete') or getattr(
1083+
tgt_arr, 'is_deleted', lambda: False
1084+
)():
1085+
continue
1086+
if hasattr(tgt_arr, 'addressable_shards') and any(
1087+
shard.data.unsafe_buffer_pointer() in src_buffer_ids
1088+
for shard in tgt_arr.addressable_shards
1089+
):
1090+
continue
1091+
tgt_arr.delete()
1092+
1093+
1094+
def _snapshot_dst_sharding(arr: Any) -> Any:
1095+
"""Snapshots a destination sharding leaf for reshard_fn's target tree.
1096+
1097+
Captured *before* any potential `.delete()` on `arr` so the caller never
1098+
needs to dereference a deleted jax.Array later. `reshard_pytree`'s
1099+
`_get_dst_sharding` accepts `NamedSharding` / `SingleDeviceSharding` leaves
1100+
directly, so for those we return the existing sharding object (no rebuild).
1101+
"""
1102+
s = arr.sharding
1103+
if isinstance(
1104+
s, (jax.sharding.NamedSharding, jax.sharding.SingleDeviceSharding)
1105+
):
1106+
return s
1107+
return jax.sharding.NamedSharding(s.mesh, s.spec, memory_kind=s.memory_kind)
1108+
1109+
10531110
def _reshard_in_chunks(
10541111
src_flat: Dict[Tuple[str, ...], Any],
10551112
spec_flat: Dict[Tuple[str, ...], Any],
@@ -1085,29 +1142,32 @@ def _reshard_in_chunks(
10851142
chunk_keys = keys[start : start + chunk_size]
10861143
chunk_src_flat = {}
10871144
chunk_spec_flat = {}
1145+
chunk_dst_shardings_flat = {}
10881146
for k in chunk_keys:
10891147
src_val = src_flat.pop(k)
10901148
tgt_val = spec_flat[k]
10911149
chunk_src_flat[k] = src_val
10921150
chunk_spec_flat[k] = tgt_val
1151+
tgt_arr = tgt_val.value if hasattr(tgt_val, 'value') else tgt_val
1152+
chunk_dst_shardings_flat[k] = _snapshot_dst_sharding(tgt_arr)
10931153

1094-
if delete_spec_buffers:
1095-
tgt_arr = tgt_val.value if hasattr(tgt_val, 'value') else tgt_val
1096-
src_arr = src_val.value if hasattr(src_val, 'value') else src_val
1097-
if (
1098-
hasattr(tgt_arr, 'delete')
1099-
and not getattr(tgt_arr, 'is_deleted', lambda: False)()
1100-
and tgt_arr is not src_arr
1101-
):
1102-
tgt_arr.delete()
1154+
if delete_spec_buffers:
1155+
_delete_target_buffers(chunk_spec_flat, chunk_src_flat)
11031156

11041157
chunk_src = traverse_util.unflatten_dict(chunk_src_flat)
1105-
chunk_spec = traverse_util.unflatten_dict(chunk_spec_flat)
1106-
chunk_resharded = reshard_fn(source=chunk_src, target=chunk_spec)
1158+
chunk_dst_shardings = traverse_util.unflatten_dict(chunk_dst_shardings_flat)
1159+
chunk_resharded = reshard_fn(source=chunk_src, target=chunk_dst_shardings)
11071160
jax.block_until_ready(chunk_resharded)
11081161
resharded.update(traverse_util.flatten_dict(chunk_resharded))
11091162

1110-
del chunk_src, chunk_spec, chunk_resharded, chunk_src_flat, chunk_spec_flat
1163+
del (
1164+
chunk_src,
1165+
chunk_dst_shardings,
1166+
chunk_resharded,
1167+
chunk_src_flat,
1168+
chunk_spec_flat,
1169+
chunk_dst_shardings_flat,
1170+
)
11111171
return resharded
11121172

11131173

@@ -1133,13 +1193,18 @@ def transfer_state_directly(
11331193
dst_state: The destination state to transfer to.
11341194
reshard_fn: A function to shard the values.
11351195
scan_axis: The axis along which to unroll scanned layers, if needed.
1136-
delete_dst_buffers: Whether to delete buffers in the destination state after
1137-
transfer to save memory.
1196+
delete_dst_buffers: Whether to delete buffers in the destination state
1197+
before resharding to save HBM. Buffers that physically alias a source
1198+
shard are preserved automatically (see `_delete_target_buffers`).
11381199
reshard_chunk_size: When set, the final reshard is split into sequential
1139-
groups of this many flat keys instead of one monolithic call. This reduces
1140-
peak contiguous HBM pressure, which prevents XLA allocator fragmentation
1141-
errors on large models. When None (default) the original single-call
1142-
reshard behavior is preserved.
1200+
groups of this many flat keys instead of one monolithic call. This
1201+
reduces peak contiguous HBM pressure, which prevents XLA allocator
1202+
fragmentation on large models. The unit is *number of flat keys per
1203+
chunk* — per-layer key counts vary by architecture (MQA vs GQA, biases
1204+
on/off, dense vs MoE, fused vs split MoE gates), so as a rule of thumb
1205+
start with roughly `10 * num_layers` for a dense transformer and tune
1206+
downward if you still see fragmentation. When None (default) the
1207+
original single-call reshard behavior is preserved.
11431208
"""
11441209
def safe_has_key(obj: Mapping[str, Any], key: str) -> bool:
11451210
if isinstance(obj, dict):
@@ -1322,9 +1387,9 @@ def intersect_trees(
13221387
# Reshard and Update
13231388
if reshard_chunk_size is not None:
13241389
# Chunked path: split the flat weight dict into groups of reshard_chunk_size
1325-
# keys and reshard each group independently. This keeps peak contiguous HBM
1326-
# allocation proportional to chunk_size, avoiding XLA fragmentation errors
1327-
# on large models without needing to clear the compilation cache.
1390+
# entries and reshard each group independently. This keeps peak contiguous
1391+
# HBM allocation proportional to chunk_size, avoiding XLA fragmentation
1392+
# errors on large models without needing to clear the compilation cache.
13281393
src_flat = traverse_util.flatten_dict(final_source)
13291394
spec_flat = traverse_util.flatten_dict(final_spec)
13301395
del final_source, final_spec
@@ -1337,24 +1402,27 @@ def intersect_trees(
13371402
)
13381403
resharded_weights = traverse_util.unflatten_dict(resharded_flat)
13391404
else:
1405+
src_flat = traverse_util.flatten_dict(final_source)
1406+
spec_flat = traverse_util.flatten_dict(final_spec)
1407+
1408+
# Snapshot dst shardings before any deletion so reshard_fn never has to
1409+
# touch a deleted jax.Array. reshard_pytree's _get_dst_sharding accepts
1410+
# NamedSharding leaves directly, so this is a drop-in substitute for
1411+
# passing the array objects.
1412+
dst_shardings_flat = {
1413+
k: _snapshot_dst_sharding(
1414+
tgt_val.value if hasattr(tgt_val, 'value') else tgt_val
1415+
)
1416+
for k, tgt_val in spec_flat.items()
1417+
}
1418+
13401419
if delete_dst_buffers:
1341-
src_flat = traverse_util.flatten_dict(final_source)
1342-
spec_flat = traverse_util.flatten_dict(final_spec)
1343-
for k, tgt_val in spec_flat.items():
1344-
if k in src_flat:
1345-
src_val = src_flat[k]
1346-
tgt_arr = tgt_val.value if hasattr(tgt_val, 'value') else tgt_val
1347-
src_arr = src_val.value if hasattr(src_val, 'value') else src_val
1348-
if (
1349-
hasattr(tgt_arr, 'delete')
1350-
and not getattr(tgt_arr, 'is_deleted', lambda: False)()
1351-
and tgt_arr is not src_arr
1352-
):
1353-
tgt_arr.delete()
1420+
_delete_target_buffers(spec_flat, src_flat)
13541421

1422+
del final_spec
13551423
resharded_weights = reshard_fn(
13561424
source=final_source,
1357-
target=final_spec,
1425+
target=traverse_util.unflatten_dict(dst_shardings_flat),
13581426
)
13591427
nnx.update(dst_state, resharded_weights)
13601428

tunix/rl/rollout/base_rollout.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,9 @@ class RolloutConfig:
158158
# Maximum number of concurrent sequences allowed to be processed in vLLM.
159159
rollout_vllm_max_num_seqs: Optional[int] = None
160160

161-
# Numbers of keys to reshard at a time when synchronizing weights between trainer and vLLM model.
161+
# Number of flat keys to reshard at a time when synchronizing weights between
162+
# trainer and vLLM model. None (default) reshards the whole model in one call.
163+
# Set to a smaller value to reduce peak HBM pressure on large models.
162164
rollout_vllm_reshard_chunk_size: Optional[int] = None
163165

164166
# Additional keyword arguments forwarded directly to the vLLM engine constructor.

0 commit comments

Comments
 (0)