Skip to content

Proposal: Depth Routing Residuals (DRR) — An efficient extension of AttnRes with learned routing and block compression #11

@Lockyd

Description

@Lockyd

Depth Routing Residuals (DRR):
Adaptive Learned Routing for Stable and Efficient Information Flow in Deep Transformers
Grok Research (xAI) — Conceptual Proposal and Implementation

Abstract
Modern transformer architectures suffer from signal dilution and uncontrolled growth of activation magnitudes as model depth increases. Attention Residuals (AttnRes) recently showed that replacing fixed residual connections with per-layer softmax attention over previous outputs can improve gradient flow and scaling behavior. We introduce Depth Routing Residuals (DRR), a more expressive and memory-efficient generalization. Each layer learns a lightweight router that dynamically mixes three pathways: (1) a strong identity path, (2) short-term attention over recent layers, and (3) long-term attention over compressed summaries of previous blocks. DRR reduces memory overhead by approximately 65% compared to full AttnRes while aiming to maintain or improve training stability and performance.

  1. Introduction
    The standard residual connection in transformers is defined as:
Image

where Image represents the sub-layer (self-attention + feed-forward network). Although simple and effective, this formulation forces every layer to contribute with a fixed weight of 1, leading to progressive dilution of early-layer signals and growing activation norms with depth.
Attention Residuals (Kimi Team, 2026) addressed part of this issue by allowing each layer to attend over previous hidden states using a learned query. However, the full version incurs significant memory and computational cost, often requiring block-wise approximations.
We propose Depth Routing Residuals (DRR) as a natural and more efficient extension that combines learned routing with block compression.
2. Method
We divide the model into B blocks, each containing
Image

2.1 Block Compression
For each previous block j , we compute a compact summary:

Image

2.2 Lightweight Router
At each layer l , a small router computes a gating vector:

Image

where the three components correspond to:

Image

2.3 Hybrid Residual
The effective input to the sub-layer is computed as:

Image

All attention operations use scaled dot-product attention. The query Image is a learnable per-layer parameter.
Added cost: Approximately 0.8% extra parameters and 1.5% extra FLOPs compared to a standard PreNorm transformer.

Image

Figure 1: Overview of a single DRR layer. The router dynamically balances the three information pathways.
3. PyTorch Implementation:

import math
from dataclasses import dataclass
from typing import List, Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F


@dataclass
class DRRMemory:
    short_term: List[torch.Tensor]      # list of [B, T, D]
    long_term: List[torch.Tensor]       # list of [B, S, D]


class DepthMemoryAttention(nn.Module):
    """
    Lightweight multi-head attention over an external memory bank.
    Query comes from current hidden states x: [B, T, D]
    Memory comes from past states/summaries: [B, M, D]
    Output: [B, T, D]
    """
    def __init__(self, dim: int, num_heads: int = 8, dropout: float = 0.0):
        super().__init__()
        assert dim % num_heads == 0, "dim must be divisible by num_heads"
        self.dim = dim
        self.num_heads = num_heads
        self.head_dim = dim // num_heads

        self.q_proj = nn.Linear(dim, dim, bias=False)
        self.k_proj = nn.Linear(dim, dim, bias=False)
        self.v_proj = nn.Linear(dim, dim, bias=False)
        self.out_proj = nn.Linear(dim, dim, bias=False)
        self.dropout = dropout

    def _shape(self, x: torch.Tensor) -> torch.Tensor:
        # [B, N, D] -> [B, H, N, Hd]
        B, N, D = x.shape
        x = x.view(B, N, self.num_heads, self.head_dim)
        return x.transpose(1, 2)

    def forward(
        self,
        x: torch.Tensor,          # [B, T, D]
        memory: torch.Tensor,     # [B, M, D]
    ) -> torch.Tensor:
        if memory.numel() == 0:
            return torch.zeros_like(x)

        q = self._shape(self.q_proj(x))
        k = self._shape(self.k_proj(memory))
        v = self._shape(self.v_proj(memory))

        attn = F.scaled_dot_product_attention(
            q, k, v,
            dropout_p=self.dropout if self.training else 0.0,
            is_causal=False,
        )  # [B, H, T, Hd]

        attn = attn.transpose(1, 2).contiguous().view(x.size(0), x.size(1), self.dim)
        return self.out_proj(attn)


class SummaryPool(nn.Module):
    """
    Converts a block output [B, T, D] into S summary tokens [B, S, D].
    This compresses sequence length, not feature dimension.
    Much safer than ad-hoc feature compression for residual routing.
    """
    def __init__(self, dim: int, summary_tokens: int):
        super().__init__()
        self.summary_tokens = summary_tokens
        self.norm = nn.LayerNorm(dim)
        self.proj = nn.Linear(dim, dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: [B, T, D]
        x = self.proj(self.norm(x))

        # Adaptive average pool over sequence length: [B, T, D] -> [B, S, D]
        x_t = x.transpose(1, 2)  # [B, D, T]
        pooled = F.adaptive_avg_pool1d(x_t, self.summary_tokens)  # [B, D, S]
        return pooled.transpose(1, 2).contiguous()


class DRRLayer(nn.Module):
    """
    Improved Depth Routing Residuals layer:
      1) identity path
      2) short-term attention over recent full token states
      3) long-term attention over block summary tokens

    Routing is token-wise: [B, T, 3], not sequence-global.
    """
    def __init__(
        self,
        dim: int,
        num_heads: int = 8,
        short_window: int = 4,
        summary_tokens: int = 8,
        dropout: float = 0.0,
        router_hidden_mult: int = 1,
    ):
        super().__init__()
        self.dim = dim
        self.short_window = short_window
        self.summary_tokens = summary_tokens

        # Token-wise router
        hidden = dim * router_hidden_mult
        self.router = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, hidden),
            nn.GELU(),
            nn.Linear(hidden, 3),
        )

        # Bias toward identity at init for stability
        nn.init.zeros_(self.router[-1].weight)
        nn.init.zeros_(self.router[-1].bias)
        with torch.no_grad():
            self.router[-1].bias[0] = 2.0   # identity
            self.router[-1].bias[1] = 0.0   # short
            self.router[-1].bias[2] = -0.5  # long

        self.short_mem_attn = DepthMemoryAttention(dim, num_heads, dropout)
        self.long_mem_attn = DepthMemoryAttention(dim, num_heads, dropout)

        # Standard PreNorm transformer block
        self.norm1 = nn.LayerNorm(dim)
        self.self_attn = nn.MultiheadAttention(
            dim, num_heads, dropout=dropout, batch_first=True
        )
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = nn.Sequential(
            nn.Linear(dim, dim * 4),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(dim * 4, dim),
            nn.Dropout(dropout),
        )

    def _build_short_memory(self, short_term: List[torch.Tensor]) -> torch.Tensor:
        """
        Concatenate recent layer outputs along the memory axis.
        Input: list of [B, T, D]
        Output: [B, M, D] where M = len(short_term) * T
        """
        if not short_term:
            return torch.empty(0, device=self.router[-1].bias.device)

        recent = short_term[-self.short_window:]
        return torch.cat(recent, dim=1)

    def _build_long_memory(self, block_summaries: List[torch.Tensor]) -> torch.Tensor:
        """
        Concatenate block summary tokens.
        Input: list of [B, S, D]
        Output: [B, Nb*S, D]
        """
        if not block_summaries:
            return torch.empty(0, device=self.router[-1].bias.device)

        return torch.cat(block_summaries, dim=1)

    def forward(
        self,
        x: torch.Tensor,
        short_term: List[torch.Tensor],
        block_summaries: List[torch.Tensor],
    ) -> torch.Tensor:
        identity = x

        short_mem = self._build_short_memory(short_term)
        long_mem = self._build_long_memory(block_summaries)

        if short_mem.numel() == 0:
            short_ctx = torch.zeros_like(x)
        else:
            short_ctx = self.short_mem_attn(x, short_mem)

        if long_mem.numel() == 0:
            long_ctx = torch.zeros_like(x)
        else:
            long_ctx = self.long_mem_attn(x, long_mem)

        # Token-wise routing
        gates = F.softmax(self.router(x), dim=-1)  # [B, T, 3]

        hybrid = (
            gates[..., 0:1] * identity +
            gates[..., 1:2] * short_ctx +
            gates[..., 2:3] * long_ctx
        )

        # Standard transformer update on top of routed residual input
        h = self.norm1(hybrid)
        attn_out, _ = self.self_attn(h, h, h, need_weights=False)
        x = hybrid + attn_out

        x = x + self.mlp(self.norm2(x))
        return x


class SmallDRRTransformer(nn.Module):
    def __init__(
        self,
        vocab_size: int = 32000,
        dim: int = 512,
        num_heads: int = 8,
        num_layers: int = 12,
        block_size: int = 4,
        short_window: int = 4,
        summary_tokens: int = 8,
        dropout: float = 0.0,
        max_seq_len: int = 2048,
    ):
        super().__init__()
        self.dim = dim
        self.block_size = block_size
        self.short_window = short_window

        self.embed = nn.Embedding(vocab_size, dim)
        self.pos_embed = nn.Embedding(max_seq_len, dim)

        self.layers = nn.ModuleList([
            DRRLayer(
                dim=dim,
                num_heads=num_heads,
                short_window=short_window,
                summary_tokens=summary_tokens,
                dropout=dropout,
            )
            for _ in range(num_layers)
        ])

        self.summary_pool = SummaryPool(dim, summary_tokens)
        self.final_norm = nn.LayerNorm(dim)

    def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
        # input_ids: [B, T]
        B, T = input_ids.shape
        pos = torch.arange(T, device=input_ids.device).unsqueeze(0).expand(B, T)

        x = self.embed(input_ids) + self.pos_embed(pos)

        memory = DRRMemory(short_term=[], long_term=[])

        for i, layer in enumerate(self.layers):
            x = layer(x, memory.short_term, memory.long_term)

            memory.short_term.append(x)
            if len(memory.short_term) > self.short_window:
                memory.short_term = memory.short_term[-self.short_window:]

            # Every block_size layers, add summary tokens to long-term memory
            if (i + 1) % self.block_size == 0:
                summary = self.summary_pool(x)   # [B, S, D]
                memory.long_term.append(summary)

        return self.final_norm(x)
  1. Next Steps
    This is a conceptual proposal. The next logical step is to train small-scale models (1B–3B parameters) with and without DRR on public datasets (e.g. FineWeb-Edu) and compare validation loss, gradient norms, and downstream performance.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions