Skip to content

Commit 6d86e74

Browse files
Merge pull request #3639 from AI-Hypercomputer:chengnuojin-no-exp5
PiperOrigin-RevId: 899212298
2 parents 50ba0a1 + eed004f commit 6d86e74

61 files changed

Lines changed: 4310 additions & 64669 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

tests/unit/custom_mesh_and_rule_test.py

Lines changed: 0 additions & 69 deletions
This file was deleted.

tests/unit/sharding_compare_test.py

Lines changed: 47 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -114,8 +114,10 @@ def compare_sharding_jsons(json1: dict, model1_name: str, json2: dict, model2_na
114114
# Requires JAX TPU support to generate the simulated TPU topology.
115115
@pytest.mark.cpu_only
116116
@pytest.mark.tpu_backend
117-
@pytest.mark.parametrize("model_name, topology, num_slice", TEST_CASES)
118-
def test_sharding_dump_for_model(model_name: str, topology: str, num_slice: str) -> None:
117+
@pytest.mark.parametrize("model_name, topology, num_slice, custom_mesh_and_rule, overrides", TEST_CASES)
118+
def test_sharding_dump_for_model(
119+
model_name: str, topology: str, num_slice: str, custom_mesh_and_rule: str, overrides: tuple
120+
) -> None:
119121
"""
120122
Test sharding configurations from train_compile.get_shaped_inputs.
121123
This test verifies that the sharding configurations for various models and topologies remain consistent with golden files.
@@ -132,9 +134,16 @@ def test_sharding_dump_for_model(model_name: str, topology: str, num_slice: str)
132134
"enable_nnx=False",
133135
"pure_nnx_decoder=False",
134136
]
137+
if custom_mesh_and_rule:
138+
params.append(f"custom_mesh_and_rule={custom_mesh_and_rule}")
139+
if overrides:
140+
params.extend(overrides)
135141

136142
root_dir = "tests/utils/sharding_info"
137-
base_path = os.path.join(root_dir, model_name, topology, f"slice_{num_slice}")
143+
rule_name = f"rule_{custom_mesh_and_rule}" if custom_mesh_and_rule else "rule_default"
144+
if overrides:
145+
rule_name += "_" + "_".join(overrides)
146+
base_path = os.path.join(root_dir, model_name, topology, f"slice_{num_slice}", rule_name)
138147

139148
named_json_path = os.path.join(base_path, "named_shardings.json")
140149
logical_json_path = os.path.join(base_path, "logical_shardings.json")
@@ -206,12 +215,16 @@ def test_sharding_dump_for_model(model_name: str, topology: str, num_slice: str)
206215

207216
@pytest.fixture(
208217
scope="module",
209-
params=[pytest.param(case, id=f"{case[0]}-{case[1]}-{case[2]}") for case in TEST_CASES],
218+
params=[pytest.param(case, id=f"{case[0]}-{case[1]}-{case[2]}-{case[3]}-{''.join(case[4])}") for case in TEST_CASES],
210219
)
211220
def abstract_state_and_shardings(request):
212221
"""Pytest fixture to set up model, config, and generate abstract state once per test case."""
213-
model_name, topology, num_slice = request.param
214-
print(f"Testing model: {model_name}, topology: {topology}, num_slices: {num_slice}", flush=True)
222+
model_name, topology, num_slice, custom_mesh_and_rule, overrides = request.param
223+
print(
224+
f"Testing model: {model_name}, topology: {topology}, num_slices: {num_slice}, "
225+
"rule: {custom_mesh_and_rule}, overrides: {overrides}",
226+
flush=True,
227+
)
215228
params = [
216229
"/deps/MaxText/tests/unit/sharding_compare_test",
217230
get_test_config_path(),
@@ -223,6 +236,10 @@ def abstract_state_and_shardings(request):
223236
"enable_nnx=False",
224237
"pure_nnx_decoder=False",
225238
]
239+
if custom_mesh_and_rule:
240+
params.append(f"custom_mesh_and_rule={custom_mesh_and_rule}")
241+
if overrides:
242+
params.extend(overrides)
226243
config = pyconfig.initialize(params)
227244
validate_config(config)
228245

@@ -245,7 +262,16 @@ def abstract_state_and_shardings(request):
245262
# Get logical shardings from maxtext_utils
246263
logical_shardings = maxtext_utils.get_logical_annotations(config, topology_mesh, init_state_fn)
247264

248-
return model_name, topology, num_slice, abstract_state, state_mesh_shardings, logical_shardings
265+
return (
266+
model_name,
267+
topology,
268+
num_slice,
269+
custom_mesh_and_rule,
270+
overrides,
271+
abstract_state,
272+
state_mesh_shardings,
273+
logical_shardings,
274+
)
249275

250276

251277
@pytest.mark.cpu_only
@@ -257,9 +283,16 @@ class TestGetAbstractState:
257283
def test_get_abstract_state_sharding(self, abstract_state_and_shardings): # pylint: disable=redefined-outer-name
258284
"""Tests that get_abstract_state returns a state with the correct abstract structure and compares sharding."""
259285

260-
model_name, topology, num_slice, abstract_state, state_mesh_shardings, logical_shardings = (
261-
abstract_state_and_shardings
262-
)
286+
(
287+
model_name,
288+
topology,
289+
num_slice,
290+
custom_mesh_and_rule,
291+
overrides,
292+
abstract_state,
293+
state_mesh_shardings,
294+
logical_shardings,
295+
) = abstract_state_and_shardings
263296

264297
assert hasattr(abstract_state, "params")
265298
assert hasattr(abstract_state, "opt_state")
@@ -268,7 +301,10 @@ def test_get_abstract_state_sharding(self, abstract_state_and_shardings): # pyl
268301
assert param_leaf.dtype == jnp.float32
269302

270303
root_dir = "tests/utils/sharding_info" # Or your target directory
271-
base_path = os.path.join(root_dir, model_name, topology, f"slice_{num_slice}")
304+
rule_name = f"rule_{custom_mesh_and_rule}" if custom_mesh_and_rule else "rule_default"
305+
if overrides:
306+
rule_name += "_" + "_".join(overrides)
307+
base_path = os.path.join(root_dir, model_name, topology, f"slice_{num_slice}", rule_name)
272308
os.makedirs(base_path, exist_ok=True) # Ensure directory exists for saving actual
273309

274310
error_messages = []

tests/utils/run_sharding_dump.py

Lines changed: 40 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -58,57 +58,68 @@
5858
flags.DEFINE_string("model_name", None, "Specific model name to dump.")
5959
flags.DEFINE_string("topology", None, "Specific topology to dump.")
6060
flags.DEFINE_string("num_slice", None, "Specific number of slices to dump.")
61-
62-
63-
def run_single_dump(model_name: str, topology: str, num_slice: str) -> None:
64-
"""Generate sharding json file for one specific model, topology and slice."""
65-
subprocess.run(
66-
[
67-
"python3",
68-
"-m",
69-
"tests.utils.sharding_dump",
70-
get_test_config_path(),
71-
f"compile_topology={topology}",
72-
f"compile_topology_num_slices={num_slice}",
73-
f"model_name={model_name}",
74-
"weight_dtype=float32",
75-
"log_config=false",
76-
"debug_sharding=true",
77-
],
78-
check=True,
79-
)
61+
flags.DEFINE_string("custom_mesh_and_rule", None, "Specific custom_mesh_and_rule to dump.")
62+
63+
64+
def run_single_dump(model_name: str, topology: str, num_slice: str, custom_mesh_and_rule: str, overrides: tuple) -> None:
65+
"""Generate sharding json file for one specific model, topology, slice and rule."""
66+
args = [
67+
"python3",
68+
"-m",
69+
"tests.utils.sharding_dump",
70+
get_test_config_path(),
71+
f"compile_topology={topology}",
72+
f"compile_topology_num_slices={num_slice}",
73+
f"model_name={model_name}",
74+
"weight_dtype=float32",
75+
"log_config=false",
76+
"debug_sharding=true",
77+
]
78+
if custom_mesh_and_rule:
79+
args.append(f"custom_mesh_and_rule={custom_mesh_and_rule}")
80+
if overrides:
81+
args.extend(overrides)
82+
subprocess.run(args, check=True)
8083

8184

8285
def main(argv: Sequence[str]) -> None:
83-
"""Generate json files for every combination of model, topology and slices."""
86+
"""Generate json files for every combination of model, topology, slices and rule."""
8487
if FLAGS.model_name and FLAGS.topology and FLAGS.num_slice:
85-
cases_to_run = [(FLAGS.model_name, FLAGS.topology, FLAGS.num_slice)]
88+
cmr = FLAGS.custom_mesh_and_rule if FLAGS.custom_mesh_and_rule is not None else ""
89+
# We do not natively support overrides via CLI FLAGS. To test explicit cases,
90+
# rely on the predefined TEST_CASES.
91+
cases_to_run = [(FLAGS.model_name, FLAGS.topology, FLAGS.num_slice, cmr, ())]
8692
print(
8793
"Running specific case from command line: "
88-
f"Model={FLAGS.model_name}, Topology={FLAGS.topology}, NumSlice={FLAGS.num_slice}"
94+
f"Model={FLAGS.model_name}, Topology={FLAGS.topology}, NumSlice={FLAGS.num_slice}, Rule={cmr}"
8995
)
90-
elif FLAGS.model_name or FLAGS.topology or FLAGS.num_slice:
96+
elif FLAGS.model_name or FLAGS.topology or FLAGS.num_slice or FLAGS.custom_mesh_and_rule:
9197
print("Error: To specify a single test case, --model_name, --topology, and --num_slice must all be provided.")
9298
return
9399
else:
94100
cases_to_run = TEST_CASES
95101
print(f"Running all {len(TEST_CASES)} predefined test cases.")
96102

97103
total = len(cases_to_run)
98-
for i, (model_name, topology, num_slice) in enumerate(cases_to_run):
99-
print(f"\n[{i+1}/{total}] Processing: {model_name} | {topology} | Slice {num_slice}")
100-
101-
base_path = Path(f"{MAXTEXT_REPO_ROOT}/tests/utils/sharding_info/{model_name}/" f"{topology}/slice_{num_slice}/")
104+
for i, (model_name, topology, num_slice, custom_mesh_and_rule, overrides) in enumerate(cases_to_run):
105+
rule_name = f"rule_{custom_mesh_and_rule}" if custom_mesh_and_rule else "rule_default"
106+
if overrides:
107+
rule_name += "_" + "_".join(overrides)
108+
print(f"\n[{i+1}/{total}] Processing: {model_name} | {topology} | Slice {num_slice} | Rule {rule_name}")
109+
110+
base_path = Path(
111+
f"{MAXTEXT_REPO_ROOT}/tests/utils/sharding_info/{model_name}/" f"{topology}/slice_{num_slice}/{rule_name}/"
112+
)
102113
json_path_named = base_path / "named_shardings.json"
103114
json_path_logical = base_path / "logical_shardings.json"
104115

105116
if json_path_named.exists() and json_path_logical.exists():
106117
print(" -> Sharding files already exist. Regenerating to overwrite.")
107118

108119
try:
109-
run_single_dump(model_name, topology, str(num_slice))
120+
run_single_dump(model_name, topology, str(num_slice), custom_mesh_and_rule, overrides)
110121
except subprocess.CalledProcessError:
111-
print(f"!!! FAILED: {model_name} {topology} {num_slice}")
122+
print(f"!!! FAILED: {model_name} {topology} {num_slice} {custom_mesh_and_rule} overrides={overrides}")
112123

113124

114125
if __name__ == "__main__":

0 commit comments

Comments
 (0)