Skip to content

Commit 942c287

Browse files
committed
Fix AttnRes block semantics and two-phase inference
1 parent 162c186 commit 942c287

12 files changed

Lines changed: 356 additions & 257 deletions

File tree

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ h = h_l + F(h_l) V = [b₀; b₁; …; bₙ] ← stack all blo
3232
h = Σ αᵢ · Vᵢ ← weighted combination
3333
```
3434

35-
Each transformer layer has **two** AttnRes operations (before self-attention and before MLP), each with its own learned pseudo-query vector **w_l** initialized to zero. At initialization, all blocks receive equal weight (standard residual behavior). During training, the model learns to selectively route information from the most relevant depths.
35+
Each transformer layer has **two** AttnRes operations (before self-attention and before MLP), each with its own learned pseudo-query vector **w_l** initialized to zero. At initialization, all available sources receive equal weight. During training, the model learns to selectively route information from the most relevant depths.
3636

3737
## Quick Start
3838

examples/compare_residuals.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
//! Compare standard residual connections vs AttnRes.
22
//!
33
//! Shows that an AttnRes model with zero-initialized pseudo-queries
4-
//! starts as equivalent to uniform averaging over all prior blocks,
4+
//! starts as uniform averaging over all prior blocks,
55
//! and demonstrates the forward pass works correctly.
66
//!
77
//! Run with: `cargo run --example compare_residuals`
@@ -20,7 +20,7 @@ fn main() {
2020

2121
// Demo 1: Zero-init produces uniform weights (equivalent to mean)
2222
println!("1. Zero-initialized AttnRes = uniform averaging");
23-
println!(" (equivalent to standard residual connections)\n");
23+
println!(" (equal weights over all available sources)\n");
2424

2525
let config = AttnResConfig::new(32, 4, 2);
2626
let op: AttnResOp<B> = config.init_op(&device);

examples/demo_tui.rs

Lines changed: 5 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -204,10 +204,12 @@ fn train_step(
204204
fn compute_alpha(
205205
op: &AttnResOp<AB>,
206206
blocks: &[Tensor<AB, 3>],
207-
partial: &Tensor<AB, 3>,
207+
partial: Option<&Tensor<AB, 3>>,
208208
) -> Vec<f32> {
209209
let mut sources: Vec<Tensor<AB, 3>> = blocks.to_vec();
210-
sources.push(partial.clone());
210+
if let Some(partial) = partial {
211+
sources.push(partial.clone());
212+
}
211213
let n = sources.len();
212214

213215
let v = Tensor::stack(sources, 0); // [N+1, B, T, D]
@@ -268,26 +270,8 @@ fn extract_diagnostics(
268270
norms.push(norm_a);
269271
norms.push(norm_m);
270272

271-
// Replicate boundary handling from layer.forward() to get
272-
// the correct block state for alpha computation.
273-
let current_partial = state
274-
.partial_block
275-
.clone()
276-
.unwrap_or_else(|| Tensor::zeros_like(state.blocks.last().unwrap()));
277-
278-
let at_boundary = layer.is_at_boundary();
279-
let mut blocks_snap = state.blocks.clone();
280-
if at_boundary {
281-
blocks_snap.push(current_partial.clone());
282-
}
283-
let partial_snap = if at_boundary {
284-
Tensor::zeros_like(blocks_snap.last().unwrap())
285-
} else {
286-
current_partial
287-
};
288-
289273
// Compute actual attention weights for the attn sublayer
290-
let alpha = compute_alpha(attn_res, &blocks_snap, &partial_snap);
274+
let alpha = compute_alpha(attn_res, &state.blocks, state.partial_block.as_ref());
291275
depth_weights.push(alpha);
292276

293277
// Run the real forward to advance block state

examples/visualize_weights.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ fn main() {
141141
println!(" - This allows selective information routing across depth");
142142
println!();
143143
println!(" At initialization (zero pseudo-queries), all layers attend");
144-
println!(" uniformly, equivalent to standard residual connections.");
144+
println!(" uniformly across all available sources.");
145145
println!(" Training gradually differentiates the attention patterns.");
146146

147147
println!("\nDone!");

src/attn_res_op.rs

Lines changed: 56 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ impl AttnResConfig {
3434
///
3535
/// The pseudo-query is zero-initialized per the paper's requirement for
3636
/// training stability. This means the operation starts as uniform averaging
37-
/// over all sources (equivalent to standard residual connections).
37+
/// over all available sources.
3838
pub fn init_op<B: Backend>(&self, device: &B::Device) -> AttnResOp<B> {
3939
AttnResOp {
4040
// CRITICAL: zero initialization per paper requirement
@@ -47,47 +47,62 @@ impl AttnResConfig {
4747
}
4848

4949
impl<B: Backend> AttnResOp<B> {
50-
/// Compute attention residual over block representations.
51-
///
52-
/// # Arguments
53-
/// * `blocks` - Completed block representations [N tensors of shape [B, T, D]]
54-
/// * `partial_block` - Current intra-block partial sum [B, T, D]
50+
/// Compute attention residual over any available block representations.
5551
///
56-
/// # Returns
57-
/// * Attention-weighted combination of all sources [B, T, D]
58-
pub fn forward(&self, blocks: &[Tensor<B, 3>], partial_block: &Tensor<B, 3>) -> Tensor<B, 3> {
59-
// Step 1: Stack all sources into value matrix
60-
// V: [N+1, B, T, D]
52+
/// `partial_block` is optional because the first sublayer of the network
53+
/// and the first sublayer of each new block attend only over completed
54+
/// blocks (Eq. 6 in the paper) and therefore have no intra-block partial.
55+
pub fn forward_optional_partial(
56+
&self,
57+
blocks: &[Tensor<B, 3>],
58+
partial_block: Option<&Tensor<B, 3>>,
59+
) -> Tensor<B, 3> {
6160
let mut sources: Vec<Tensor<B, 3>> = blocks.to_vec();
62-
sources.push(partial_block.clone());
63-
let v = Tensor::stack(sources, 0); // [N+1, B, T, D]
61+
if let Some(partial_block) = partial_block {
62+
sources.push(partial_block.clone());
63+
}
64+
65+
assert!(
66+
!sources.is_empty(),
67+
"AttnResOp requires at least one source tensor"
68+
);
69+
70+
// Step 1: Stack all sources into value matrix
71+
// V: [N, B, T, D] or [N+1, B, T, D]
72+
let v = Tensor::stack(sources, 0);
6473

6574
// Step 2: Apply RMSNorm to get keys
66-
// K: [N+1, B, T, D]
75+
// K: same shape as V
6776
let k = self.norm.forward_4d(v.clone());
6877

6978
// Step 3: Compute attention logits
70-
// w: [D] -> [1, 1, 1, D] for broadcasting
71-
// logits = sum(K * w, dim=3) -> [N+1, B, T]
7279
let w = self
7380
.pseudo_query
7481
.val()
7582
.unsqueeze_dim::<2>(0)
7683
.unsqueeze_dim::<3>(0)
7784
.unsqueeze_dim::<4>(0); // [1, 1, 1, D]
78-
let logits = (k * w).sum_dim(3).squeeze_dim::<3>(3); // [N+1, B, T]
85+
let logits = (k * w).sum_dim(3).squeeze_dim::<3>(3);
7986

8087
// Step 4: Softmax over the depth dimension (dim=0)
81-
// CRITICAL: softmax over depth, NOT sequence
82-
let alpha = softmax(logits, 0); // [N+1, B, T]
88+
let alpha = softmax(logits, 0);
8389

8490
// Step 5: Weighted sum of values
85-
// alpha: [N+1, B, T] -> [N+1, B, T, 1]
86-
// v: [N+1, B, T, D]
87-
// result: sum over dim=0 -> [B, T, D]
88-
let alpha_expanded = alpha.unsqueeze_dim::<4>(3); // [N+1, B, T, 1]
89-
let weighted = v * alpha_expanded; // [N+1, B, T, D]
90-
weighted.sum_dim(0).squeeze_dim::<3>(0) // [B, T, D]
91+
let alpha_expanded = alpha.unsqueeze_dim::<4>(3);
92+
let weighted = v * alpha_expanded;
93+
weighted.sum_dim(0).squeeze_dim::<3>(0)
94+
}
95+
96+
/// Compute attention residual over block representations.
97+
///
98+
/// # Arguments
99+
/// * `blocks` - Completed block representations [N tensors of shape [B, T, D]]
100+
/// * `partial_block` - Current intra-block partial sum [B, T, D]
101+
///
102+
/// # Returns
103+
/// * Attention-weighted combination of all sources [B, T, D]
104+
pub fn forward(&self, blocks: &[Tensor<B, 3>], partial_block: &Tensor<B, 3>) -> Tensor<B, 3> {
105+
self.forward_optional_partial(blocks, Some(partial_block))
91106
}
92107
}
93108

@@ -160,4 +175,20 @@ mod tests {
160175
let diff: f32 = (output - expected).abs().max().into_scalar();
161176
assert!(diff < 1e-4, "Single block should produce mean, diff={diff}");
162177
}
178+
179+
#[test]
180+
fn test_blocks_only_returns_only_source() {
181+
let device = Default::default();
182+
let config = AttnResConfig::new(32, 4, 2);
183+
let op = config.init_op::<TestBackend>(&device);
184+
185+
let embedding = Tensor::random([1, 8, 32], Distribution::Normal(0.0, 1.0), &device);
186+
let output = op.forward_optional_partial(&[embedding.clone()], None);
187+
188+
let diff: f32 = (output - embedding).abs().max().into_scalar();
189+
assert!(
190+
diff < 1e-5,
191+
"A single completed block should be returned unchanged, diff={diff}"
192+
);
193+
}
163194
}

src/layer.rs

Lines changed: 48 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,26 @@ impl AttnResConfig {
6464
}
6565

6666
impl<B: Backend> AttnResLayer<B> {
67+
fn attn_sublayer_idx(&self) -> usize {
68+
self.layer_idx * 2
69+
}
70+
71+
fn mlp_sublayer_idx(&self) -> usize {
72+
self.attn_sublayer_idx() + 1
73+
}
74+
75+
fn starts_new_block_before_sublayer(&self, sublayer_idx: usize) -> bool {
76+
sublayer_idx > 0 && sublayer_idx.is_multiple_of(self.block_size)
77+
}
78+
79+
pub(crate) fn starts_new_block_before_attn(&self) -> bool {
80+
self.starts_new_block_before_sublayer(self.attn_sublayer_idx())
81+
}
82+
83+
pub(crate) fn starts_new_block_before_mlp(&self) -> bool {
84+
self.starts_new_block_before_sublayer(self.mlp_sublayer_idx())
85+
}
86+
6787
/// Get the layer index.
6888
pub fn layer_idx(&self) -> usize {
6989
self.layer_idx
@@ -74,10 +94,14 @@ impl<B: Backend> AttnResLayer<B> {
7494
self.block_size
7595
}
7696

77-
/// Check if this layer is at a block boundary.
97+
/// Check if this layer's attention sublayer starts a new block.
98+
///
99+
/// Block sizing is defined in sublayers, so the MLP sublayer can also
100+
/// start a new block when `block_size` is odd or when `block_size == 1`
101+
/// (Full AttnRes). This helper preserves the historical public API by
102+
/// reporting only the pre-attention boundary.
78103
pub fn is_at_boundary(&self) -> bool {
79-
let half_block = self.block_size / 2;
80-
self.layer_idx > 0 && (half_block == 0 || self.layer_idx.is_multiple_of(half_block))
104+
self.starts_new_block_before_attn()
81105
}
82106

83107
/// Get references to the AttnRes operations (attn_res, mlp_res).
@@ -116,33 +140,18 @@ impl<B: Backend> AttnResLayer<B> {
116140
/// # Returns
117141
/// * Updated block state
118142
pub fn forward(&self, mut state: BlockState<B>, mask: Option<&Tensor<B, 3>>) -> BlockState<B> {
119-
// Get the current partial block, or zeros if at the start of a new block
120-
let current_partial = state
121-
.partial_block
122-
.take()
123-
.unwrap_or_else(|| Tensor::zeros_like(state.blocks.last().unwrap()));
124-
125-
// === Check block boundary ===
126-
// Block boundary occurs every block_size/2 transformer layers (each layer = 2 sublayers).
127-
// For Full AttnRes (block_size=1), every layer after the first is a boundary.
128-
let half_block = self.block_size / 2;
129-
let at_boundary =
130-
self.layer_idx > 0 && (half_block == 0 || self.layer_idx.is_multiple_of(half_block));
131-
132-
if at_boundary {
133-
// Push the completed partial block as a new block
134-
state.blocks.push(current_partial.clone());
135-
}
136-
137-
// The partial block for AttnRes input: if we just pushed, start fresh; otherwise use current
138-
let partial_for_attn = if at_boundary {
139-
Tensor::zeros_like(state.blocks.last().unwrap())
140-
} else {
141-
current_partial
142-
};
143-
144143
// === AttnRes before self-attention ===
145-
let h = self.attn_res.forward(&state.blocks, &partial_for_attn);
144+
let current_partial = state.partial_block.take();
145+
let h = self
146+
.attn_res
147+
.forward_optional_partial(&state.blocks, current_partial.as_ref());
148+
149+
let mut partial_for_attn =
150+
current_partial.unwrap_or_else(|| Tensor::zeros_like(state.blocks.last().unwrap()));
151+
if self.starts_new_block_before_attn() {
152+
state.blocks.push(partial_for_attn.clone());
153+
partial_for_attn = Tensor::zeros_like(state.blocks.last().unwrap());
154+
}
146155

147156
// === Self-attention sublayer ===
148157
let normed = self.attn_norm.forward(h);
@@ -152,14 +161,22 @@ impl<B: Backend> AttnResLayer<B> {
152161
let partial_after_attn = partial_for_attn + attn_out;
153162

154163
// === AttnRes before MLP ===
155-
let h = self.mlp_res.forward(&state.blocks, &partial_after_attn);
164+
let h = self
165+
.mlp_res
166+
.forward_optional_partial(&state.blocks, Some(&partial_after_attn));
167+
168+
let mut partial_for_mlp = partial_after_attn;
169+
if self.starts_new_block_before_mlp() {
170+
state.blocks.push(partial_for_mlp.clone());
171+
partial_for_mlp = Tensor::zeros_like(state.blocks.last().unwrap());
172+
}
156173

157174
// === MLP sublayer ===
158175
let normed = self.mlp_norm.forward(h);
159176
let mlp_out = self.mlp.forward(normed);
160177

161178
// Update partial block with MLP output
162-
let partial_after_mlp = partial_after_attn + mlp_out;
179+
let partial_after_mlp = partial_for_mlp + mlp_out;
163180

164181
state.partial_block = Some(partial_after_mlp);
165182
state

0 commit comments

Comments
 (0)