Skip to content

JSTPRV-164 per layer benchmarks#252

Open
millioner wants to merge 18 commits intomainfrom
JSTPRV-164_per_layer_benchmarks
Open

JSTPRV-164 per layer benchmarks#252
millioner wants to merge 18 commits intomainfrom
JSTPRV-164_per_layer_benchmarks

Conversation

@millioner
Copy link
Copy Markdown
Contributor

@millioner millioner commented Apr 7, 2026

Description

Related Issue

Type of Change

  • Bug fix (non-breaking)
  • New feature (non-breaking)
  • Breaking change (fix/feature causing existing functionality to break)
  • Refactor (non-functional changes)
  • Documentation update

Checklist

  • Code follows project patterns
  • Tests added/updated (if applicable)
  • Documentation updated (if applicable)
  • Self-review of code
  • All tests pass locally
  • Linter passes locally

Deployment Notes

Additional Comments

Summary by CodeRabbit

  • Bug Fixes
    • More informative out-of-bounds query errors showing offending ID and table size.
    • Improved input deserialization with clearer errors for unsupported/invalid msgpack shapes.
    • Numeric edge-case handling hardened: pow() falls back to 1.0 on NaN/Inf, sqrt enforces non-negativity before hinting, LayerNorm tolerances/overflow checks tightened.
  • New Features
    • ConvTranspose layer implemented.
    • Added a witness-generation path that accepts pre-quantized integer inputs.
  • Refactor
    • Benchmark suite migrated to Criterion with new layer-focused benchmarks and a new benchmark target.
  • Other
    • Shape propagation and reshape inference made stricter, erroring on ambiguous/invalid shapes.

@coderabbitai
Copy link
Copy Markdown

coderabbitai bot commented Apr 7, 2026

Walkthrough

Converts custom benchmarks to Criterion and adds layer/compile benchmarks; implements ConvTranspose; adds a two-pass pre‑quantized witness API; tightens numeric/shape/range handling across hints and layers; refactors reshape propagation; and improves LogUp error messaging.

Changes

Cohort / File(s) Summary
Benchmark infra & entries
rust/jstprove_circuits/Cargo.toml, rust/jstprove_circuits/benches/...
Replaced ad-hoc mains with Criterion harnesses, added layer_bench bench target, standardized sample/measurement settings, removed custom timing/reporting.
Per-layer benchmarks
rust/jstprove_circuits/benches/layer_bench.rs
Added Criterion benchmarks for six single-layer workloads (compile/witness/prove) using in-memory layer metadata and deterministic constant weights.
Autotuner & compile benches
rust/jstprove_circuits/benches/autotuner_strategies.rs, .../logup_chunking.rs, .../pipeline_ecc.rs
Converted to Criterion, simplified to compile/witness/prove flows, added cache-clearing/temp-dir batching and model-path assertions, removed multi-stage printed pipeline logic.
ConvTranspose layer
rust/jstprove_circuits/src/circuit_functions/layers/conv_transpose.rs
Implemented ConvTransposeLayer (struct, build, apply): weight/bias handling, shape validation, output-size computation, inverse-mapping accumulation, rescale support, and parameter validation.
Numeric/hint & layer fixes
rust/jstprove_circuits/src/circuit_functions/hints/pow.rs, .../layers/layer_norm.rs, .../layers/sqrt.rs
pow now falls back to quantized 1.0 on NaN/Inf; LayerNorm uses U256-based tolerance and biased division with explicit branches; Sqrt adds pre-hint non-negativity range check.
Witness API & runner
rust/jstprove_circuits/src/onnx.rs, rust/jstprove_circuits/src/runner/main_runner.rs
Added witness_from_prequantized (two-pass probe → evaluate → final solve), runner now uses it; removed f64 flattening path and strengthened rmpv::Value flattening/errors.
Quantizer / shape propagation
rust/jstprove_onnx/src/quantizer.rs
propagate_shapes now returns Result; Reshape propagation refactored to Option dims with stricter -1 inference, divisibility checks, and explicit errors for invalid/multiple unknowns.
LogUp error message
compiler/circuit-std-rs/src/logup.rs
query_count_hint now returns an Error::InternalError with a formatted message including the offending query_id and count.len().
Minor tests/comments
rust/jstprove_circuits/src/circuit_functions/layers/matmul.rs
Added clarifying comment in a unit test documenting Freivalds cost inequality for ell=1.

Sequence Diagram(s)

sequenceDiagram
  participant Caller as Caller
  participant Loader as CircuitLoader
  participant Probe as ProbeSolver
  participant Eval as LayeredEvaluator
  participant Final as FinalSolver
  participant Serializer as WitnessSerializer

  Caller->>Loader: witness_from_prequantized(circuit_bytes, solver_bytes, params, input_data)
  Loader->>Probe: apply_input_data (outputs zeroed) + solve_witness (probe)
  Probe-->>Loader: probe assignment (probe witness bytes)
  Loader->>Eval: evaluate layered circuit with probe inputs
  Eval-->>Loader: computed outputs (field elements)
  Loader->>Final: set computed outputs on probe assignment + solve_witness (final)
  Final-->>Serializer: final witness bytes
  Serializer-->>Caller: return WitnessBundle (serialized, output_data, version)
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

Suggested reviewers

  • tmfreiberg
  • jsgold-1

Poem

"I nibble bytes beneath the moonlit test,
Benches hum soft as Criterion does rest,
ConvTranspose flips weights with nimble art,
Two-pass witnesses give the solver heart,
A rabbit hops — compile, witness, proof, and zest!" 🐇

🚥 Pre-merge checks | ✅ 3
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The PR title 'JSTPRV-164 per layer benchmarks' directly and clearly summarizes the main change: adding per-layer benchmarking functionality.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch JSTPRV-164_per_layer_benchmarks

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@millioner millioner marked this pull request as ready for review April 7, 2026 18:08
Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
rust/jstprove_onnx/src/quantizer.rs (1)

548-585: ⚠️ Potential issue | 🟠 Major

Reshape shape parsing now accepts invalid negatives and can infer invalid dimensions.

At Line 562, every d < 0 is treated as infer (None), but only -1 is valid for ONNX Reshape. Also, Line 573 infers with integer division and no divisibility check, and Line 584 silently maps unresolved unknowns to 0. This can produce impossible shapes and skew downstream bound/n_bits estimation.

[sraise]

Suggested guardrails
-                        let mut dims: Vec<Option<usize>> = raw
+                        // Reject invalid negative dims (only -1 is legal in ONNX Reshape).
+                        if raw.iter().any(|&d| d < -1) {
+                            return input_shape.unwrap_or_default();
+                        }
+
+                        let mut dims: Vec<Option<usize>> = raw
                             .iter()
                             .enumerate()
                             .map(|(i, &d)| {
                                 if d == 0 {
                                     if allowzero {
                                         Some(0)
                                     } else {
                                         Some(in_shape.get(i).copied().unwrap_or(0))
                                     }
                                 } else if d > 0 {
                                     Some(d as usize)
                                 } else {
-                                    None // -1 only
+                                    None // -1
                                 }
                             })
                             .collect();

                         // Infer the single -1 dimension when input total is known.
                         if input_total > 0 {
                             let n_unknown = dims.iter().filter(|d| d.is_none()).count();
                             if n_unknown == 1 {
                                 let known: usize = dims.iter().filter_map(|&d| d).product();
                                 if known > 0 {
-                                    let inferred = input_total / known;
+                                    if input_total % known != 0 {
+                                        return input_shape.unwrap_or_default();
+                                    }
+                                    let inferred = input_total / known;
                                     for d in &mut dims {
                                         if d.is_none() {
                                             *d = Some(inferred);
                                             break;
                                         }
                                     }
                                 }
                             }
                         }

-                        dims.into_iter().map(|d| d.unwrap_or(0)).collect()
+                        // Avoid silently manufacturing zeros for unresolved unknown dims.
+                        if dims.iter().any(|d| d.is_none()) {
+                            input_shape.unwrap_or_default()
+                        } else {
+                            dims.into_iter().flatten().collect()
+                        }
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@rust/jstprove_onnx/src/quantizer.rs` around lines 548 - 585, The reshape
parsing currently treats any negative `d` in `raw` as an infer (`None`), uses
integer division to infer the unknown dimension, and finally maps unresolved
`None` to 0—produce invalid shapes; instead, in the block handling `raw` →
`dims` (variables: raw, dims, in_shape, allowzero, input_total) only treat `-1`
as the valid infer sentinel and treat any other negative `d` as an error; when
n_unknown == 1, check divisibility (input_total % known == 0) before computing
inferred = input_total / known and return an error if not divisible; do not
silently convert remaining None to 0—propagate/return an error (e.g.,
Result::Err) so callers handle invalid reshape specs rather than producing
impossible shapes.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@rust/jstprove_circuits/benches/autotuner_strategies.rs`:
- Line 50: The benchmark ID is hardcoded as "lenet" inside the
group.bench_function call, causing incorrect labels when the MODEL constant
differs; update both bench_cold_compile and bench_warm_compile to use the
selected model name (e.g., MODEL or a formatted string like
format!("{}",&MODEL)) when calling group.bench_function so the benchmark key
reflects the actual model being tested.

In `@rust/jstprove_circuits/benches/layer_bench.rs`:
- Around line 761-764: The AveragePool metadata currently includes a "dilations"
attribute tuple (Value::String("dilations"), Value::Array(...)) which is invalid
for ONNX opset 17; remove that tuple from the metadata where AveragePool is
constructed so the averagepool_* benchmarks no longer emit the dilations
attribute and compile_bn254() will pass strict attribute validation.

In `@rust/jstprove_circuits/src/circuit_functions/hints/pow.rs`:
- Around line 74-78: Add unit tests that exercise the NaN/Inf fallback path by
calling the pow hint with inputs that produce y_scaled.is_nan() and
y_scaled.is_infinite() (e.g., 0^0 and a negative base with a fractional
exponent) and assert the hint returns the fixed-point identity value (scale_u64
as i64). Locate the pow logic that checks y_scaled.is_nan() ||
y_scaled.is_infinite() (reference y_scaled and scale_u64) and write tests named
pow_zero_zero() and pow_negative_fractional_exponent() that construct the same
inputs the hint expects, invoke the hint function, and assert the output equals
the encoded 1.0 value so the intended silent fallback is covered.

In `@rust/jstprove_circuits/src/runner/main_runner.rs`:
- Around line 1049-1055: The new call to witness_from_prequantized::<C> skips
alpha scaling and so accepts already-quantized inputs, but flatten_recursive and
the witness assignment still provide raw f64s (params.scale_base and
params.scale_exponent are computed but never applied), causing incorrect
witnesses; fix by restoring quantization in this path: compute alpha from
params.scale_base/scale_exponent (as the old witness_from_f64_generic did) and
multiply each extracted activation in flatten_recursive (or immediately before
calling witness_from_prequantized) by alpha so the witness assignment receives
pre-quantized integers, or alternatively revert this call back to
witness_from_f64_generic if you intend to accept raw f64 inputs. Ensure you
update the codepaths referencing witness_from_prequantized, flatten_recursive,
and the witness assignment to perform the same alpha multiplication used
previously.

---

Outside diff comments:
In `@rust/jstprove_onnx/src/quantizer.rs`:
- Around line 548-585: The reshape parsing currently treats any negative `d` in
`raw` as an infer (`None`), uses integer division to infer the unknown
dimension, and finally maps unresolved `None` to 0—produce invalid shapes;
instead, in the block handling `raw` → `dims` (variables: raw, dims, in_shape,
allowzero, input_total) only treat `-1` as the valid infer sentinel and treat
any other negative `d` as an error; when n_unknown == 1, check divisibility
(input_total % known == 0) before computing inferred = input_total / known and
return an error if not divisible; do not silently convert remaining None to
0—propagate/return an error (e.g., Result::Err) so callers handle invalid
reshape specs rather than producing impossible shapes.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

Run ID: a296c79e-ff6c-4dae-adcd-4753f2e78047

📥 Commits

Reviewing files that changed from the base of the PR and between 7f100e6 and 665f405.

📒 Files selected for processing (14)
  • compiler/circuit-std-rs/src/logup.rs
  • rust/jstprove_circuits/Cargo.toml
  • rust/jstprove_circuits/benches/autotuner_strategies.rs
  • rust/jstprove_circuits/benches/layer_bench.rs
  • rust/jstprove_circuits/benches/logup_chunking.rs
  • rust/jstprove_circuits/benches/pipeline_ecc.rs
  • rust/jstprove_circuits/src/circuit_functions/hints/pow.rs
  • rust/jstprove_circuits/src/circuit_functions/layers/conv_transpose.rs
  • rust/jstprove_circuits/src/circuit_functions/layers/layer_norm.rs
  • rust/jstprove_circuits/src/circuit_functions/layers/matmul.rs
  • rust/jstprove_circuits/src/circuit_functions/layers/sqrt.rs
  • rust/jstprove_circuits/src/onnx.rs
  • rust/jstprove_circuits/src/runner/main_runner.rs
  • rust/jstprove_onnx/src/quantizer.rs

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (2)
rust/jstprove_circuits/benches/layer_bench.rs (2)

163-236: Consider adding measurement_time to witness and prove benchmarks for consistency.

The compile benchmarks explicitly set measurement_time(Duration::from_secs(120)), but witness and prove benchmarks only set sample_size(10). If witness/prove durations vary significantly or are unexpectedly slow, Criterion may not collect enough samples for stable measurements. Consider either:

  1. Adding explicit measurement_time to witness/prove groups, or
  2. Adding a comment explaining why compile needs it but witness/prove do not.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@rust/jstprove_circuits/benches/layer_bench.rs` around lines 163 - 236, The
conv_witness and conv_prove benchmark groups omit measurement_time while
conv_compile sets measurement_time(Duration::from_secs(120)), so update the
benchmark setup in the functions conv_witness and conv_prove to match by calling
group.measurement_time(Duration::from_secs(120)) after creating the group (or
alternatively add a brief comment in each function explaining why a longer
measurement_time isn’t needed); locate the group variable in conv_witness and
conv_prove and add the same measurement_time call used in conv_compile for
consistency.

169-177: Note: Global state via OnnxContext::set_all may require sequential benchmark execution.

OnnxContext::set_all() modifies global state before each benchmark iteration. Criterion runs benchmarks sequentially by default, so this should be fine. However, if parallel benchmark execution is enabled in the future (e.g., via configuration or external runners), the global state mutations could cause race conditions or incorrect measurements.

Consider adding a comment near the benchmark registration or in the module docs noting that these benchmarks assume sequential execution.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@rust/jstprove_circuits/benches/layer_bench.rs` around lines 169 - 177,
OnnxContext::set_all mutates global state before each iteration (used inside
group.bench_function for bn254 and compile_bn254), so add a clear comment next
to the benchmark registration (or at the top of the module) stating these
benches assume sequential execution and that enabling parallel criterion
execution or external parallel runners may cause races; mention the exact
symbols OnnxContext::set_all, group.bench_function, and compile_bn254 so future
maintainers know why the assumption exists and consider protecting the global
state (e.g., via a mutex) if parallel runs are later enabled.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@rust/jstprove_circuits/benches/layer_bench.rs`:
- Around line 163-236: The conv_witness and conv_prove benchmark groups omit
measurement_time while conv_compile sets
measurement_time(Duration::from_secs(120)), so update the benchmark setup in the
functions conv_witness and conv_prove to match by calling
group.measurement_time(Duration::from_secs(120)) after creating the group (or
alternatively add a brief comment in each function explaining why a longer
measurement_time isn’t needed); locate the group variable in conv_witness and
conv_prove and add the same measurement_time call used in conv_compile for
consistency.
- Around line 169-177: OnnxContext::set_all mutates global state before each
iteration (used inside group.bench_function for bn254 and compile_bn254), so add
a clear comment next to the benchmark registration (or at the top of the module)
stating these benches assume sequential execution and that enabling parallel
criterion execution or external parallel runners may cause races; mention the
exact symbols OnnxContext::set_all, group.bench_function, and compile_bn254 so
future maintainers know why the assumption exists and consider protecting the
global state (e.g., via a mutex) if parallel runs are later enabled.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

Run ID: ba6e218d-914a-4d2f-a6f5-5b8767c341d2

📥 Commits

Reviewing files that changed from the base of the PR and between 665f405 and bda679d.

📒 Files selected for processing (4)
  • rust/jstprove_circuits/benches/autotuner_strategies.rs
  • rust/jstprove_circuits/benches/layer_bench.rs
  • rust/jstprove_circuits/src/circuit_functions/hints/pow.rs
  • rust/jstprove_onnx/src/quantizer.rs
🚧 Files skipped from review as they are similar to previous changes (2)
  • rust/jstprove_circuits/src/circuit_functions/hints/pow.rs
  • rust/jstprove_circuits/benches/autotuner_strategies.rs

@millioner millioner force-pushed the JSTPRV-164_per_layer_benchmarks branch from bda679d to c035960 Compare April 8, 2026 15:44
Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 9

♻️ Duplicate comments (2)
rust/jstprove_circuits/benches/pipeline_ecc.rs (1)

42-46: ⚠️ Potential issue | 🟡 Minor

Add MODEL to the pipeline/... benchmark paths.

These groups and stage names are static, so changing MODEL still produces the same pipeline/.../compile|witness|prove IDs. Criterion will compare different models as if they were the same benchmark. Put the selected model name in either the group or function ID.

Also applies to: 122-126, 202-206

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@rust/jstprove_circuits/benches/pipeline_ecc.rs` around lines 42 - 46, The
benchmark group name "pipeline/bn254" and the bench_function IDs ("compile",
"witness", "prove") are static so different MODEL values are being conflated;
include the MODEL identifier in the Criterion IDs by appending or interpolating
the MODEL value into the group name or the per-stage function name where group
is constructed (the call to c.benchmark_group("pipeline/bn254")) and where
group.bench_function("compile" / "witness" / "prove") is invoked (also at the
other similar sites around lines 122-126 and 202-206) so each model produces
unique benchmark IDs like format!("pipeline/{}/compile", MODEL) or
format!("compile/{}", MODEL).
rust/jstprove_circuits/benches/logup_chunking.rs (1)

25-40: ⚠️ Potential issue | 🟡 Minor

Include MODEL in the Criterion ID.

Right now only the chunk-width label distinguishes these results. Re-running the same bench binary with a different MODEL will reuse the same logup_chunking/... report path and compare unrelated models.

Proposed fix
-    let mut group = c.benchmark_group("logup_chunking");
+    let mut group = c.benchmark_group(format!("logup_chunking/{model_name}"));
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@rust/jstprove_circuits/benches/logup_chunking.rs` around lines 25 - 40, The
Criterion benchmark ID currently only uses the chunk_bits label, causing
different MODEL runs to collide; update the ID creation in the bench loop where
BenchmarkId::new(...) is called (and the associated label variable) to include
the MODEL identifier (e.g., append or incorporate MODEL or a MODEL string
variable into the label/BenchmarkId) so each run is namespaced by both
chunk_bits and MODEL; adjust the label construction (the match on chunk_bits) or
create a new label_with_model and pass that to
group.bench_with_input/BenchmarkId::new to ensure unique report paths per MODEL.
🧹 Nitpick comments (2)
rust/jstprove_circuits/benches/autotuner_strategies.rs (1)

44-67: Make the timed autotuner configuration explicit.

compile_bn254 only invokes autotune_chunk_bits when the passed metadata has logup_chunk_bits == None (rust/jstprove_circuits/src/onnx.rs:655-668). The prewarm block does that, but the measured cold/warm paths reload fresh params and pass them through unchanged, so benchmark semantics depend on whatever generate_from_onnx currently returns. Setting it to None here keeps both benches pinned to the autotuner path.

Proposed fix
-    let params = metadata.circuit_params.clone();
+    let mut params = metadata.circuit_params.clone();
+    params.logup_chunk_bits = None;
@@
-    let params = metadata.circuit_params.clone();
+    let mut params = metadata.circuit_params.clone();
+    params.logup_chunk_bits = None;

Also applies to: 100-123

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@rust/jstprove_circuits/benches/autotuner_strategies.rs` around lines 44 - 67,
The benchmark currently passes the generated params through unchanged so
autotuning is conditional; to force the autotuner path set the params'
logup_chunk_bits to None before using them in the benchmark setup and the
compile calls (i.e., clone metadata.circuit_params into a mutable variable, set
params.logup_chunk_bits = None, then use that params in OnnxContext::set_all and
compile_bn254); do this both in the cold_compile bench setup and the warm/cold
measured blocks referenced around compile_bn254 and where OnnxContext::set_all
is called so autotune_chunk_bits is always invoked.
rust/jstprove_circuits/benches/pipeline_ecc.rs (1)

71-97: These witness/prove numbers still go through the legacy *_from_f64 helpers.

All three backends call the f64 path from rust/jstprove_circuits/src/onnx.rs:382-468, so these benches will not reflect behavior in a prequantized witness entrypoint. If this file is meant to track the production pipeline, switch to that API or rename these as legacy-f64 benchmarks.

Also applies to: 151-177, 232-258

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@rust/jstprove_circuits/benches/pipeline_ecc.rs` around lines 71 - 97, The
benchmarks use the legacy f64 helpers (witness_bn254_from_f64) so they exercise
the old path; replace calls to witness_bn254_from_f64 with the
production/prequantized witness entrypoint (the prequantized witness API used by
the pipeline) and ensure the corresponding prove call (prove_bn254) consumes the
prequantized witness object, or if you intend to keep testing the legacy route,
rename these benches (the "witness" and "prove" bench_function blocks) to
indicate they are legacy-f64; apply the same change to the other similar
benchmark blocks in this file that call witness_bn254_from_f64.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@rust/jstprove_circuits/benches/logup_chunking.rs`:
- Around line 29-60: The adaptive benchmark (chunk_bits == None) is measuring a
warmed autotuner because the autotuner cache written by compile_bn254 (and code
in OnnxContext/onnx.rs) is never cleared; fix by invoking the same
cache-clearing routine used in autotuner_strategies benches before each adaptive
run or split the adaptive case into explicit cold and warm benchmarks: locate
the adaptive branch in the bench loop (where compile_bn254 is called) and either
call the cache-clear helper from
rust/jstprove_circuits/benches/autotuner_strategies.rs (replicate its logic that
removes the autotuner cache) right before calling compile_bn254, or add two
BenchmarkId entries for "adaptive_cold" and "adaptive_warm" and ensure only the
cold variant clears the cache so results are comparable to the fixed-width
Some(...) runs.

In `@rust/jstprove_circuits/src/circuit_functions/layers/conv_transpose.rs`:
- Around line 247-253: The build method computes spatial_rank from
kernel_shape.len() but the apply method uses fixed 2D indexing; add a guard in
build to ensure spatial_rank == 2 and validate that the vectors strides,
dilation, output_padding, and pads have the expected lengths (strides.len()==2,
dilation.len()==2, output_padding.len()==2, pads.len()==4) before constructing
the layer (e.g., in the code block after kernel_shape and spatial_rank). If any
check fails return an appropriate Err (layer construction error) with a clear
message referencing the layer.name and the offending parameter so non-2D
ConvTranspose layers fail gracefully instead of panicking.
- Around line 125-126: The subtraction for out_h and out_w uses usize and can
underflow; compute the intermediate values in signed space (e.g., cast stride_h,
h_in, eff_kh, out_pad_h, pad_h_begin, pad_h_end and their width counterparts to
i64), perform the arithmetic in i64, check that the resulting out_h_i64 and
out_w_i64 are > 0, then cast to usize only after validation; if either is <= 0
return or propagate an error (or early-return a Result) instead of creating the
ArrayD. Ensure you update the variables computed in the block that sets out_h
and out_w (referencing out_h, out_w, stride_h, h_in, eff_kh, out_pad_h,
pad_h_begin, pad_h_end, stride_w, w_in, eff_kw, out_pad_w, pad_w_begin,
pad_w_end) so no usize subtraction occurs unchecked.
- Around line 107-110: Add explicit rank and shape validation for the tensors
before any multi-dimensional indexing: check that weights.ndim() == 4, compute
kh and kw from self.kernel_shape and verify weights.shape()[0] == c_in,
weights.shape()[1] * self.group as usize == c_out, and weights.shape()[2..4] ==
[kh, kw]; for bias (if Some) assert bias.ndim() == 1 and bias.shape()[0] ==
c_out. Perform these checks at the start of the conv-transpose routine (before
the lines that access weights.shape()[1], bias[[oc]], and weights[[ci, oc, ki_h,
ki_w]]), and return a clear error (or panic) with descriptive messages
referencing the offending tensor and expected shape so malformed external
tensors fail fast and clearly.

In `@rust/jstprove_circuits/src/circuit_functions/layers/layer_norm.rs`:
- Around line 157-165: The per-element tolerance is set to scaling^2
(per_elem_tol_native = self.scaling * self.scaling), which is too large and
effectively disables the affine output constraint; change the tolerance to be on
the order of scaling (e.g., per_elem_tol_native = self.scaling) so the allowed
error is alpha-scale rather than alpha^2, and recompute per_elem_tolerance and
per_elem_tol_bits accordingly (adjust the usages of per_elem_tol_native,
per_elem_tolerance, and per_elem_tol_bits in the layer_norm implementation to
reflect this smaller tolerance).
- Around line 145-151: The current computation of norm_var_tolerance uses the
same expression as n_alpha_sq which cancels the variance target when computing
var_diff/var_shifted; change norm_var_tol_native (and thus norm_var_tolerance
and norm_var_tol_bits) so the tolerance is strictly smaller than n_alpha_sq
(e.g., use a smaller scalar factor or a fixed minimal tolerance such as 1 *
self.scaling * self.scaling or (lane_size as u64 * self.scaling * self.scaling)
/ 2) so that var_diff = norm_sq_sum - n_alpha_sq is not zeroed out by the
tolerance; update the values for norm_var_tol_native, norm_var_tolerance and
recompute norm_var_tol_bits accordingly and verify behavior around var_diff,
var_shifted, norm_sq_sum and n_alpha_sq.
- Around line 145-151: LayerNormLayer::build must reject scale_exponent >= 32 to
avoid silent u64 overflow when computing self.scaling * self.scaling (used to
produce norm_var_tol_native and related constants); add an explicit guard at the
start of LayerNormLayer::build that checks the layer config's scale_exponent and
returns an Err with a clear message if scale_exponent >= 32 (matching semantics
used by Tanh/Cos layers), instead of relying on the existing checked_shl guard
used only for shifts. Ensure the error is returned from LayerNormLayer::build
before any use of self.scaling, norm_var_tol_native, or similar constant
computations.
- Around line 213-219: The range_check call fails for scale_exponent == 0
because LogupRangeCheckContext::range_check rejects n_bits == 0; modify the
layer_norm logic around norm_rem so that when self.scale_exponent == 0 you do
not call range_check but instead assert that norm_rem is zero (e.g. use
api.assert_is_zero or api.assert_is_equal(norm_rem, zero_var)), otherwise keep
the existing path calling logup_ctx.range_check::<C, Builder>(api, norm_rem,
self.scale_exponent as usize). Update LayerNormLayer::build validation if needed
to keep behavior consistent for scale_exponent == 0.

In `@rust/jstprove_onnx/src/quantizer.rs`:
- Around line 570-609: The reshape handling does not validate fixed-shape
targets when there are zero -1 dimensions; inside the block where n_unknown == 0
(i.e., after computing n_unknown from dims and when input_total is known), add
an anyhow::ensure! (or equivalent check) that dims.iter().filter_map(|d|
d).product::<usize>() == input_total and return a clear error using layer.name
if it mismatches so invalid reshapes are rejected before further bound
estimation; update the branch handling n_unknown (the same area that currently
handles n_unknown == 1 and >1, using variables dims, input_total, and
layer.name).

---

Duplicate comments:
In `@rust/jstprove_circuits/benches/logup_chunking.rs`:
- Around line 25-40: The Criterion benchmark ID currently only uses the
chunk_bits label, causing different MODEL runs to collide; update the ID
creation in the bench loop where BenchmarkId::new(...) is called (and the
associated label variable) to include the MODEL identifier (e.g., append or
incorporate MODEL or a MODEL string variable into the label/BenchmarkId) so each
run is namespaced by both chunk_bits and MODEL; adjust the label construction
(the match on chunk_bits) or create a new label_with_model and pass that to
group.bench_with_input/BenchmarkId::new to ensure unique report paths per MODEL.

In `@rust/jstprove_circuits/benches/pipeline_ecc.rs`:
- Around line 42-46: The benchmark group name "pipeline/bn254" and the
bench_function IDs ("compile", "witness", "prove") are static so different MODEL
values are being conflated; include the MODEL identifier in the Criterion IDs by
appending or interpolating the MODEL value into the group name or the per-stage
function name where group is constructed (the call to
c.benchmark_group("pipeline/bn254")) and where group.bench_function("compile" /
"witness" / "prove") is invoked (also at the other similar sites around lines
122-126 and 202-206) so each model produces unique benchmark IDs like
format!("pipeline/{}/compile", MODEL) or format!("compile/{}", MODEL).

---

Nitpick comments:
In `@rust/jstprove_circuits/benches/autotuner_strategies.rs`:
- Around line 44-67: The benchmark currently passes the generated params through
unchanged so autotuning is conditional; to force the autotuner path set the
params' logup_chunk_bits to None before using them in the benchmark setup and
the compile calls (i.e., clone metadata.circuit_params into a mutable variable,
set params.logup_chunk_bits = None, then use that params in OnnxContext::set_all
and compile_bn254); do this both in the cold_compile bench setup and the
warm/cold measured blocks referenced around compile_bn254 and where
OnnxContext::set_all is called so autotune_chunk_bits is always invoked.

In `@rust/jstprove_circuits/benches/pipeline_ecc.rs`:
- Around line 71-97: The benchmarks use the legacy f64 helpers
(witness_bn254_from_f64) so they exercise the old path; replace calls to
witness_bn254_from_f64 with the production/prequantized witness entrypoint (the
prequantized witness API used by the pipeline) and ensure the corresponding
prove call (prove_bn254) consumes the prequantized witness object, or if you
intend to keep testing the legacy route, rename these benches (the "witness" and
"prove" bench_function blocks) to indicate they are legacy-f64; apply the same
change to the other similar benchmark blocks in this file that call
witness_bn254_from_f64.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

Run ID: b514bf43-772e-4d64-a2f9-f42d11ff5993

📥 Commits

Reviewing files that changed from the base of the PR and between bda679d and c035960.

📒 Files selected for processing (12)
  • compiler/circuit-std-rs/src/logup.rs
  • rust/jstprove_circuits/Cargo.toml
  • rust/jstprove_circuits/benches/autotuner_strategies.rs
  • rust/jstprove_circuits/benches/layer_bench.rs
  • rust/jstprove_circuits/benches/logup_chunking.rs
  • rust/jstprove_circuits/benches/pipeline_ecc.rs
  • rust/jstprove_circuits/src/circuit_functions/hints/pow.rs
  • rust/jstprove_circuits/src/circuit_functions/layers/conv_transpose.rs
  • rust/jstprove_circuits/src/circuit_functions/layers/layer_norm.rs
  • rust/jstprove_circuits/src/onnx.rs
  • rust/jstprove_circuits/src/runner/main_runner.rs
  • rust/jstprove_onnx/src/quantizer.rs
✅ Files skipped from review due to trivial changes (2)
  • rust/jstprove_circuits/Cargo.toml
  • rust/jstprove_circuits/benches/layer_bench.rs
🚧 Files skipped from review as they are similar to previous changes (3)
  • rust/jstprove_circuits/src/circuit_functions/hints/pow.rs
  • rust/jstprove_circuits/src/runner/main_runner.rs
  • rust/jstprove_circuits/src/onnx.rs

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 8

♻️ Duplicate comments (1)
rust/jstprove_circuits/src/circuit_functions/layers/conv_transpose.rs (1)

179-189: ⚠️ Potential issue | 🟠 Major

Validate bias rank before 1-D indexing

Line 188 uses bias[[oc]] (1-D indexing), but only bias.shape()[0] is checked. A non-1D bias with matching first dimension will panic at runtime.

Proposed fix
-            if bias.shape()[0] != c_out {
+            if bias.ndim() != 1 || bias.shape()[0] != c_out {
                 return Err(LayerError::InvalidShape {
                     layer: LayerKind::ConvTranspose,
-                    msg: format!("bias length {} != output channels {c_out}", bias.shape()[0]),
+                    msg: format!(
+                        "bias must be 1-D with length {c_out}, got shape {:?}",
+                        bias.shape()
+                    ),
                 }
                 .into());
             }
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@rust/jstprove_circuits/src/circuit_functions/layers/conv_transpose.rs` around
lines 179 - 189, The code currently only checks bias.shape()[0] but then uses
1-D indexing bias[[oc]], which will panic if bias is not rank-1; update the
validation in the ConvTranspose bias handling (the block that returns
LayerError::InvalidShape and then fills r via r.slice_mut and bias[[oc]]) to
also check bias.ndim() == 1 (or equivalently bias.shape().len() == 1) and return
a LayerError::InvalidShape with a clear message if not 1-D before any indexing
occurs.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@rust/jstprove_circuits/benches/autotuner_strategies.rs`:
- Around line 39-41: The code currently returns early when model_path() doesn't
exist, causing benches (bench_cold_compile and bench_warm_compile) to register
no Criterion cases; replace the silent early return with a hard fail that
includes diagnostic info: when path.exists() is false, panic (or assert!) with a
clear message containing the resolved path and mention of the MODEL env var so
the failure is actionable (e.g., "MODEL resolved to {path} - ONNX file not
found"); apply this change in both bench_cold_compile and bench_warm_compile
where model_path() is checked.

In `@rust/jstprove_circuits/benches/pipeline_ecc.rs`:
- Around line 32-33: The benchmark uses metadata.circuit_params from
expander_metadata::generate_from_onnx which can be adaptive (logup_chunk_bits ==
None) causing mixed cold/warm runs; ensure each compile benchmark runs on a
single cache regime by either clearing the autotuner cache at the start of each
measured iteration (call the project’s autotuner cache-clear function before
compiling) or by overriding params.circuit_params.logup_chunk_bits to
Some(fixed_bits) after obtaining metadata so the adaptive path is disabled;
update all places using metadata.circuit_params (the let params =
metadata.circuit_params.clone() sites) in this bench (and the other noted
ranges) to apply one of these fixes.
- Around line 27-30: The benchmark currently silently returns when the selected
model path is missing (let path = model_path(); if !path.exists() { return; }),
which hides a bad MODEL value; change this to fail fast by replacing the early
return with a hard failure (e.g., panic! or std::process::exit with an error
message) that includes the missing path and context so the binary fails loudly;
make the identical change in the corresponding checks in bench_bn254,
bench_goldilocks, and bench_goldilocks_basefold (and the other occurrences noted
around the same file ranges).

In `@rust/jstprove_circuits/src/circuit_functions/layers/conv_transpose.rs`:
- Around line 112-116: The code around the ConvTranspose weights check is
misformatted; run `cargo fmt` and commit the formatted output to satisfy `cargo
fmt --check`. Specifically, reformat the block containing the `if weights.ndim()
!= 4 { ... }` that constructs `LayerError::InvalidShape` (references:
`weights.ndim()`, `LayerError::InvalidShape`, `LayerKind::ConvTranspose`) and
the nearby similar block later in the file (the other weights-shape validation),
ensuring braces/commas/indentation match rustfmt style and then commit the
changes.
- Around line 332-359: The loop that validates parameter vector lengths
(checking STRIDES, DILATION, "output_padding", PADS) must also guard against
zero stride values to avoid divide/modulo-by-zero in the ConvTranspose apply
logic; update the validation after retrieving strides (the strides: Vec<u32>
from get_param_or_default) to check each element is > 0 and return
LayerError::UnsupportedConfig (LayerKind::ConvTranspose) with a clear message
naming the layer and parameter if any stride equals 0. Keep the existing length
checks and apply the non-zero check specifically for the strides variable (and
similarly for dilation if needed) so downstream code in the apply method won’t
panic.

In `@rust/jstprove_circuits/src/circuit_functions/layers/layer_norm.rs`:
- Around line 149-155: The computed norm_var_tol_native can be zero causing
norm_var_tol_bits == 0 and a failing range_check; update the logic around
norm_var_tol_native / norm_var_tolerance / norm_var_tol_bits so that if
norm_var_tol_native == 0 you do not call range_check on var_diff but instead
assert var_diff == 0 (mirroring the existing norm_rem handling), otherwise
compute norm_var_tol_bits and call api.range_check(var_diff, norm_var_tol_bits,
norm_var_tolerance) as before; reference the variables norm_var_tol_native,
norm_var_tolerance, norm_var_tol_bits and the value var_diff and modify the
conditional flow where range_check is invoked.
- Around line 142-149: The current computation of n_alpha_sq and
norm_var_tol_native multiplies lane_size as u64 * self.scaling * self.scaling
and can overflow even when scale_exponent < 32 (e.g., scale_exponent=31 with
lane_size≥4). Fix by either (preferred) performing the full product using U256
(compute U256::from(lane_size) * U256::from(self.scaling).pow(2) and use that to
create n_alpha_sq and norm_var_tol_native as U256-based constants) or
(alternate) add a validation where the existing scale_exponent guard is enforced
to check that (lane_size as u128) * (self.scaling as u128) * (self.scaling as
u128) <= u64::MAX and return an error if not; update uses of n_alpha_sq,
norm_var_tol_native, and any code that assumes they fit in u64 accordingly (look
for symbols n_alpha_sq, norm_var_tol_native, self.scaling, scale_exponent,
lane_size in layer_norm.rs).

---

Duplicate comments:
In `@rust/jstprove_circuits/src/circuit_functions/layers/conv_transpose.rs`:
- Around line 179-189: The code currently only checks bias.shape()[0] but then
uses 1-D indexing bias[[oc]], which will panic if bias is not rank-1; update the
validation in the ConvTranspose bias handling (the block that returns
LayerError::InvalidShape and then fills r via r.slice_mut and bias[[oc]]) to
also check bias.ndim() == 1 (or equivalently bias.shape().len() == 1) and return
a LayerError::InvalidShape with a clear message if not 1-D before any indexing
occurs.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

Run ID: 14a557b8-1974-4a75-b6b7-d59030873df5

📥 Commits

Reviewing files that changed from the base of the PR and between c035960 and 917ef56.

📒 Files selected for processing (6)
  • rust/jstprove_circuits/benches/autotuner_strategies.rs
  • rust/jstprove_circuits/benches/logup_chunking.rs
  • rust/jstprove_circuits/benches/pipeline_ecc.rs
  • rust/jstprove_circuits/src/circuit_functions/layers/conv_transpose.rs
  • rust/jstprove_circuits/src/circuit_functions/layers/layer_norm.rs
  • rust/jstprove_onnx/src/quantizer.rs
✅ Files skipped from review due to trivial changes (1)
  • rust/jstprove_circuits/benches/logup_chunking.rs

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

♻️ Duplicate comments (1)
rust/jstprove_onnx/src/quantizer.rs (1)

545-546: ⚠️ Potential issue | 🟠 Major

Reshape still conflates unknown input shape with empty shape and mis-handles -1/0 semantics.

Line 545 fabricates [] for unknown input shape, so Line 546 computes a fake total (1). This can incorrectly infer [-1] -> [1], reject valid empty reshapes (because Line 571 requires input_total > 0), and silently emit 0 at Line 555 when 0-copy source dim is unavailable. This is a correctness issue in shape propagation.

Suggested fix
-                        let in_shape: &[usize] = input_shape.as_deref().unwrap_or(&[]);
-                        let input_total: usize = in_shape.iter().product();
+                        let in_shape = input_shape.as_deref();
+                        let input_total = in_shape.map(|s| s.iter().product::<usize>());

@@
-                                    Some(in_shape.get(i).copied().unwrap_or(0))
+                                    Some(
+                                        in_shape
+                                            .and_then(|s| s.get(i).copied())
+                                            .ok_or_else(|| anyhow::anyhow!(
+                                                "Reshape layer '{}': cannot resolve 0 at index {} without a known input shape",
+                                                layer.name, i
+                                            ))?,
+                                    )
@@
-                        if n_unknown == 1 && input_total > 0 {
+                        if n_unknown > 1 {
+                            anyhow::bail!(
+                                "Reshape layer '{}': more than one -1 dimension is not allowed",
+                                layer.name
+                            );
+                        }
+
+                        if let Some(input_total) = input_total {
+                            if n_unknown == 1 {
+                                let known: usize = dims.iter().filter_map(|&d| d).product();
+                                if known == 0 {
+                                    anyhow::bail!(
+                                        "Reshape layer '{}': cannot infer -1 dimension when known product is 0",
+                                        layer.name
+                                    );
+                                }
+                                if input_total % known != 0 {
+                                    anyhow::bail!(
+                                        "Reshape layer '{}': input total {} is not divisible by known dims product {}",
+                                        layer.name, input_total, known
+                                    );
+                                }
+                                let inferred = input_total / known;
+                                for d in &mut dims {
+                                    if d.is_none() {
+                                        *d = Some(inferred);
+                                        break;
+                                    }
+                                }
+                            }
+                        } else if n_unknown == 1 {
+                            anyhow::bail!(
+                                "Reshape layer '{}': cannot infer -1 dimension without a known input shape",
+                                layer.name
+                            );
+                        }
-                        } else if n_unknown > 1 {
-                            anyhow::bail!(
-                                "Reshape layer '{}': more than one -1 dimension is not allowed",
-                                layer.name
-                            );
-                        }

Also applies to: 555-556, 571-615

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@rust/jstprove_onnx/src/quantizer.rs` around lines 545 - 546, The code
incorrectly treats an unknown input_shape as an empty slice by using
input_shape.as_deref().unwrap_or(&[]), causing input_total to be computed as 1
and breaking -1/0 reshape semantics; instead, preserve the unknown shape
(Option) and compute input_total only when the shape is known (e.g., match on
input_shape.as_deref()), making input_total an Option<usize> so downstream logic
(the reshape handling in the function around the blocks referencing input_total,
the 0-copy dim handling at the region around lines 555-556, and the validation
logic around lines 571-615) can properly: (1) infer a single -1 only when
input_total is Some and remaining dims allow inference, (2) treat a 0-dim as
“copy from corresponding input dim” only if the input shape is Some and that
source dim exists, otherwise error, and (3) avoid rejecting valid
unknown-to-unknown reshape cases — update the branches that currently assume
input_total > 0 to handle None appropriately and emit errors only when a
required concrete input_total or source dim is missing.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Duplicate comments:
In `@rust/jstprove_onnx/src/quantizer.rs`:
- Around line 545-546: The code incorrectly treats an unknown input_shape as an
empty slice by using input_shape.as_deref().unwrap_or(&[]), causing input_total
to be computed as 1 and breaking -1/0 reshape semantics; instead, preserve the
unknown shape (Option) and compute input_total only when the shape is known
(e.g., match on input_shape.as_deref()), making input_total an Option<usize> so
downstream logic (the reshape handling in the function around the blocks
referencing input_total, the 0-copy dim handling at the region around lines
555-556, and the validation logic around lines 571-615) can properly: (1) infer
a single -1 only when input_total is Some and remaining dims allow inference,
(2) treat a 0-dim as “copy from corresponding input dim” only if the input shape
is Some and that source dim exists, otherwise error, and (3) avoid rejecting
valid unknown-to-unknown reshape cases — update the branches that currently
assume input_total > 0 to handle None appropriately and emit errors only when a
required concrete input_total or source dim is missing.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

Run ID: 1d32b06b-0163-4396-b701-0f057498c914

📥 Commits

Reviewing files that changed from the base of the PR and between 917ef56 and 08992a1.

📒 Files selected for processing (3)
  • rust/jstprove_circuits/src/circuit_functions/layers/conv_transpose.rs
  • rust/jstprove_circuits/src/circuit_functions/layers/layer_norm.rs
  • rust/jstprove_onnx/src/quantizer.rs
🚧 Files skipped from review as they are similar to previous changes (1)
  • rust/jstprove_circuits/src/circuit_functions/layers/conv_transpose.rs

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (1)
rust/jstprove_circuits/benches/autotuner_strategies.rs (1)

93-116: Consider hoisting metadata parsing to avoid duplicate ONNX parsing.

generate_from_onnx(&path) is called twice—once in the pre-populate block (line 95) and again for the measured iterations (line 109). While this works correctly and keeps the pre-populate block self-contained, you could hoist the parsing above the block to reuse the same metadata.

♻️ Proposed simplification
+    let metadata = expander_metadata::generate_from_onnx(&path).unwrap();
+    let mut params = metadata.circuit_params.clone();
+    params.logup_chunk_bits = None;
+    let arch = metadata.architecture.clone();
+    let wandb = metadata.wandb.clone();
+
     // Pre-populate the cache with one cold compile.
     {
-        let metadata = expander_metadata::generate_from_onnx(&path).unwrap();
-        let mut params = metadata.circuit_params.clone();
-        params.logup_chunk_bits = None;
         clear_cache();
-        OnnxContext::set_all(metadata.architecture, params.clone(), Some(metadata.wandb));
+        OnnxContext::set_all(arch.clone(), params.clone(), Some(wandb.clone()));
         let tmp = tempfile::TempDir::new().unwrap();
         compile_bn254(
             tmp.path().join("circuit.bundle").to_str().unwrap(),
             false,
             Some(params),
+            Some(params.clone()),
         )
         .unwrap();
     }
-
-    let metadata = expander_metadata::generate_from_onnx(&path).unwrap();
-    // Force logup_chunk_bits = None so the measured iterations hit the autotuner
-    // cache path rather than bypassing the autotuner with a fixed chunk width.
-    let mut params = metadata.circuit_params.clone();
-    params.logup_chunk_bits = None;
-    let arch = metadata.architecture.clone();
-    let wandb = metadata.wandb.clone();
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@rust/jstprove_circuits/benches/autotuner_strategies.rs` around lines 93 -
116, Hoist the ONNX parsing by calling
expander_metadata::generate_from_onnx(&path) once and reuse the resulting
metadata for both the pre-populate block and the measured iterations; replace
the two separate generate_from_onnx calls with a single let metadata = ...
before the pre-populate block, remove the second call, and reuse
metadata.circuit_params (cloning only when mutating params for pre-populate and
measurements), keeping references to metadata.architecture and metadata.wandb
for OnnxContext::set_all and later variables (arch, wandb) so compile_bn254 and
the autotuner path use the same parsed metadata.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@rust/jstprove_circuits/src/circuit_functions/layers/conv_transpose.rs`:
- Around line 305-321: The code currently allows zero-valued kernel dimensions
after checking spatial_rank; update the conv-transpose layer build to reject any
zero in kernel_shape by validating kernel_shape.iter().any(|&d| d == 0) and
returning an error (e.g., LayerError::UnsupportedConfig for
LayerKind::ConvTranspose with a message including layer.name and the invalid
kernel_shape) so layers like [0,k] or [k,0] fail fast during construction.
- Around line 64-70: The code clones the entire input tensor into layer_input
causing heavy memory/time cost; change it to borrow the tensor instead by
removing .clone() so layer_input is a reference: use let layer_input =
input.get(&input_name).ok_or_else(|| LayerError::MissingInput { layer:
LayerKind::ConvTranspose, name: input_name.clone(), })?; and then update any
downstream usage in this function (and signatures if needed) to accept &Tensor
(or &Activation) rather than an owned value so the tensor is not duplicated on
the hot path.
- Around line 219-230: The calculations for ih_num and iw_num should use i64
(not i32) to match earlier i64 dimension math: change casts like "oh as i32" and
"ow as i32" to "oh as i64"/"ow as i64" and keep ih_num/iw_num as i64, perform
modulus and >= checks using i64 (e.g., "% stride_h as i64"), then only convert
the final index to usize after bounds checks (e.g., compute ih = (ih_num as
usize) / (stride_h as usize) or use checked conversion). Update all related
comparisons and modulus operations for stride_h/stride_w, dil_h/dil_w,
pad_h_begin/pad_w_begin to use i64 to avoid silent truncation when computing
ih_num/iw_num in the conv-transpose loop.

---

Nitpick comments:
In `@rust/jstprove_circuits/benches/autotuner_strategies.rs`:
- Around line 93-116: Hoist the ONNX parsing by calling
expander_metadata::generate_from_onnx(&path) once and reuse the resulting
metadata for both the pre-populate block and the measured iterations; replace
the two separate generate_from_onnx calls with a single let metadata = ...
before the pre-populate block, remove the second call, and reuse
metadata.circuit_params (cloning only when mutating params for pre-populate and
measurements), keeping references to metadata.architecture and metadata.wandb
for OnnxContext::set_all and later variables (arch, wandb) so compile_bn254 and
the autotuner path use the same parsed metadata.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

Run ID: ca7dd288-ac90-4097-a72f-8c336356f7dc

📥 Commits

Reviewing files that changed from the base of the PR and between 08992a1 and 025fe65.

📒 Files selected for processing (4)
  • rust/jstprove_circuits/benches/autotuner_strategies.rs
  • rust/jstprove_circuits/benches/pipeline_ecc.rs
  • rust/jstprove_circuits/src/circuit_functions/layers/conv_transpose.rs
  • rust/jstprove_circuits/src/circuit_functions/layers/layer_norm.rs
✅ Files skipped from review due to trivial changes (1)
  • rust/jstprove_circuits/benches/pipeline_ecc.rs

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
rust/jstprove_circuits/src/circuit_functions/layers/conv_transpose.rs (1)

251-265: Consider using the actual layer name for better diagnostics.

Line 261 hardcodes "ConvTranspose" as the layer_name, but other layers typically use the actual ONNX node name for clearer error messages. The layer name isn't currently stored in the struct, so this would require adding it.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@rust/jstprove_circuits/src/circuit_functions/layers/conv_transpose.rs` around
lines 251 - 265, The code currently passes a hardcoded layer_name
"ConvTranspose" into MaybeRescaleParams; instead add a layer_name field to the
ConvTranspose struct (e.g., pub layer_name: String), populate it when the
ConvTranspose instance is constructed from the ONNX node (use the node's actual
name), and then pass self.layer_name.clone() into MaybeRescaleParams in the
maybe_rescale call; update any constructors/creators that build ConvTranspose
(and any trait impls like new_from_node) to accept/store the node name so
diagnostics use the real ONNX node name.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@rust/jstprove_circuits/src/circuit_functions/layers/conv_transpose.rs`:
- Around line 251-265: The code currently passes a hardcoded layer_name
"ConvTranspose" into MaybeRescaleParams; instead add a layer_name field to the
ConvTranspose struct (e.g., pub layer_name: String), populate it when the
ConvTranspose instance is constructed from the ONNX node (use the node's actual
name), and then pass self.layer_name.clone() into MaybeRescaleParams in the
maybe_rescale call; update any constructors/creators that build ConvTranspose
(and any trait impls like new_from_node) to accept/store the node name so
diagnostics use the real ONNX node name.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

Run ID: 64204de7-e573-42ed-bc7d-301d69da7a55

📥 Commits

Reviewing files that changed from the base of the PR and between 025fe65 and 4b4dbfd.

📒 Files selected for processing (1)
  • rust/jstprove_circuits/src/circuit_functions/layers/conv_transpose.rs

millioner and others added 18 commits April 8, 2026 22:24
…rrors, reject unsupported attrs, enforce input non-negativity, fix Pow bound and MatMul K

  factor, validate Split sum
… handle Pad constant fill, MatMul left-constant L1 norm, and ReduceSum noop_with_empty_axes

  semantics
… tanh scale, rename MatMul n_bits, add tanh error tests
…sts, update sqrt docs, and annotate metal_bench raw bytes conversion
…generation

- Full ConvTranspose circuit implementation using inverse-mapping algorithm
  matching ONNX spec (weight layout [C_in, C_out/group, kH, kW])
- ConvTranspose benchmark (compile/witness/prove) following AveragePool pattern
- Fix run_gen_witness double-scaling bug: inputs from input.msgpack are
  pre-quantized i64 values; the old path re-multiplied by alpha via
  witness_from_f64_generic, causing public-input verification to fail.
  New witness_from_prequantized uses apply_input_data (raw i64 → field,
  no alpha re-scaling), making run_gen_verify -i/-o pass correctly.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…-bounds diagnostics, and add missing flatten_rmpv_to_f64 helper
… add pow NaN/Inf fallback tests, and harden Reshape dimension validation in

  propagate_shapes
…ive cache clearing), add ConvTranspose shape/underflow guards, tighten LayerNorm

  tolerance constraints, and validate fixed Reshape targets
…ches, add ConvTranspose zero-stride and bias ndim guards, and fix LayerNorm variance

  tolerance overflow and zero-bits crash
@millioner millioner force-pushed the JSTPRV-164_per_layer_benchmarks branch from 4b4dbfd to 93ad1be Compare April 8, 2026 20:25
Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

♻️ Duplicate comments (2)
rust/jstprove_circuits/src/runner/main_runner.rs (1)

1049-1054: ⚠️ Potential issue | 🔴 Critical

Keep the file-input contract consistent across witness and debug flows.

Line 1049 now routes InputData straight into witness_from_prequantized, but run_debug_verify_onnx in this same file still flattens the same payload to f64 and re-applies alpha scaling through build_debug_assignment. Those two paths now disagree about whether input.msgpack is raw or pre-quantized, so one of them will generate the wrong witness/trace.

Possible fix

Pick one contract and enforce it at the boundary:

  • keep run_witness_from_inputs on the raw-f64 path, or
  • move the debug/input-file path to genuinely pre-quantized inputs and validate that format before solving.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@rust/jstprove_circuits/src/runner/main_runner.rs` around lines 1049 - 1054,
The two input paths disagree about whether InputData (from input.msgpack) is raw
f64 or pre-quantized: witness_from_prequantized is being fed InputData directly
while run_debug_verify_onnx still flattens to f64 and re-applies alpha via
build_debug_assignment; make the contract consistent by choosing one boundary
and enforcing it—either change run_witness_from_inputs/run_debug_verify_onnx to
pass pre-quantized InputData into witness_from_prequantized (and add validation
of quantized format), or convert the witness_from_prequantized call to accept
raw f64 (rename/use a non-quantized API) so both paths use the same raw-f64
flow; update the code around witness_from_prequantized, run_debug_verify_onnx,
build_debug_assignment, and run_witness_from_inputs to reflect the chosen
contract and add a validation check for input.msgpack format accordingly.
rust/jstprove_circuits/src/circuit_functions/layers/conv_transpose.rs (1)

64-70: ⚠️ Potential issue | 🟠 Major

Borrow the input activation tensor instead of cloning it.

Line 64 clones the entire activation tensor even though this layer only reads from it. On large feature maps that is a full extra allocation/copy per ConvTransposeLayer::apply call and will noticeably skew the new layer benchmarks.

Suggested fix
-        let layer_input = input
-            .get(&input_name.clone())
-            .ok_or_else(|| LayerError::MissingInput {
-                layer: LayerKind::ConvTranspose,
-                name: input_name.clone(),
-            })?
-            .clone();
+        let layer_input = input.get(&input_name).ok_or_else(|| LayerError::MissingInput {
+            layer: LayerKind::ConvTranspose,
+            name: input_name.clone(),
+        })?;
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@rust/jstprove_circuits/src/circuit_functions/layers/conv_transpose.rs` around
lines 64 - 70, The code currently clones the activation with .clone() when
retrieving it from the input map (in the block that constructs layer_input),
causing an unnecessary full-copy; instead borrow the tensor: remove .clone() and
bind layer_input as a reference (e.g. let layer_input =
input.get(&input_name.clone()).ok_or_else(|| LayerError::MissingInput { layer:
LayerKind::ConvTranspose, name: input_name.clone(), })?;), then update any
downstream code in ConvTransposeLayer::apply that expects an owned tensor to
accept &Tensor (or call .as_ref() where needed) so no full allocation is
performed for read-only access. Ensure references use the same names input_name
and LayerKind::ConvTranspose so the change is locatable.
🧹 Nitpick comments (1)
rust/jstprove_circuits/src/onnx.rs (1)

573-631: Factor the two-pass witness solve into a shared helper.

This block now duplicates most of witness_from_f64_generic, so probe/final-pass fixes and output-handling changes will need to be kept in sync manually. A small private helper that takes a prepared assignment would keep the raw-f64 and pre-quantized paths from drifting.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@rust/jstprove_circuits/src/onnx.rs` around lines 573 - 631, The function
witness_from_prequantized duplicates the two-pass probe/final witness solving
logic found in witness_from_f64_generic; extract that shared logic into a small
private helper (e.g. solve_two_pass_witness) that accepts the prepared
assignment (probe_assignment), the loaded witness_solver (witness_solver),
hint_registry (hint_registry), layered_circuit (layered_circuit) and params
(CircuitParams) and returns the final serialized witness bytes and output_i64
(or a WitnessBundle-like result); then replace the duplicated blocks in
witness_from_prequantized and witness_from_f64_generic to call this helper so
probe pass, output length checks, computed_outputs -> output_i64 conversion,
final solve and serialize_witness are centralized and kept in sync.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@rust/jstprove_circuits/src/circuit_functions/layers/sqrt.rs`:
- Around line 4-7: Move the call to range_check(x) before the new_hint
registration (SQRT_HINT_KEY / RootAPI::new_hint) and add a short comment
clarifying this is a readability/maintenance reorder, not a soundness fix
because new_hint only registers a hint and does not evaluate it; if you want
true input validation for SQRT_HINT_KEY ensure the check is implemented inside
the hint logic itself rather than relying on call order. Reference the functions
range_check and expander_compiler::frontend::RootAPI::new_hint (and the constant
SQRT_HINT_KEY) so reviewers can find and verify the change.

In `@rust/jstprove_circuits/src/runner/main_runner.rs`:
- Around line 1801-1809: The map-branch currently uses filter_map to silently
drop non-string keys, causing partial flattening; instead, update the Value::Map
handling in flatten_rmpv_to_f64 to reject non-string keys by returning an error
when any map key is not a string: iterate entries, check k.as_str() for each
(instead of filter_map), and if it returns None return a descriptive Err
(propagating the error type used by flatten_rmpv_to_f64) that includes that a
non-string map key was encountered and ideally the key's type; continue to sort
and call flatten_rmpv_to_f64(v, out)? only for valid string-keyed pairs.

---

Duplicate comments:
In `@rust/jstprove_circuits/src/circuit_functions/layers/conv_transpose.rs`:
- Around line 64-70: The code currently clones the activation with .clone() when
retrieving it from the input map (in the block that constructs layer_input),
causing an unnecessary full-copy; instead borrow the tensor: remove .clone() and
bind layer_input as a reference (e.g. let layer_input =
input.get(&input_name.clone()).ok_or_else(|| LayerError::MissingInput { layer:
LayerKind::ConvTranspose, name: input_name.clone(), })?;), then update any
downstream code in ConvTransposeLayer::apply that expects an owned tensor to
accept &Tensor (or call .as_ref() where needed) so no full allocation is
performed for read-only access. Ensure references use the same names input_name
and LayerKind::ConvTranspose so the change is locatable.

In `@rust/jstprove_circuits/src/runner/main_runner.rs`:
- Around line 1049-1054: The two input paths disagree about whether InputData
(from input.msgpack) is raw f64 or pre-quantized: witness_from_prequantized is
being fed InputData directly while run_debug_verify_onnx still flattens to f64
and re-applies alpha via build_debug_assignment; make the contract consistent by
choosing one boundary and enforcing it—either change
run_witness_from_inputs/run_debug_verify_onnx to pass pre-quantized InputData
into witness_from_prequantized (and add validation of quantized format), or
convert the witness_from_prequantized call to accept raw f64 (rename/use a
non-quantized API) so both paths use the same raw-f64 flow; update the code
around witness_from_prequantized, run_debug_verify_onnx, build_debug_assignment,
and run_witness_from_inputs to reflect the chosen contract and add a validation
check for input.msgpack format accordingly.

---

Nitpick comments:
In `@rust/jstprove_circuits/src/onnx.rs`:
- Around line 573-631: The function witness_from_prequantized duplicates the
two-pass probe/final witness solving logic found in witness_from_f64_generic;
extract that shared logic into a small private helper (e.g.
solve_two_pass_witness) that accepts the prepared assignment (probe_assignment),
the loaded witness_solver (witness_solver), hint_registry (hint_registry),
layered_circuit (layered_circuit) and params (CircuitParams) and returns the
final serialized witness bytes and output_i64 (or a WitnessBundle-like result);
then replace the duplicated blocks in witness_from_prequantized and
witness_from_f64_generic to call this helper so probe pass, output length
checks, computed_outputs -> output_i64 conversion, final solve and
serialize_witness are centralized and kept in sync.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

Run ID: 411ffa1a-c8ae-4236-9818-3fe4fd8f7949

📥 Commits

Reviewing files that changed from the base of the PR and between 4b4dbfd and 93ad1be.

📒 Files selected for processing (14)
  • compiler/circuit-std-rs/src/logup.rs
  • rust/jstprove_circuits/Cargo.toml
  • rust/jstprove_circuits/benches/autotuner_strategies.rs
  • rust/jstprove_circuits/benches/layer_bench.rs
  • rust/jstprove_circuits/benches/logup_chunking.rs
  • rust/jstprove_circuits/benches/pipeline_ecc.rs
  • rust/jstprove_circuits/src/circuit_functions/hints/pow.rs
  • rust/jstprove_circuits/src/circuit_functions/layers/conv_transpose.rs
  • rust/jstprove_circuits/src/circuit_functions/layers/layer_norm.rs
  • rust/jstprove_circuits/src/circuit_functions/layers/matmul.rs
  • rust/jstprove_circuits/src/circuit_functions/layers/sqrt.rs
  • rust/jstprove_circuits/src/onnx.rs
  • rust/jstprove_circuits/src/runner/main_runner.rs
  • rust/jstprove_onnx/src/quantizer.rs
✅ Files skipped from review due to trivial changes (3)
  • rust/jstprove_circuits/Cargo.toml
  • rust/jstprove_circuits/src/circuit_functions/layers/matmul.rs
  • rust/jstprove_circuits/benches/layer_bench.rs
🚧 Files skipped from review as they are similar to previous changes (5)
  • rust/jstprove_circuits/src/circuit_functions/hints/pow.rs
  • rust/jstprove_circuits/src/circuit_functions/layers/layer_norm.rs
  • rust/jstprove_circuits/benches/logup_chunking.rs
  • rust/jstprove_onnx/src/quantizer.rs
  • rust/jstprove_circuits/benches/pipeline_ecc.rs

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant