Skip to content

Commit 665f405

Browse files
committed
Fix LayerNorm signed division/range constraints, improve LogUp out-of-bounds diagnostics, and add missing flatten_rmpv_to_f64 helper
1 parent 5b6bd69 commit 665f405

File tree

3 files changed

+85
-9
lines changed

3 files changed

+85
-9
lines changed

compiler/circuit-std-rs/src/logup.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -469,7 +469,10 @@ pub fn query_count_hint<F: Field>(inputs: &[F], outputs: &mut [F]) -> Result<(),
469469
for input in inputs {
470470
let query_id = input.to_u256().as_usize();
471471
if query_id >= count.len() {
472-
return Err(Error::InternalError("query_id out of bounds".into()));
472+
return Err(Error::InternalError(format!(
473+
"query_id out of bounds: query_id={query_id}, table_len={}",
474+
count.len()
475+
)));
473476
}
474477
count[query_id] += 1;
475478
}

rust/jstprove_circuits/src/circuit_functions/layers/layer_norm.rs

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -129,22 +129,40 @@ impl<C: Config, Builder: RootAPI<C>> LayerOp<C, Builder> for LayerNormLayer {
129129
let norm_offset = api.constant(CircuitField::<C>::from_u256(U256::from(
130130
1u64 << max_norm_bits,
131131
)));
132+
// Signed-safe division support:
133+
// norm_unscaled can be negative, while unconstrained_int_div/mod operate
134+
// over field representatives. Bias by a large power-of-two multiple of
135+
// `scale` so the dividend is non-negative in the intended integer domain,
136+
// then subtract the corresponding quotient bias.
137+
let norm_bias_bits = max_norm_bits + self.scale_exponent as usize + 2;
138+
let norm_bias_u256 = U256::from(1u8) << norm_bias_bits;
139+
let norm_bias_var = api.constant(CircuitField::<C>::from_u256(norm_bias_u256));
140+
let norm_q_bias_u256 = U256::from(1u8) << (norm_bias_bits - self.scale_exponent as usize);
141+
let norm_q_bias_var = api.constant(CircuitField::<C>::from_u256(norm_q_bias_u256));
132142
let n_alpha_sq = api.constant(CircuitField::<C>::from_u256(U256::from(
133143
lane_size as u64 * self.scaling * self.scaling,
134144
)));
145+
let norm_var_tol_native = lane_size as u64 * self.scaling * self.scaling;
135146
let norm_var_tolerance = api.constant(CircuitField::<C>::from_u256(U256::from(
136-
lane_size as u64 * self.scaling,
147+
norm_var_tol_native,
137148
)));
138-
let norm_var_tol_bits = (2 * lane_size as u64 * self.scaling)
149+
let norm_var_tol_bits = (2 * norm_var_tol_native)
139150
.next_power_of_two()
140151
.trailing_zeros() as usize;
141152

142153
let mean_tolerance =
143154
api.constant(CircuitField::<C>::from_u256(U256::from(lane_size as u64)));
144155
let mean_tol_bits = (2 * lane_size + 1).next_power_of_two().trailing_zeros() as usize;
145156

146-
let per_elem_tolerance = api.constant(CircuitField::<C>::from_u256(U256::from(3u64)));
147-
let per_elem_tol_bits: usize = 3;
157+
// y_q comes from floating-point hint arithmetic while norm_q is reconstructed
158+
// from quantized intermediates; allow bounded slack at alpha^2 scale.
159+
let per_elem_tol_native = self.scaling * self.scaling;
160+
let per_elem_tolerance = api.constant(CircuitField::<C>::from_u256(U256::from(
161+
per_elem_tol_native,
162+
)));
163+
let per_elem_tol_bits = (2 * per_elem_tol_native)
164+
.next_power_of_two()
165+
.trailing_zeros() as usize;
148166

149167
let outer_size: usize = shape[..axis].iter().product();
150168
let flat_input: Vec<Variable> = x_input
@@ -192,13 +210,16 @@ impl<C: Config, Builder: RootAPI<C>> LayerOp<C, Builder> for LayerNormLayer {
192210
let dev = api.sub(flat_input[start + i], mean_q);
193211
let norm_unscaled = api.mul(dev, inv_std_q);
194212

195-
let norm_q = api.unconstrained_int_div(norm_unscaled, scale_var);
196-
let norm_rem = api.unconstrained_mod(norm_unscaled, scale_var);
197-
let recon = api.mul(norm_q, scale_var);
213+
let norm_unscaled_biased = api.add(norm_unscaled, norm_bias_var);
214+
let norm_q_biased = api.unconstrained_int_div(norm_unscaled_biased, scale_var);
215+
let norm_rem = api.unconstrained_mod(norm_unscaled_biased, scale_var);
216+
let recon = api.mul(norm_q_biased, scale_var);
198217
let recon = api.add(recon, norm_rem);
199-
api.assert_is_equal(recon, norm_unscaled);
218+
api.assert_is_equal(recon, norm_unscaled_biased);
200219
logup_ctx.range_check::<C, Builder>(api, norm_rem, self.scale_exponent as usize)?;
201220

221+
let norm_q = api.sub(norm_q_biased, norm_q_bias_var);
222+
202223
let norm_shifted = api.add(norm_q, norm_offset);
203224
logup_ctx.range_check::<C, Builder>(api, norm_shifted, max_norm_bits + 1)?;
204225

rust/jstprove_circuits/src/runner/main_runner.rs

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1790,6 +1790,58 @@ fn natural_key_cmp(a: &str, b: &str) -> std::cmp::Ordering {
17901790
.then_with(|| a.cmp(b))
17911791
}
17921792

1793+
fn flatten_rmpv_to_f64(val: &Value, out: &mut Vec<f64>) -> Result<(), RunError> {
1794+
match val {
1795+
Value::Array(arr) => {
1796+
for item in arr {
1797+
flatten_rmpv_to_f64(item, out)?;
1798+
}
1799+
Ok(())
1800+
}
1801+
Value::Map(entries) => {
1802+
let mut pairs: Vec<_> = entries
1803+
.iter()
1804+
.filter_map(|(k, v)| k.as_str().map(|s| (s.to_string(), v)))
1805+
.collect();
1806+
pairs.sort_by(|(a, _), (b, _)| natural_key_cmp(a, b));
1807+
for (_, v) in pairs {
1808+
flatten_rmpv_to_f64(v, out)?;
1809+
}
1810+
Ok(())
1811+
}
1812+
Value::Integer(n) => {
1813+
let i = n
1814+
.as_i64()
1815+
.ok_or_else(|| RunError::Deserialize("integer value out of i64 range".into()))?;
1816+
out.push(i as f64);
1817+
Ok(())
1818+
}
1819+
Value::F64(f) => {
1820+
if f.is_finite() {
1821+
out.push(*f);
1822+
Ok(())
1823+
} else {
1824+
Err(RunError::Deserialize(
1825+
"non-finite f64 value in input".into(),
1826+
))
1827+
}
1828+
}
1829+
Value::F32(f) => {
1830+
if f.is_finite() {
1831+
out.push(f64::from(*f));
1832+
Ok(())
1833+
} else {
1834+
Err(RunError::Deserialize(
1835+
"non-finite f32 value in input".into(),
1836+
))
1837+
}
1838+
}
1839+
other => Err(RunError::Deserialize(format!(
1840+
"unsupported input value type for debug assignment: {other:?}"
1841+
))),
1842+
}
1843+
}
1844+
17931845
#[allow(clippy::cast_precision_loss, clippy::cast_possible_truncation)]
17941846
fn flatten_value_to_i64(val: &Value) -> Vec<i64> {
17951847
match val {

0 commit comments

Comments
 (0)