Introduce Conv+BatchNormalization fusion in constant folding pass#179
Introduce Conv+BatchNormalization fusion in constant folding pass#179HudsonGraeme merged 2 commits intomainfrom
Conversation
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: Organization UI Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (1)
🚧 Files skipped from review as they are similar to previous changes (1)
WalkthroughAdds a post-constant-fold optimization that fuses Conv+BatchNormalization pairs by merging BN parameters into Conv weights and biases, removing the BatchNormalization nodes and related initializers. Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Possibly related PRs
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (1)
crates/dsperse/src/slicer/onnx_fold.rs (1)
1077-1078: Consider cleaning up stale BN parameter initializers.After fusion, the BN parameter initializers (
gamma,beta,mean,var) are no longer referenced but remain ingraph.initializer. For large models with many BN layers, this could leave significant dead data. Not critical, but a cleanup pass could reduce serialized model size.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@crates/dsperse/src/slicer/onnx_fold.rs` around lines 1077 - 1078, The BN fusion leaves parameter initializers (gamma/beta/mean/var) in graph.initializer even after their BN nodes were fused; use the existing removed_bn set (the one you insert f.bn_idx into) to purge those stale initializers after the fusion pass by filtering graph.initializer to exclude any initializer whose index/name corresponds to entries in removed_bn; perform the removal in a safe way (collect indices/names to delete first or use retain/filter rather than mutating while iterating) and update any related maps/lookup structures so no dangling references remain.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@crates/dsperse/src/slicer/onnx_fold.rs`:
- Around line 1028-1048: The weight-folding code in onnx_fold.rs (the block
using tensor_to_f32 on w_init inside the loop) must guard against tensor_to_f32
returning an empty vector and must update the tensor data type after conversion;
modify the branch that sets w_init.float_data so that you only assign and clear
raw_data if w_data is non-empty (return false / skip if empty), and set the
initializer's data_type to FLOAT (or the appropriate f32 enum) after successful
conversion so the tensor metadata matches the new float_data; ensure these
checks are applied in the same scope where w_init.float_data, raw_data, and dims
are modified.
---
Nitpick comments:
In `@crates/dsperse/src/slicer/onnx_fold.rs`:
- Around line 1077-1078: The BN fusion leaves parameter initializers
(gamma/beta/mean/var) in graph.initializer even after their BN nodes were fused;
use the existing removed_bn set (the one you insert f.bn_idx into) to purge
those stale initializers after the fusion pass by filtering graph.initializer to
exclude any initializer whose index/name corresponds to entries in removed_bn;
perform the removal in a safe way (collect indices/names to delete first or use
retain/filter rather than mutating while iterating) and update any related
maps/lookup structures so no dangling references remain.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
Run ID: 2d3d9898-19b7-41f0-afce-5874002c0fc9
📒 Files selected for processing (1)
crates/dsperse/src/slicer/onnx_fold.rs
Detect Conv -> BatchNormalization pairs where Conv has a single consumer and fuse BN parameters into Conv weights and bias: W' = W * gamma / sqrt(var + eps) b' = (b - mean) * gamma / sqrt(var + eps) + beta If Conv has no bias input, a fused bias initializer is created and the Conv input list extended. The BN node is removed and Conv output redirected to the BN output name. Pattern from deep-prove (Lagrange Labs GKR-based ML proving) where Conv+BN fusion eliminates intermediate materialization and reduces circuit constraint count for BN-heavy models (ResNet, EfficientNet).
…ale init GC Investigate CodeRabbit feedback on PR #179 and apply three fixes at the point where the Conv+BatchNormalization fusion writes back the scaled weights and fused bias: * Guard the tensor_to_f32 conversion: the helper returns an empty vector for weight dtypes we don't yet convert (f16 / bf16). The previous code would still clear raw_data and assign the empty vector, leaving the Conv with a zero-length FLOAT weight that fails every downstream shape check. Skip the fusion for this Conv instead. * Stamp w_init.data_type = TensorProto::FLOAT after the float_data write. The initialiser may have arrived as half / bfloat encoded in raw_data; float_data is FLOAT by definition, so the tensor metadata must follow the new representation. Applied to the bias branch as well for consistency. * Garbage-collect the BN parameter initialisers (gamma, beta, running_mean, running_var) after a successful fusion. Their names are captured up-front into ConvBnFusion::stale_bn_param_names and merged into a stale_init_names set; a post-pass retain on graph.initializer drops any name in the set that is no longer referenced by a surviving node. The still_used guard prevents accidentally deleting an initialiser shared with an unrelated node elsewhere in the graph. Soundness: the fusion math (w' = gamma * w / sqrt(var + eps), b' = beta + gamma * (b - mean) / sqrt(var + eps)) is unchanged. All three fixes only tighten the post-fusion graph hygiene; none admit a previously-rejected fusion, so the set of fused pairs is a subset of the previous behaviour.
33a0965 to
9ddd750
Compare
Summary
Metrics
cargo test --lockedcargo clippy -D warningscargo fmt --checkMethodology
Cross-referenced DSperse's constant folding pipeline with deep-prove's Conv+BN fusion pattern and zk-torch's operator coverage. Identified that BatchNormalization nodes following Conv with single consumers can be algebraically absorbed: W' = W * gamma / sqrt(var + eps), b' = (b - mean) * scale + beta.
Implementation uses a two-phase approach: (1) collect all fusible pairs with owned BN parameter data, (2) apply mutations to graph initializers and nodes. This avoids simultaneous borrow conflicts between init_map and graph mutation.
Significance Verdict
significant— eliminates BN nodes entirely for fusible pairs, reducing slice count and circuit constraints for BN-heavy models (ResNet, EfficientNet, MobileNet). Expected 10-15% proof time reduction on models with Conv+BN chains.SR&ED Description
Algorithm Development for operator fusion Architecture in model preprocessing Engineering
What changed
onnx_fold.rs— AddedConvBnFusionstruct for collecting fusion data,fuse_conv_batchnorm()function that detects Conv->BN pairs, computes fused weights/bias, and removes BN nodes. Called fromfold_constant_nodes()after constant propagation.What did NOT work
First implementation used
HashMap<&str, &TensorProto>(init_map) that borrowed graph.initializer immutably while needing mutable access for weight patching. Restructured to collect all BN parameters into ownedConvBnFusionstructs in phase 1, then apply mutations in phase 2.Benchmark reproduction
cargo build --release --manifest-path crates/dsperse/Cargo.toml cargo test --locked cargo clippy --workspace --all-targets -- -D warnings cargo fmt --checkFiles changed
crates/dsperse/src/slicer/onnx_fold.rs— added Conv+BN fusion pass (+207 lines)Summary by CodeRabbit