@@ -329,7 +329,7 @@ fn compute_max_bound_inner(graph: &LayerGraph) -> Result<f64> {
329329 // default alpha for Log-layer bound propagation (ln(2^18) ≈ 12.47).
330330 let default_alpha_f64 = ( DEFAULT_SCALE_BASE as f64 ) . powi ( DEFAULT_SCALE_EXPONENT as i32 ) ;
331331
332- let shapes = propagate_shapes ( graph) ;
332+ let shapes = propagate_shapes ( graph) ? ;
333333
334334 for name in & graph. input_names {
335335 bounds. insert ( name. clone ( ) , 1.0 ) ;
@@ -413,7 +413,7 @@ fn broadcast_two(a: &[usize], b: &[usize]) -> Vec<usize> {
413413///
414414/// This is used by `compute_bounds` so that reduction ops (e.g. ReduceSum)
415415/// can account for the actual number of elements being accumulated.
416- fn propagate_shapes ( graph : & LayerGraph ) -> HashMap < String , Vec < usize > > {
416+ fn propagate_shapes ( graph : & LayerGraph ) -> Result < HashMap < String , Vec < usize > > > {
417417 let mut shapes: HashMap < String , Vec < usize > > = HashMap :: new ( ) ;
418418
419419 // Seed with model-level input shapes.
@@ -545,43 +545,68 @@ fn propagate_shapes(graph: &LayerGraph) -> HashMap<String, Vec<usize>> {
545545 let in_shape: & [ usize ] = input_shape. as_deref ( ) . unwrap_or ( & [ ] ) ;
546546 let input_total: usize = in_shape. iter ( ) . product ( ) ;
547547
548- // First pass: resolve 0 → copy and positives; mark -1 as None.
549- let mut dims: Vec < Option < usize > > = raw
550- . iter ( )
551- . enumerate ( )
552- . map ( |( i, & d) | {
553- if d == 0 {
554- if allowzero {
555- Some ( 0 )
556- } else {
557- Some ( in_shape. get ( i) . copied ( ) . unwrap_or ( 0 ) )
558- }
559- } else if d > 0 {
560- Some ( d as usize )
548+ // First pass: resolve 0 → copy and positives; -1 → infer; other negatives are invalid.
549+ let mut dims: Vec < Option < usize > > = Vec :: with_capacity ( raw. len ( ) ) ;
550+ for ( i, & d) in raw. iter ( ) . enumerate ( ) {
551+ if d == 0 {
552+ dims. push ( if allowzero {
553+ Some ( 0 )
561554 } else {
562- None // -1 only
563- }
564- } )
565- . collect ( ) ;
555+ Some ( in_shape. get ( i) . copied ( ) . unwrap_or ( 0 ) )
556+ } ) ;
557+ } else if d > 0 {
558+ dims. push ( Some ( d as usize ) ) ;
559+ } else if d == -1 {
560+ dims. push ( None ) ; // infer sentinel
561+ } else {
562+ anyhow:: bail!(
563+ "Reshape layer '{}': invalid dimension {} at index {} (only -1 is a valid infer sentinel)" ,
564+ layer. name, d, i
565+ ) ;
566+ }
567+ }
566568
567569 // Infer the single -1 dimension when input total is known.
568- if input_total > 0 {
569- let n_unknown = dims. iter ( ) . filter ( |d| d. is_none ( ) ) . count ( ) ;
570- if n_unknown == 1 {
571- let known: usize = dims. iter ( ) . filter_map ( |& d| d) . product ( ) ;
572- if known > 0 {
573- let inferred = input_total / known;
574- for d in & mut dims {
575- if d. is_none ( ) {
576- * d = Some ( inferred) ;
577- break ;
578- }
579- }
570+ let n_unknown = dims. iter ( ) . filter ( |d| d. is_none ( ) ) . count ( ) ;
571+ if n_unknown == 1 && input_total > 0 {
572+ let known: usize = dims. iter ( ) . filter_map ( |& d| d) . product ( ) ;
573+ if known == 0 {
574+ anyhow:: bail!(
575+ "Reshape layer '{}': cannot infer -1 dimension when known product is 0" ,
576+ layer. name
577+ ) ;
578+ }
579+ if input_total % known != 0 {
580+ anyhow:: bail!(
581+ "Reshape layer '{}': input total {} is not divisible by known dims product {}" ,
582+ layer. name, input_total, known
583+ ) ;
584+ }
585+ let inferred = input_total / known;
586+ for d in & mut dims {
587+ if d. is_none ( ) {
588+ * d = Some ( inferred) ;
589+ break ;
580590 }
581591 }
592+ } else if n_unknown > 1 {
593+ anyhow:: bail!(
594+ "Reshape layer '{}': more than one -1 dimension is not allowed" ,
595+ layer. name
596+ ) ;
582597 }
583598
584- dims. into_iter ( ) . map ( |d| d. unwrap_or ( 0 ) ) . collect ( )
599+ // All None should be resolved by now; any remaining None is an error.
600+ dims. into_iter ( )
601+ . map ( |d| {
602+ d. ok_or_else ( || {
603+ anyhow:: anyhow!(
604+ "Reshape layer '{}': could not infer -1 dimension (input shape unknown)" ,
605+ layer. name
606+ )
607+ } )
608+ } )
609+ . collect :: < Result < Vec < usize > > > ( ) ?
585610 } else {
586611 input_shape. unwrap_or_default ( )
587612 }
@@ -1145,7 +1170,7 @@ fn propagate_shapes(graph: &LayerGraph) -> HashMap<String, Vec<usize>> {
11451170 }
11461171 }
11471172
1148- shapes
1173+ Ok ( shapes)
11491174}
11501175
11511176fn compute_bounds ( graph : & LayerGraph , config : & ScaleConfig ) -> Result < HashMap < String , usize > > {
@@ -1154,7 +1179,7 @@ fn compute_bounds(graph: &LayerGraph, config: &ScaleConfig) -> Result<HashMap<St
11541179 let mut n_bits_config = HashMap :: new ( ) ;
11551180
11561181 // Pre-compute tensor shapes for reduction-size estimation.
1157- let shapes = propagate_shapes ( graph) ;
1182+ let shapes = propagate_shapes ( graph) ? ;
11581183
11591184 for name in & graph. input_names {
11601185 bounds. insert ( name. clone ( ) , 1.0 ) ;
0 commit comments