Skip to content

Shape mismatch occurs when dim0 is not divisible by global_rank_size #672

@kane-vln

Description

@kane-vln

Model:
qwen3-235b-a22b

Config:
tp_size=4, dp_shard_size=64

cosmos_rl/policy/model/qwen3_vl_moe/init.py", line 893, in load_hf_weights
[rank130]: assert local_view.shape == sharded_weight.shape, (
[rank130]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank130]: AssertionError: Shape mismatch: torch.Size([594, 4096]) != torch.Size([593, 4096]) for lm_head.weight with original shape torch.Size([151936, 4096])

Metadata

Metadata

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions