Skip to content

Commit d908f2f

Browse files
committed
Fail fast on missing MODEL path, pin logup_chunk_bits in pipeline benches, add ConvTranspose zero-stride and bias ndim guards, and fix LayerNorm variance
tolerance overflow and zero-bits crash
1 parent 407c118 commit d908f2f

File tree

4 files changed

+80
-30
lines changed

4 files changed

+80
-30
lines changed

rust/jstprove_circuits/benches/autotuner_strategies.rs

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,12 @@ fn model_path() -> std::path::PathBuf {
3737
/// forcing a full sweep each time.
3838
fn bench_cold_compile(c: &mut Criterion) {
3939
let path = model_path();
40-
if !path.exists() {
41-
return;
42-
}
40+
assert!(
41+
path.exists(),
42+
"MODEL='{}' resolved to '{}' — ONNX file not found; set the MODEL env var to a model in jstprove_remainder/models/",
43+
model_name(),
44+
path.display()
45+
);
4346

4447
let metadata = expander_metadata::generate_from_onnx(&path).unwrap();
4548
// Force logup_chunk_bits = None so compile_bn254 always invokes the autotuner
@@ -80,9 +83,12 @@ fn bench_cold_compile(c: &mut Criterion) {
8083
/// every subsequent iteration hits the cache and skips the sweep.
8184
fn bench_warm_compile(c: &mut Criterion) {
8285
let path = model_path();
83-
if !path.exists() {
84-
return;
85-
}
86+
assert!(
87+
path.exists(),
88+
"MODEL='{}' resolved to '{}' — ONNX file not found; set the MODEL env var to a model in jstprove_remainder/models/",
89+
model_name(),
90+
path.display()
91+
);
8692

8793
// Pre-populate the cache with one cold compile.
8894
{

rust/jstprove_circuits/benches/pipeline_ecc.rs

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,18 @@ fn model_path() -> std::path::PathBuf {
2525

2626
fn bench_bn254(c: &mut Criterion) {
2727
let path = model_path();
28-
if !path.exists() {
29-
return;
30-
}
28+
assert!(
29+
path.exists(),
30+
"MODEL='{}' resolved to '{}' — ONNX file not found; set the MODEL env var to a model in jstprove_remainder/models/",
31+
model_name(),
32+
path.display()
33+
);
3134

3235
let metadata = expander_metadata::generate_from_onnx(&path).unwrap();
33-
let params = metadata.circuit_params.clone();
36+
// Pin a fixed chunk width so every compile iteration runs in the same warm-cache
37+
// regime without triggering the adaptive autotuner sweep on the first iteration.
38+
let mut params = metadata.circuit_params.clone();
39+
params.logup_chunk_bits = params.logup_chunk_bits.or(Some(12));
3440
let arch = metadata.architecture.clone();
3541
let wandb = metadata.wandb.clone();
3642
OnnxContext::set_all(arch.clone(), params.clone(), Some(wandb.clone()));
@@ -104,13 +110,17 @@ fn bench_bn254(c: &mut Criterion) {
104110

105111
fn bench_goldilocks(c: &mut Criterion) {
106112
let path = model_path();
107-
if !path.exists() {
108-
return;
109-
}
113+
assert!(
114+
path.exists(),
115+
"MODEL='{}' resolved to '{}' — ONNX file not found; set the MODEL env var to a model in jstprove_remainder/models/",
116+
model_name(),
117+
path.display()
118+
);
110119

111120
let metadata =
112121
expander_metadata::generate_from_onnx_for_field(&path, N_BITS_GOLDILOCKS, None).unwrap();
113-
let params = metadata.circuit_params.clone();
122+
let mut params = metadata.circuit_params.clone();
123+
params.logup_chunk_bits = params.logup_chunk_bits.or(Some(12));
114124
let arch = metadata.architecture.clone();
115125
let wandb = metadata.wandb.clone();
116126
OnnxContext::set_all(arch.clone(), params.clone(), Some(wandb.clone()));
@@ -184,13 +194,17 @@ fn bench_goldilocks(c: &mut Criterion) {
184194

185195
fn bench_goldilocks_basefold(c: &mut Criterion) {
186196
let path = model_path();
187-
if !path.exists() {
188-
return;
189-
}
197+
assert!(
198+
path.exists(),
199+
"MODEL='{}' resolved to '{}' — ONNX file not found; set the MODEL env var to a model in jstprove_remainder/models/",
200+
model_name(),
201+
path.display()
202+
);
190203

191204
let metadata =
192205
expander_metadata::generate_from_onnx_for_field(&path, N_BITS_GOLDILOCKS, None).unwrap();
193-
let params = metadata.circuit_params.clone();
206+
let mut params = metadata.circuit_params.clone();
207+
params.logup_chunk_bits = params.logup_chunk_bits.or(Some(12));
194208
let arch = metadata.architecture.clone();
195209
let wandb = metadata.wandb.clone();
196210
OnnxContext::set_all(arch.clone(), params.clone(), Some(wandb.clone()));

rust/jstprove_circuits/src/circuit_functions/layers/conv_transpose.rs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,13 @@ impl<C: Config, Builder: RootAPI<C>> LayerOp<C, Builder> for ConvTransposeLayer
181181
let mut res = if bias.is_empty() {
182182
ArrayD::from_elem(vec![n_batch, c_out, out_h, out_w], zero)
183183
} else {
184+
if bias.ndim() != 1 {
185+
return Err(LayerError::InvalidShape {
186+
layer: LayerKind::ConvTranspose,
187+
msg: format!("bias must be 1-D, got {}D", bias.ndim()),
188+
}
189+
.into());
190+
}
184191
if bias.shape()[0] != c_out {
185192
return Err(LayerError::InvalidShape {
186193
layer: LayerKind::ConvTranspose,
@@ -363,6 +370,18 @@ impl<C: Config, Builder: RootAPI<C>> LayerOp<C, Builder> for ConvTransposeLayer
363370
}
364371
}
365372

373+
// Guard against zero strides: the apply method uses stride as a divisor.
374+
if strides.contains(&0) {
375+
return Err(LayerError::UnsupportedConfig {
376+
layer: LayerKind::ConvTranspose,
377+
msg: format!(
378+
"layer '{}': strides must all be > 0, got {:?}",
379+
layer.name, strides
380+
),
381+
}
382+
.into());
383+
}
384+
366385
let conv_transpose = Self {
367386
weights,
368387
bias,

rust/jstprove_circuits/src/circuit_functions/layers/layer_norm.rs

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -139,20 +139,25 @@ impl<C: Config, Builder: RootAPI<C>> LayerOp<C, Builder> for LayerNormLayer {
139139
let norm_bias_var = api.constant(CircuitField::<C>::from_u256(norm_bias_u256));
140140
let norm_q_bias_u256 = U256::from(1u8) << (norm_bias_bits - self.scale_exponent as usize);
141141
let norm_q_bias_var = api.constant(CircuitField::<C>::from_u256(norm_q_bias_u256));
142-
let n_alpha_sq = api.constant(CircuitField::<C>::from_u256(U256::from(
143-
lane_size as u64 * self.scaling * self.scaling,
144-
)));
142+
// Compute N·α² in U256 to avoid u64 overflow when lane_size is large or
143+
// scale_exponent approaches the u64-squaring limit (guard in build caps at < 32,
144+
// but lane_size * 2^(2*exp) can still exceed u64 with exp=31 and lane_size≥4).
145+
let n_alpha_sq_u256 =
146+
U256::from(lane_size as u64) * U256::from(self.scaling) * U256::from(self.scaling);
147+
let n_alpha_sq = api.constant(CircuitField::<C>::from_u256(n_alpha_sq_u256));
145148
// Use half of n_alpha_sq as the variance tolerance so that the constraint
146149
// norm_sq_sum ∈ [N·α²/2, 3N·α²/2) is non-trivially enforced. With the
147150
// full N·α² tolerance the range-check reduced to just checking norm_sq_sum
148151
// is non-negative, which a cheating prover could trivially satisfy.
149-
let norm_var_tol_native = (lane_size as u64 * self.scaling * self.scaling) / 2;
150-
let norm_var_tolerance = api.constant(CircuitField::<C>::from_u256(U256::from(
151-
norm_var_tol_native,
152-
)));
153-
let norm_var_tol_bits = (2 * norm_var_tol_native)
154-
.next_power_of_two()
155-
.trailing_zeros() as usize;
152+
let norm_var_tol_u256 = n_alpha_sq_u256 >> 1u32; // / 2
153+
let norm_var_tolerance = api.constant(CircuitField::<C>::from_u256(norm_var_tol_u256));
154+
// ceil(log2(2 * norm_var_tol)) using U256; 0 is handled separately below.
155+
let norm_var_tol_bits: usize = if norm_var_tol_u256 == U256::from(0u64) {
156+
0
157+
} else {
158+
let two_tol = norm_var_tol_u256 * U256::from(2u64);
159+
(256u32 - (two_tol - U256::from(1u64)).leading_zeros()) as usize
160+
};
156161

157162
let mean_tolerance =
158163
api.constant(CircuitField::<C>::from_u256(U256::from(lane_size as u64)));
@@ -251,8 +256,14 @@ impl<C: Config, Builder: RootAPI<C>> LayerOp<C, Builder> for LayerNormLayer {
251256
}
252257

253258
let var_diff = api.sub(norm_sq_sum, n_alpha_sq);
254-
let var_shifted = api.add(var_diff, norm_var_tolerance);
255-
logup_ctx.range_check::<C, Builder>(api, var_shifted, norm_var_tol_bits)?;
259+
if norm_var_tol_bits == 0 {
260+
// Tolerance is 0 (degenerate: scale_exponent = 0 and lane_size = 1);
261+
// range_check(_, 0) is an error, so assert equality directly.
262+
api.assert_is_equal(var_diff, zero_var);
263+
} else {
264+
let var_shifted = api.add(var_diff, norm_var_tolerance);
265+
logup_ctx.range_check::<C, Builder>(api, var_shifted, norm_var_tol_bits)?;
266+
}
256267

257268
flat_output.extend_from_slice(y_vars);
258269
}

0 commit comments

Comments
 (0)