Hi, thanks for the great work on this repository.
I have a question regarding the implementation of “memory as context” and how it aligns with the formulation in the paper (Section 4.1).
Paper formulation
According to the paper, the computation flow for memory as context is:
- Retrieve historical information from long-term memory using the current segment as query:
[
h_t = \mathcal{M}^*_{t-1}(q_t)
]
- Apply attention over persistent memory parameters, retrieved memory, and current segment:
[
\tilde{S}^{(t)} = [p_1, \dots, p_{N_p}] ; || ; h_t ; || ; S^{(t)}
]
[
y_t = \text{Attn}(\tilde{S}^{(t)})
]
- Update long-term memory using the attention output:
[
\mathcal{M}t = \mathcal{M}{t-1}(y_t)
]
Conceptually, this corresponds to the order:
retrieve → attention → memory update
Code behavior
In the code, I see the following pattern:
retrieved, next_neural_mem_cache = mem.forward(
qkv_mem_input,
state = next(neural_mem_caches, None),
prev_weights = mem_weight_residual
) # retrieved h_t
attn_out, (values, next_kv_cache) = attn(
attn_in,
value_residual = value_residual,
disable_flex_attn = disable_flex_attn,
flex_attn_fn = flex_attn_fn,
output_gating = attn_out_gates,
cache = next(kv_caches, None)
)
From my understanding:
- Calling `self.mem.forward(...)` already returns both:
- `retrieved` (i.e. \( h_t \))
- `next_neural_mem_cache` (updated memory state)
- Attention is then computed **after** this memory forward pass.
This seems slightly different from the paper, where the memory update happens *after* the attention output
\( y_t \) is computed.
---
### Questions
1. Is the memory update inside `mem.forward` equivalent to the paper’s update step
\[
\mathcal{M}_{t-1}(y_t)
\]
or is this a deliberate reordering / approximation for implementation efficiency?
2. Does updating the memory inside `mem.forward` risk information leakage from the current step into the memory used by attention?
3. Conceptually, is this implementation still faithful to the **“memory as context”** formulation, or is the paper describing an idealized ordering?
Hi, thanks for the great work on this repository.
I have a question regarding the implementation of “memory as context” and how it aligns with the formulation in the paper (Section 4.1).
Paper formulation
According to the paper, the computation flow for memory as context is:
[
h_t = \mathcal{M}^*_{t-1}(q_t)
]
[
\tilde{S}^{(t)} = [p_1, \dots, p_{N_p}] ; || ; h_t ; || ; S^{(t)}
]
[
y_t = \text{Attn}(\tilde{S}^{(t)})
]
[
\mathcal{M}t = \mathcal{M}{t-1}(y_t)
]
Conceptually, this corresponds to the order:
retrieve → attention → memory update
Code behavior
In the code, I see the following pattern: