def forward(
self,
seq,
store_seq = None,
state: NeuralMemState | None = None,
detach_mem_state = False,
prev_weights = None,
store_mask: Tensor | None = None,
return_surprises = False,
ttt_batch_size: int | None = None
):
is_multi_input = self.qkv_receives_diff_views
# handle single token
if seq.ndim == 2 or (is_multi_input and seq.ndim == 3):
seq = rearrange(seq, '... b d -> ... b 1 d')
is_single_token = seq.shape[-2] == 1
# if different views for qkv, then
if is_multi_input:
retrieve_seq, seq = seq[0], seq[1:]
else:
retrieve_seq = seq
# handle previous state init
if not exists(state):
state = (0, None, None, None, None)
seq_index, weights, cache_store_seq, past_state, updates = state
thanks, Gordon.
Here is the top of the
forwardfunction:Notice:
updatesis computes from the currentstate. Further down,updatesis set toNoneand accumulated viaaccum_updates. Is this intended? The code implies thatupdatesin thestateis not actually required.thanks, Gordon.