@@ -1050,6 +1050,63 @@ def _fuse_moe_weights(src_flat: Dict[Tuple[str, ...], Any], tgt_flat: Dict[Tuple
10501050 return new_src_flat
10511051
10521052
1053+ def _collect_src_buffer_ids (src_flat : Mapping [Any , Any ]) -> set :
1054+ """Collects physical device buffer pointers for arrays in src_flat.
1055+
1056+ Used to detect when a target jax.Array shares its underlying buffer with a
1057+ source array — Python identity (`is`) is insufficient because two distinct
1058+ jax.Array wrappers can back the same physical shard (e.g. when source slices
1059+ come from a scanned tensor that also backs another spec entry).
1060+ """
1061+ ids = set ()
1062+ for v in src_flat .values ():
1063+ arr = v .value if hasattr (v , 'value' ) else v
1064+ if not hasattr (arr , 'addressable_shards' ):
1065+ continue
1066+ for shard in arr .addressable_shards :
1067+ try :
1068+ ids .add (shard .data .unsafe_buffer_pointer ())
1069+ except Exception : # pylint: disable=broad-except
1070+ pass
1071+ return ids
1072+
1073+
1074+ def _delete_target_buffers (
1075+ spec_flat : Mapping [Any , Any ],
1076+ src_flat : Mapping [Any , Any ],
1077+ ) -> None :
1078+ """Deletes target arrays in spec_flat that don't alias any source shard."""
1079+ src_buffer_ids = _collect_src_buffer_ids (src_flat )
1080+ for tgt_val in spec_flat .values ():
1081+ tgt_arr = tgt_val .value if hasattr (tgt_val , 'value' ) else tgt_val
1082+ if not hasattr (tgt_arr , 'delete' ) or getattr (
1083+ tgt_arr , 'is_deleted' , lambda : False
1084+ )():
1085+ continue
1086+ if hasattr (tgt_arr , 'addressable_shards' ) and any (
1087+ shard .data .unsafe_buffer_pointer () in src_buffer_ids
1088+ for shard in tgt_arr .addressable_shards
1089+ ):
1090+ continue
1091+ tgt_arr .delete ()
1092+
1093+
1094+ def _snapshot_dst_sharding (arr : Any ) -> Any :
1095+ """Snapshots a destination sharding leaf for reshard_fn's target tree.
1096+
1097+ Captured *before* any potential `.delete()` on `arr` so the caller never
1098+ needs to dereference a deleted jax.Array later. `reshard_pytree`'s
1099+ `_get_dst_sharding` accepts `NamedSharding` / `SingleDeviceSharding` leaves
1100+ directly, so for those we return the existing sharding object (no rebuild).
1101+ """
1102+ s = arr .sharding
1103+ if isinstance (
1104+ s , (jax .sharding .NamedSharding , jax .sharding .SingleDeviceSharding )
1105+ ):
1106+ return s
1107+ return jax .sharding .NamedSharding (s .mesh , s .spec , memory_kind = s .memory_kind )
1108+
1109+
10531110def _reshard_in_chunks (
10541111 src_flat : Dict [Tuple [str , ...], Any ],
10551112 spec_flat : Dict [Tuple [str , ...], Any ],
@@ -1085,29 +1142,32 @@ def _reshard_in_chunks(
10851142 chunk_keys = keys [start : start + chunk_size ]
10861143 chunk_src_flat = {}
10871144 chunk_spec_flat = {}
1145+ chunk_dst_shardings_flat = {}
10881146 for k in chunk_keys :
10891147 src_val = src_flat .pop (k )
10901148 tgt_val = spec_flat [k ]
10911149 chunk_src_flat [k ] = src_val
10921150 chunk_spec_flat [k ] = tgt_val
1151+ tgt_arr = tgt_val .value if hasattr (tgt_val , 'value' ) else tgt_val
1152+ chunk_dst_shardings_flat [k ] = _snapshot_dst_sharding (tgt_arr )
10931153
1094- if delete_spec_buffers :
1095- tgt_arr = tgt_val .value if hasattr (tgt_val , 'value' ) else tgt_val
1096- src_arr = src_val .value if hasattr (src_val , 'value' ) else src_val
1097- if (
1098- hasattr (tgt_arr , 'delete' )
1099- and not getattr (tgt_arr , 'is_deleted' , lambda : False )()
1100- and tgt_arr is not src_arr
1101- ):
1102- tgt_arr .delete ()
1154+ if delete_spec_buffers :
1155+ _delete_target_buffers (chunk_spec_flat , chunk_src_flat )
11031156
11041157 chunk_src = traverse_util .unflatten_dict (chunk_src_flat )
1105- chunk_spec = traverse_util .unflatten_dict (chunk_spec_flat )
1106- chunk_resharded = reshard_fn (source = chunk_src , target = chunk_spec )
1158+ chunk_dst_shardings = traverse_util .unflatten_dict (chunk_dst_shardings_flat )
1159+ chunk_resharded = reshard_fn (source = chunk_src , target = chunk_dst_shardings )
11071160 jax .block_until_ready (chunk_resharded )
11081161 resharded .update (traverse_util .flatten_dict (chunk_resharded ))
11091162
1110- del chunk_src , chunk_spec , chunk_resharded , chunk_src_flat , chunk_spec_flat
1163+ del (
1164+ chunk_src ,
1165+ chunk_dst_shardings ,
1166+ chunk_resharded ,
1167+ chunk_src_flat ,
1168+ chunk_spec_flat ,
1169+ chunk_dst_shardings_flat ,
1170+ )
11111171 return resharded
11121172
11131173
@@ -1133,13 +1193,18 @@ def transfer_state_directly(
11331193 dst_state: The destination state to transfer to.
11341194 reshard_fn: A function to shard the values.
11351195 scan_axis: The axis along which to unroll scanned layers, if needed.
1136- delete_dst_buffers: Whether to delete buffers in the destination state after
1137- transfer to save memory.
1196+ delete_dst_buffers: Whether to delete buffers in the destination state
1197+ before resharding to save HBM. Buffers that physically alias a source
1198+ shard are preserved automatically (see `_delete_target_buffers`).
11381199 reshard_chunk_size: When set, the final reshard is split into sequential
1139- groups of this many flat keys instead of one monolithic call. This reduces
1140- peak contiguous HBM pressure, which prevents XLA allocator fragmentation
1141- errors on large models. When None (default) the original single-call
1142- reshard behavior is preserved.
1200+ groups of this many flat keys instead of one monolithic call. This
1201+ reduces peak contiguous HBM pressure, which prevents XLA allocator
1202+ fragmentation on large models. The unit is *number of flat keys per
1203+ chunk* — per-layer key counts vary by architecture (MQA vs GQA, biases
1204+ on/off, dense vs MoE, fused vs split MoE gates), so as a rule of thumb
1205+ start with roughly `10 * num_layers` for a dense transformer and tune
1206+ downward if you still see fragmentation. When None (default) the
1207+ original single-call reshard behavior is preserved.
11431208 """
11441209 def safe_has_key (obj : Mapping [str , Any ], key : str ) -> bool :
11451210 if isinstance (obj , dict ):
@@ -1322,9 +1387,9 @@ def intersect_trees(
13221387 # Reshard and Update
13231388 if reshard_chunk_size is not None :
13241389 # Chunked path: split the flat weight dict into groups of reshard_chunk_size
1325- # keys and reshard each group independently. This keeps peak contiguous HBM
1326- # allocation proportional to chunk_size, avoiding XLA fragmentation errors
1327- # on large models without needing to clear the compilation cache.
1390+ # entries and reshard each group independently. This keeps peak contiguous
1391+ # HBM allocation proportional to chunk_size, avoiding XLA fragmentation
1392+ # errors on large models without needing to clear the compilation cache.
13281393 src_flat = traverse_util .flatten_dict (final_source )
13291394 spec_flat = traverse_util .flatten_dict (final_spec )
13301395 del final_source , final_spec
@@ -1337,24 +1402,27 @@ def intersect_trees(
13371402 )
13381403 resharded_weights = traverse_util .unflatten_dict (resharded_flat )
13391404 else :
1405+ src_flat = traverse_util .flatten_dict (final_source )
1406+ spec_flat = traverse_util .flatten_dict (final_spec )
1407+
1408+ # Snapshot dst shardings before any deletion so reshard_fn never has to
1409+ # touch a deleted jax.Array. reshard_pytree's _get_dst_sharding accepts
1410+ # NamedSharding leaves directly, so this is a drop-in substitute for
1411+ # passing the array objects.
1412+ dst_shardings_flat = {
1413+ k : _snapshot_dst_sharding (
1414+ tgt_val .value if hasattr (tgt_val , 'value' ) else tgt_val
1415+ )
1416+ for k , tgt_val in spec_flat .items ()
1417+ }
1418+
13401419 if delete_dst_buffers :
1341- src_flat = traverse_util .flatten_dict (final_source )
1342- spec_flat = traverse_util .flatten_dict (final_spec )
1343- for k , tgt_val in spec_flat .items ():
1344- if k in src_flat :
1345- src_val = src_flat [k ]
1346- tgt_arr = tgt_val .value if hasattr (tgt_val , 'value' ) else tgt_val
1347- src_arr = src_val .value if hasattr (src_val , 'value' ) else src_val
1348- if (
1349- hasattr (tgt_arr , 'delete' )
1350- and not getattr (tgt_arr , 'is_deleted' , lambda : False )()
1351- and tgt_arr is not src_arr
1352- ):
1353- tgt_arr .delete ()
1420+ _delete_target_buffers (spec_flat , src_flat )
13541421
1422+ del final_spec
13551423 resharded_weights = reshard_fn (
13561424 source = final_source ,
1357- target = final_spec ,
1425+ target = traverse_util . unflatten_dict ( dst_shardings_flat ) ,
13581426 )
13591427 nnx .update (dst_state , resharded_weights )
13601428
0 commit comments