Skip to content

Commit bda679d

Browse files
committed
Fix benchmark labels, remove invalid AveragePool dilations attribute, add pow NaN/Inf fallback tests, and harden Reshape dimension validation in
propagate_shapes
1 parent 665f405 commit bda679d

File tree

4 files changed

+94
-42
lines changed

4 files changed

+94
-42
lines changed

rust/jstprove_circuits/benches/autotuner_strategies.rs

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,14 @@ fn clear_cache() {
2323
}
2424
}
2525

26+
fn model_name() -> String {
27+
std::env::var("MODEL").unwrap_or_else(|_| "lenet".to_string())
28+
}
29+
2630
fn model_path() -> std::path::PathBuf {
27-
let model_name = std::env::var("MODEL").unwrap_or_else(|_| "lenet".to_string());
2831
Path::new(env!("CARGO_MANIFEST_DIR"))
2932
.join("../jstprove_remainder/models")
30-
.join(format!("{model_name}.onnx"))
33+
.join(format!("{}.onnx", model_name()))
3134
}
3235

3336
/// Benchmark cold compile: the autotuner cache is cleared before every iteration,
@@ -47,7 +50,7 @@ fn bench_cold_compile(c: &mut Criterion) {
4750
group.sample_size(10);
4851
group.measurement_time(Duration::from_secs(120));
4952

50-
group.bench_function("lenet", |b| {
53+
group.bench_function(model_name(), |b| {
5154
b.iter_batched(
5255
|| {
5356
clear_cache();
@@ -103,7 +106,7 @@ fn bench_warm_compile(c: &mut Criterion) {
103106
group.sample_size(10);
104107
group.measurement_time(Duration::from_secs(120));
105108

106-
group.bench_function("lenet", |b| {
109+
group.bench_function(model_name(), |b| {
107110
b.iter_batched(
108111
|| {
109112
OnnxContext::set_all(arch.clone(), params.clone(), Some(wandb.clone()));

rust/jstprove_circuits/benches/layer_bench.rs

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -758,10 +758,6 @@ fn make_averagepool_metadata() -> (CircuitParams, Architecture, WANDB) {
758758
Value::String("pads".into()),
759759
Value::Array(vec![Value::from(0i64); 4]),
760760
),
761-
(
762-
Value::String("dilations".into()),
763-
Value::Array(vec![Value::from(1i64), Value::from(1i64)]),
764-
),
765761
])),
766762
opset_version_number: 17,
767763
};

rust/jstprove_circuits/src/circuit_functions/hints/pow.rs

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,4 +174,32 @@ mod tests {
174174
"(-2)^3: got {result}, expected {expected}"
175175
);
176176
}
177+
178+
#[test]
179+
fn pow_zero_zero() {
180+
// 0^0 → 0.0_f64.powf(0.0) = NaN → fallback to 1.0 (scale_u64 as i64)
181+
let scale: u64 = 1 << 18;
182+
let x_q = 0i64;
183+
let exp_q = 0i64;
184+
let result = run_hint(x_q, exp_q, scale);
185+
assert_eq!(
186+
result, scale as i64,
187+
"0^0: expected 1.0 encoded as {}, got {result}",
188+
scale
189+
);
190+
}
191+
192+
#[test]
193+
fn pow_negative_fractional_exponent() {
194+
// (-2)^0.5 → (-2.0_f64).powf(0.5) = NaN → fallback to 1.0 (scale_u64 as i64)
195+
let scale: u64 = 1 << 18;
196+
let x_q = -2 * scale as i64;
197+
let exp_q = scale as i64 / 2; // 0.5 in fixed-point
198+
let result = run_hint(x_q, exp_q, scale);
199+
assert_eq!(
200+
result, scale as i64,
201+
"(-2)^0.5: expected 1.0 encoded as {}, got {result}",
202+
scale
203+
);
204+
}
177205
}

rust/jstprove_onnx/src/quantizer.rs

Lines changed: 59 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,7 @@ fn compute_max_bound_inner(graph: &LayerGraph) -> Result<f64> {
329329
// default alpha for Log-layer bound propagation (ln(2^18) ≈ 12.47).
330330
let default_alpha_f64 = (DEFAULT_SCALE_BASE as f64).powi(DEFAULT_SCALE_EXPONENT as i32);
331331

332-
let shapes = propagate_shapes(graph);
332+
let shapes = propagate_shapes(graph)?;
333333

334334
for name in &graph.input_names {
335335
bounds.insert(name.clone(), 1.0);
@@ -413,7 +413,7 @@ fn broadcast_two(a: &[usize], b: &[usize]) -> Vec<usize> {
413413
///
414414
/// This is used by `compute_bounds` so that reduction ops (e.g. ReduceSum)
415415
/// can account for the actual number of elements being accumulated.
416-
fn propagate_shapes(graph: &LayerGraph) -> HashMap<String, Vec<usize>> {
416+
fn propagate_shapes(graph: &LayerGraph) -> Result<HashMap<String, Vec<usize>>> {
417417
let mut shapes: HashMap<String, Vec<usize>> = HashMap::new();
418418

419419
// Seed with model-level input shapes.
@@ -545,43 +545,68 @@ fn propagate_shapes(graph: &LayerGraph) -> HashMap<String, Vec<usize>> {
545545
let in_shape: &[usize] = input_shape.as_deref().unwrap_or(&[]);
546546
let input_total: usize = in_shape.iter().product();
547547

548-
// First pass: resolve 0 → copy and positives; mark -1 as None.
549-
let mut dims: Vec<Option<usize>> = raw
550-
.iter()
551-
.enumerate()
552-
.map(|(i, &d)| {
553-
if d == 0 {
554-
if allowzero {
555-
Some(0)
556-
} else {
557-
Some(in_shape.get(i).copied().unwrap_or(0))
558-
}
559-
} else if d > 0 {
560-
Some(d as usize)
548+
// First pass: resolve 0 → copy and positives; -1 → infer; other negatives are invalid.
549+
let mut dims: Vec<Option<usize>> = Vec::with_capacity(raw.len());
550+
for (i, &d) in raw.iter().enumerate() {
551+
if d == 0 {
552+
dims.push(if allowzero {
553+
Some(0)
561554
} else {
562-
None // -1 only
563-
}
564-
})
565-
.collect();
555+
Some(in_shape.get(i).copied().unwrap_or(0))
556+
});
557+
} else if d > 0 {
558+
dims.push(Some(d as usize));
559+
} else if d == -1 {
560+
dims.push(None); // infer sentinel
561+
} else {
562+
anyhow::bail!(
563+
"Reshape layer '{}': invalid dimension {} at index {} (only -1 is a valid infer sentinel)",
564+
layer.name, d, i
565+
);
566+
}
567+
}
566568

567569
// Infer the single -1 dimension when input total is known.
568-
if input_total > 0 {
569-
let n_unknown = dims.iter().filter(|d| d.is_none()).count();
570-
if n_unknown == 1 {
571-
let known: usize = dims.iter().filter_map(|&d| d).product();
572-
if known > 0 {
573-
let inferred = input_total / known;
574-
for d in &mut dims {
575-
if d.is_none() {
576-
*d = Some(inferred);
577-
break;
578-
}
579-
}
570+
let n_unknown = dims.iter().filter(|d| d.is_none()).count();
571+
if n_unknown == 1 && input_total > 0 {
572+
let known: usize = dims.iter().filter_map(|&d| d).product();
573+
if known == 0 {
574+
anyhow::bail!(
575+
"Reshape layer '{}': cannot infer -1 dimension when known product is 0",
576+
layer.name
577+
);
578+
}
579+
if input_total % known != 0 {
580+
anyhow::bail!(
581+
"Reshape layer '{}': input total {} is not divisible by known dims product {}",
582+
layer.name, input_total, known
583+
);
584+
}
585+
let inferred = input_total / known;
586+
for d in &mut dims {
587+
if d.is_none() {
588+
*d = Some(inferred);
589+
break;
580590
}
581591
}
592+
} else if n_unknown > 1 {
593+
anyhow::bail!(
594+
"Reshape layer '{}': more than one -1 dimension is not allowed",
595+
layer.name
596+
);
582597
}
583598

584-
dims.into_iter().map(|d| d.unwrap_or(0)).collect()
599+
// All None should be resolved by now; any remaining None is an error.
600+
dims.into_iter()
601+
.map(|d| {
602+
d.ok_or_else(|| {
603+
anyhow::anyhow!(
604+
"Reshape layer '{}': could not infer -1 dimension (input shape unknown)",
605+
layer.name
606+
)
607+
})
608+
})
609+
.collect::<Result<Vec<usize>>>()?
585610
} else {
586611
input_shape.unwrap_or_default()
587612
}
@@ -1145,7 +1170,7 @@ fn propagate_shapes(graph: &LayerGraph) -> HashMap<String, Vec<usize>> {
11451170
}
11461171
}
11471172

1148-
shapes
1173+
Ok(shapes)
11491174
}
11501175

11511176
fn compute_bounds(graph: &LayerGraph, config: &ScaleConfig) -> Result<HashMap<String, usize>> {
@@ -1154,7 +1179,7 @@ fn compute_bounds(graph: &LayerGraph, config: &ScaleConfig) -> Result<HashMap<St
11541179
let mut n_bits_config = HashMap::new();
11551180

11561181
// Pre-compute tensor shapes for reduction-size estimation.
1157-
let shapes = propagate_shapes(graph);
1182+
let shapes = propagate_shapes(graph)?;
11581183

11591184
for name in &graph.input_names {
11601185
bounds.insert(name.clone(), 1.0);

0 commit comments

Comments
 (0)