@@ -129,22 +129,40 @@ impl<C: Config, Builder: RootAPI<C>> LayerOp<C, Builder> for LayerNormLayer {
129129 let norm_offset = api. constant ( CircuitField :: < C > :: from_u256 ( U256 :: from (
130130 1u64 << max_norm_bits,
131131 ) ) ) ;
132+ // Signed-safe division support:
133+ // norm_unscaled can be negative, while unconstrained_int_div/mod operate
134+ // over field representatives. Bias by a large power-of-two multiple of
135+ // `scale` so the dividend is non-negative in the intended integer domain,
136+ // then subtract the corresponding quotient bias.
137+ let norm_bias_bits = max_norm_bits + self . scale_exponent as usize + 2 ;
138+ let norm_bias_u256 = U256 :: from ( 1u8 ) << norm_bias_bits;
139+ let norm_bias_var = api. constant ( CircuitField :: < C > :: from_u256 ( norm_bias_u256) ) ;
140+ let norm_q_bias_u256 = U256 :: from ( 1u8 ) << ( norm_bias_bits - self . scale_exponent as usize ) ;
141+ let norm_q_bias_var = api. constant ( CircuitField :: < C > :: from_u256 ( norm_q_bias_u256) ) ;
132142 let n_alpha_sq = api. constant ( CircuitField :: < C > :: from_u256 ( U256 :: from (
133143 lane_size as u64 * self . scaling * self . scaling ,
134144 ) ) ) ;
145+ let norm_var_tol_native = lane_size as u64 * self . scaling * self . scaling ;
135146 let norm_var_tolerance = api. constant ( CircuitField :: < C > :: from_u256 ( U256 :: from (
136- lane_size as u64 * self . scaling ,
147+ norm_var_tol_native ,
137148 ) ) ) ;
138- let norm_var_tol_bits = ( 2 * lane_size as u64 * self . scaling )
149+ let norm_var_tol_bits = ( 2 * norm_var_tol_native )
139150 . next_power_of_two ( )
140151 . trailing_zeros ( ) as usize ;
141152
142153 let mean_tolerance =
143154 api. constant ( CircuitField :: < C > :: from_u256 ( U256 :: from ( lane_size as u64 ) ) ) ;
144155 let mean_tol_bits = ( 2 * lane_size + 1 ) . next_power_of_two ( ) . trailing_zeros ( ) as usize ;
145156
146- let per_elem_tolerance = api. constant ( CircuitField :: < C > :: from_u256 ( U256 :: from ( 3u64 ) ) ) ;
147- let per_elem_tol_bits: usize = 3 ;
157+ // y_q comes from floating-point hint arithmetic while norm_q is reconstructed
158+ // from quantized intermediates; allow bounded slack at alpha^2 scale.
159+ let per_elem_tol_native = self . scaling * self . scaling ;
160+ let per_elem_tolerance = api. constant ( CircuitField :: < C > :: from_u256 ( U256 :: from (
161+ per_elem_tol_native,
162+ ) ) ) ;
163+ let per_elem_tol_bits = ( 2 * per_elem_tol_native)
164+ . next_power_of_two ( )
165+ . trailing_zeros ( ) as usize ;
148166
149167 let outer_size: usize = shape[ ..axis] . iter ( ) . product ( ) ;
150168 let flat_input: Vec < Variable > = x_input
@@ -192,13 +210,16 @@ impl<C: Config, Builder: RootAPI<C>> LayerOp<C, Builder> for LayerNormLayer {
192210 let dev = api. sub ( flat_input[ start + i] , mean_q) ;
193211 let norm_unscaled = api. mul ( dev, inv_std_q) ;
194212
195- let norm_q = api. unconstrained_int_div ( norm_unscaled, scale_var) ;
196- let norm_rem = api. unconstrained_mod ( norm_unscaled, scale_var) ;
197- let recon = api. mul ( norm_q, scale_var) ;
213+ let norm_unscaled_biased = api. add ( norm_unscaled, norm_bias_var) ;
214+ let norm_q_biased = api. unconstrained_int_div ( norm_unscaled_biased, scale_var) ;
215+ let norm_rem = api. unconstrained_mod ( norm_unscaled_biased, scale_var) ;
216+ let recon = api. mul ( norm_q_biased, scale_var) ;
198217 let recon = api. add ( recon, norm_rem) ;
199- api. assert_is_equal ( recon, norm_unscaled ) ;
218+ api. assert_is_equal ( recon, norm_unscaled_biased ) ;
200219 logup_ctx. range_check :: < C , Builder > ( api, norm_rem, self . scale_exponent as usize ) ?;
201220
221+ let norm_q = api. sub ( norm_q_biased, norm_q_bias_var) ;
222+
202223 let norm_shifted = api. add ( norm_q, norm_offset) ;
203224 logup_ctx. range_check :: < C , Builder > ( api, norm_shifted, max_norm_bits + 1 ) ?;
204225
0 commit comments