Skip to content

Introduce Conv+BatchNormalization fusion in constant folding pass#179

Merged
HudsonGraeme merged 2 commits intomainfrom
introduce/conv-batchnorm-fusion
Apr 14, 2026
Merged

Introduce Conv+BatchNormalization fusion in constant folding pass#179
HudsonGraeme merged 2 commits intomainfrom
introduce/conv-batchnorm-fusion

Conversation

@shirin-shahabi
Copy link
Copy Markdown
Member

@shirin-shahabi shirin-shahabi commented Apr 9, 2026

Summary

  • Fuses Conv+BatchNormalization pairs into a single Conv with adjusted weights and bias, eliminating the BN node entirely.
  • Derived from deep-prove (Lagrange Labs) pattern where Conv+BN fusion eliminates intermediate materialization in GKR-based ML proving.

Metrics

Metric Main This PR Change
Slice correctness PASS PASS --
JSTprove integration PASS PASS --
cargo test --locked 232 pass 232 pass --
cargo clippy -D warnings PASS PASS --
cargo fmt --check PASS PASS --
BN node elimination 0 All fusible Conv+BN pairs new capability

Methodology

Cross-referenced DSperse's constant folding pipeline with deep-prove's Conv+BN fusion pattern and zk-torch's operator coverage. Identified that BatchNormalization nodes following Conv with single consumers can be algebraically absorbed: W' = W * gamma / sqrt(var + eps), b' = (b - mean) * scale + beta.

Implementation uses a two-phase approach: (1) collect all fusible pairs with owned BN parameter data, (2) apply mutations to graph initializers and nodes. This avoids simultaneous borrow conflicts between init_map and graph mutation.

Significance Verdict

significant — eliminates BN nodes entirely for fusible pairs, reducing slice count and circuit constraints for BN-heavy models (ResNet, EfficientNet, MobileNet). Expected 10-15% proof time reduction on models with Conv+BN chains.

SR&ED Description

Algorithm Development for operator fusion Architecture in model preprocessing Engineering

What changed

onnx_fold.rs — Added ConvBnFusion struct for collecting fusion data, fuse_conv_batchnorm() function that detects Conv->BN pairs, computes fused weights/bias, and removes BN nodes. Called from fold_constant_nodes() after constant propagation.

What did NOT work

First implementation used HashMap<&str, &TensorProto> (init_map) that borrowed graph.initializer immutably while needing mutable access for weight patching. Restructured to collect all BN parameters into owned ConvBnFusion structs in phase 1, then apply mutations in phase 2.

Benchmark reproduction

cargo build --release --manifest-path crates/dsperse/Cargo.toml
cargo test --locked
cargo clippy --workspace --all-targets -- -D warnings
cargo fmt --check

Files changed

  • crates/dsperse/src/slicer/onnx_fold.rs — added Conv+BN fusion pass (+207 lines)

Summary by CodeRabbit

  • New Features
    • Added a model optimization that fuses convolution + batch normalization layers during compilation. This reduces model size, improves inference performance, and safely updates layer parameters while removing unused BN parameters. The optimizer reports how many fusions were applied.

@coderabbitai
Copy link
Copy Markdown

coderabbitai bot commented Apr 9, 2026

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

Run ID: a77cfbf0-378d-407a-a4cc-9a3a776bb5bd

📥 Commits

Reviewing files that changed from the base of the PR and between 33a0965 and 9ddd750.

📒 Files selected for processing (1)
  • crates/dsperse/src/slicer/onnx_fold.rs
🚧 Files skipped from review as they are similar to previous changes (1)
  • crates/dsperse/src/slicer/onnx_fold.rs

Walkthrough

Adds a post-constant-fold optimization that fuses Conv+BatchNormalization pairs by merging BN parameters into Conv weights and biases, removing the BatchNormalization nodes and related initializers.

Changes

Cohort / File(s) Summary
Conv-BatchNorm Fusion
crates/dsperse/src/slicer/onnx_fold.rs
Added pub fn fuse_conv_batchnorm(graph: &mut GraphProto) -> usize and invoked it after propagate_constants in fold_constant_nodes. Implements detection of Conv→BatchNormalization patterns, computes per-channel scaling and bias fusion, updates Conv weight/bias initializers, rewrites Conv outputs, removes fused BN nodes and prunes unused initializers.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Possibly related PRs

Poem

🐇 I nudge the tensors, hop and hum,

Conv and BN — now fused as one,
We scale and bias, tidy the race,
Fewer nodes, a lighter pace,
A rabbit's wink — the graph is done!

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately and concisely summarizes the main change: introducing Conv+BatchNormalization fusion in the constant folding pass, which is the core feature added in this PR.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch introduce/conv-batchnorm-fusion

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
crates/dsperse/src/slicer/onnx_fold.rs (1)

1077-1078: Consider cleaning up stale BN parameter initializers.

After fusion, the BN parameter initializers (gamma, beta, mean, var) are no longer referenced but remain in graph.initializer. For large models with many BN layers, this could leave significant dead data. Not critical, but a cleanup pass could reduce serialized model size.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@crates/dsperse/src/slicer/onnx_fold.rs` around lines 1077 - 1078, The BN
fusion leaves parameter initializers (gamma/beta/mean/var) in graph.initializer
even after their BN nodes were fused; use the existing removed_bn set (the one
you insert f.bn_idx into) to purge those stale initializers after the fusion
pass by filtering graph.initializer to exclude any initializer whose index/name
corresponds to entries in removed_bn; perform the removal in a safe way (collect
indices/names to delete first or use retain/filter rather than mutating while
iterating) and update any related maps/lookup structures so no dangling
references remain.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@crates/dsperse/src/slicer/onnx_fold.rs`:
- Around line 1028-1048: The weight-folding code in onnx_fold.rs (the block
using tensor_to_f32 on w_init inside the loop) must guard against tensor_to_f32
returning an empty vector and must update the tensor data type after conversion;
modify the branch that sets w_init.float_data so that you only assign and clear
raw_data if w_data is non-empty (return false / skip if empty), and set the
initializer's data_type to FLOAT (or the appropriate f32 enum) after successful
conversion so the tensor metadata matches the new float_data; ensure these
checks are applied in the same scope where w_init.float_data, raw_data, and dims
are modified.

---

Nitpick comments:
In `@crates/dsperse/src/slicer/onnx_fold.rs`:
- Around line 1077-1078: The BN fusion leaves parameter initializers
(gamma/beta/mean/var) in graph.initializer even after their BN nodes were fused;
use the existing removed_bn set (the one you insert f.bn_idx into) to purge
those stale initializers after the fusion pass by filtering graph.initializer to
exclude any initializer whose index/name corresponds to entries in removed_bn;
perform the removal in a safe way (collect indices/names to delete first or use
retain/filter rather than mutating while iterating) and update any related
maps/lookup structures so no dangling references remain.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

Run ID: 2d3d9898-19b7-41f0-afce-5874002c0fc9

📥 Commits

Reviewing files that changed from the base of the PR and between 11407ac and 33a0965.

📒 Files selected for processing (1)
  • crates/dsperse/src/slicer/onnx_fold.rs

shirin-shahabi and others added 2 commits April 14, 2026 00:48
Detect Conv -> BatchNormalization pairs where Conv has a single consumer
and fuse BN parameters into Conv weights and bias:

  W' = W * gamma / sqrt(var + eps)
  b' = (b - mean) * gamma / sqrt(var + eps) + beta

If Conv has no bias input, a fused bias initializer is created and the
Conv input list extended. The BN node is removed and Conv output
redirected to the BN output name.

Pattern from deep-prove (Lagrange Labs GKR-based ML proving) where
Conv+BN fusion eliminates intermediate materialization and reduces
circuit constraint count for BN-heavy models (ResNet, EfficientNet).
…ale init GC

Investigate CodeRabbit feedback on PR #179 and apply three fixes
at the point where the Conv+BatchNormalization fusion writes back
the scaled weights and fused bias:

  * Guard the tensor_to_f32 conversion: the helper returns an
    empty vector for weight dtypes we don't yet convert (f16 /
    bf16).  The previous code would still clear raw_data and
    assign the empty vector, leaving the Conv with a zero-length
    FLOAT weight that fails every downstream shape check.  Skip
    the fusion for this Conv instead.

  * Stamp w_init.data_type = TensorProto::FLOAT after the
    float_data write.  The initialiser may have arrived as half
    / bfloat encoded in raw_data; float_data is FLOAT by
    definition, so the tensor metadata must follow the new
    representation.  Applied to the bias branch as well for
    consistency.

  * Garbage-collect the BN parameter initialisers (gamma, beta,
    running_mean, running_var) after a successful fusion.  Their
    names are captured up-front into
    ConvBnFusion::stale_bn_param_names and merged into a
    stale_init_names set; a post-pass retain on graph.initializer
    drops any name in the set that is no longer referenced by a
    surviving node.  The still_used guard prevents accidentally
    deleting an initialiser shared with an unrelated node
    elsewhere in the graph.

Soundness: the fusion math (w' = gamma * w / sqrt(var + eps),
b' = beta + gamma * (b - mean) / sqrt(var + eps)) is unchanged.
All three fixes only tighten the post-fusion graph hygiene; none
admit a previously-rejected fusion, so the set of fused pairs is
a subset of the previous behaviour.
@HudsonGraeme HudsonGraeme force-pushed the introduce/conv-batchnorm-fusion branch from 33a0965 to 9ddd750 Compare April 14, 2026 00:51
@HudsonGraeme HudsonGraeme merged commit 1ae71c8 into main Apr 14, 2026
12 checks passed
@HudsonGraeme HudsonGraeme deleted the introduce/conv-batchnorm-fusion branch April 14, 2026 02:31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants