All attention operations use scaled dot-product attention. The query
is a learnable per-layer parameter.
Added cost: Approximately 0.8% extra parameters and 1.5% extra FLOPs compared to a standard PreNorm transformer.
Figure 1: Overview of a single DRR layer. The router dynamically balances the three information pathways.
3. PyTorch Implementation:
import math
from dataclasses import dataclass
from typing import List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
@dataclass
class DRRMemory:
short_term: List[torch.Tensor] # list of [B, T, D]
long_term: List[torch.Tensor] # list of [B, S, D]
class DepthMemoryAttention(nn.Module):
"""
Lightweight multi-head attention over an external memory bank.
Query comes from current hidden states x: [B, T, D]
Memory comes from past states/summaries: [B, M, D]
Output: [B, T, D]
"""
def __init__(self, dim: int, num_heads: int = 8, dropout: float = 0.0):
super().__init__()
assert dim % num_heads == 0, "dim must be divisible by num_heads"
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.q_proj = nn.Linear(dim, dim, bias=False)
self.k_proj = nn.Linear(dim, dim, bias=False)
self.v_proj = nn.Linear(dim, dim, bias=False)
self.out_proj = nn.Linear(dim, dim, bias=False)
self.dropout = dropout
def _shape(self, x: torch.Tensor) -> torch.Tensor:
# [B, N, D] -> [B, H, N, Hd]
B, N, D = x.shape
x = x.view(B, N, self.num_heads, self.head_dim)
return x.transpose(1, 2)
def forward(
self,
x: torch.Tensor, # [B, T, D]
memory: torch.Tensor, # [B, M, D]
) -> torch.Tensor:
if memory.numel() == 0:
return torch.zeros_like(x)
q = self._shape(self.q_proj(x))
k = self._shape(self.k_proj(memory))
v = self._shape(self.v_proj(memory))
attn = F.scaled_dot_product_attention(
q, k, v,
dropout_p=self.dropout if self.training else 0.0,
is_causal=False,
) # [B, H, T, Hd]
attn = attn.transpose(1, 2).contiguous().view(x.size(0), x.size(1), self.dim)
return self.out_proj(attn)
class SummaryPool(nn.Module):
"""
Converts a block output [B, T, D] into S summary tokens [B, S, D].
This compresses sequence length, not feature dimension.
Much safer than ad-hoc feature compression for residual routing.
"""
def __init__(self, dim: int, summary_tokens: int):
super().__init__()
self.summary_tokens = summary_tokens
self.norm = nn.LayerNorm(dim)
self.proj = nn.Linear(dim, dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: [B, T, D]
x = self.proj(self.norm(x))
# Adaptive average pool over sequence length: [B, T, D] -> [B, S, D]
x_t = x.transpose(1, 2) # [B, D, T]
pooled = F.adaptive_avg_pool1d(x_t, self.summary_tokens) # [B, D, S]
return pooled.transpose(1, 2).contiguous()
class DRRLayer(nn.Module):
"""
Improved Depth Routing Residuals layer:
1) identity path
2) short-term attention over recent full token states
3) long-term attention over block summary tokens
Routing is token-wise: [B, T, 3], not sequence-global.
"""
def __init__(
self,
dim: int,
num_heads: int = 8,
short_window: int = 4,
summary_tokens: int = 8,
dropout: float = 0.0,
router_hidden_mult: int = 1,
):
super().__init__()
self.dim = dim
self.short_window = short_window
self.summary_tokens = summary_tokens
# Token-wise router
hidden = dim * router_hidden_mult
self.router = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden),
nn.GELU(),
nn.Linear(hidden, 3),
)
# Bias toward identity at init for stability
nn.init.zeros_(self.router[-1].weight)
nn.init.zeros_(self.router[-1].bias)
with torch.no_grad():
self.router[-1].bias[0] = 2.0 # identity
self.router[-1].bias[1] = 0.0 # short
self.router[-1].bias[2] = -0.5 # long
self.short_mem_attn = DepthMemoryAttention(dim, num_heads, dropout)
self.long_mem_attn = DepthMemoryAttention(dim, num_heads, dropout)
# Standard PreNorm transformer block
self.norm1 = nn.LayerNorm(dim)
self.self_attn = nn.MultiheadAttention(
dim, num_heads, dropout=dropout, batch_first=True
)
self.norm2 = nn.LayerNorm(dim)
self.mlp = nn.Sequential(
nn.Linear(dim, dim * 4),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(dim * 4, dim),
nn.Dropout(dropout),
)
def _build_short_memory(self, short_term: List[torch.Tensor]) -> torch.Tensor:
"""
Concatenate recent layer outputs along the memory axis.
Input: list of [B, T, D]
Output: [B, M, D] where M = len(short_term) * T
"""
if not short_term:
return torch.empty(0, device=self.router[-1].bias.device)
recent = short_term[-self.short_window:]
return torch.cat(recent, dim=1)
def _build_long_memory(self, block_summaries: List[torch.Tensor]) -> torch.Tensor:
"""
Concatenate block summary tokens.
Input: list of [B, S, D]
Output: [B, Nb*S, D]
"""
if not block_summaries:
return torch.empty(0, device=self.router[-1].bias.device)
return torch.cat(block_summaries, dim=1)
def forward(
self,
x: torch.Tensor,
short_term: List[torch.Tensor],
block_summaries: List[torch.Tensor],
) -> torch.Tensor:
identity = x
short_mem = self._build_short_memory(short_term)
long_mem = self._build_long_memory(block_summaries)
if short_mem.numel() == 0:
short_ctx = torch.zeros_like(x)
else:
short_ctx = self.short_mem_attn(x, short_mem)
if long_mem.numel() == 0:
long_ctx = torch.zeros_like(x)
else:
long_ctx = self.long_mem_attn(x, long_mem)
# Token-wise routing
gates = F.softmax(self.router(x), dim=-1) # [B, T, 3]
hybrid = (
gates[..., 0:1] * identity +
gates[..., 1:2] * short_ctx +
gates[..., 2:3] * long_ctx
)
# Standard transformer update on top of routed residual input
h = self.norm1(hybrid)
attn_out, _ = self.self_attn(h, h, h, need_weights=False)
x = hybrid + attn_out
x = x + self.mlp(self.norm2(x))
return x
class SmallDRRTransformer(nn.Module):
def __init__(
self,
vocab_size: int = 32000,
dim: int = 512,
num_heads: int = 8,
num_layers: int = 12,
block_size: int = 4,
short_window: int = 4,
summary_tokens: int = 8,
dropout: float = 0.0,
max_seq_len: int = 2048,
):
super().__init__()
self.dim = dim
self.block_size = block_size
self.short_window = short_window
self.embed = nn.Embedding(vocab_size, dim)
self.pos_embed = nn.Embedding(max_seq_len, dim)
self.layers = nn.ModuleList([
DRRLayer(
dim=dim,
num_heads=num_heads,
short_window=short_window,
summary_tokens=summary_tokens,
dropout=dropout,
)
for _ in range(num_layers)
])
self.summary_pool = SummaryPool(dim, summary_tokens)
self.final_norm = nn.LayerNorm(dim)
def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
# input_ids: [B, T]
B, T = input_ids.shape
pos = torch.arange(T, device=input_ids.device).unsqueeze(0).expand(B, T)
x = self.embed(input_ids) + self.pos_embed(pos)
memory = DRRMemory(short_term=[], long_term=[])
for i, layer in enumerate(self.layers):
x = layer(x, memory.short_term, memory.long_term)
memory.short_term.append(x)
if len(memory.short_term) > self.short_window:
memory.short_term = memory.short_term[-self.short_window:]
# Every block_size layers, add summary tokens to long-term memory
if (i + 1) % self.block_size == 0:
summary = self.summary_pool(x) # [B, S, D]
memory.long_term.append(summary)
return self.final_norm(x)
Depth Routing Residuals (DRR):
Adaptive Learned Routing for Stable and Efficient Information Flow in Deep Transformers
Grok Research (xAI) — Conceptual Proposal and Implementation
Abstract
Modern transformer architectures suffer from signal dilution and uncontrolled growth of activation magnitudes as model depth increases. Attention Residuals (AttnRes) recently showed that replacing fixed residual connections with per-layer softmax attention over previous outputs can improve gradient flow and scaling behavior. We introduce Depth Routing Residuals (DRR), a more expressive and memory-efficient generalization. Each layer learns a lightweight router that dynamically mixes three pathways: (1) a strong identity path, (2) short-term attention over recent layers, and (3) long-term attention over compressed summaries of previous blocks. DRR reduces memory overhead by approximately 65% compared to full AttnRes while aiming to maintain or improve training stability and performance.
The standard residual connection in transformers is defined as:
where
represents the sub-layer (self-attention + feed-forward network). Although simple and effective, this formulation forces every layer to contribute with a fixed weight of 1, leading to progressive dilution of early-layer signals and growing activation norms with depth.

Attention Residuals (Kimi Team, 2026) addressed part of this issue by allowing each layer to attend over previous hidden states using a learned query. However, the full version incurs significant memory and computational cost, often requiring block-wise approximations.
We propose Depth Routing Residuals (DRR) as a natural and more efficient extension that combines learned routing with block compression.
2. Method
We divide the model into B blocks, each containing
2.1 Block Compression
For each previous block j , we compute a compact summary:
2.2 Lightweight Router
At each layer l , a small router computes a gating vector:
where the three components correspond to:
2.3 Hybrid Residual
The effective input to the sub-layer is computed as:
All attention operations use scaled dot-product attention. The query
is a learnable per-layer parameter.
Added cost: Approximately 0.8% extra parameters and 1.5% extra FLOPs compared to a standard PreNorm transformer.
Figure 1: Overview of a single DRR layer. The router dynamically balances the three information pathways.
3. PyTorch Implementation:
This is a conceptual proposal. The next logical step is to train small-scale models (1B–3B parameters) with and without DRR on public datasets (e.g. FineWeb-Edu) and compare validation loss, gradient norms, and downstream performance.