Skip to content

NeuralMemory forward when num_chunks == 0 #59

@x54-729

Description

@x54-729

Hello, I'm really appreciate for this work, but when I read the codes I have a question.

If input's seq_len < chunk_size, which leads to num_chunks == 0 in NeuralMemory, will the weights of memory_model be updated? In store_memories, the function seems to return with only unweighted_mem_model_loss, adaptive_lr changed:

if not exists(past_state):
# minibatch_init_weight corresponds to W0 in figure 7 of TTT paper
minibatch_init_weight = weights
init_momentum = self.init_momentum(batch)
past_state = (minibatch_init_weight, init_momentum)
past_last_update, past_last_momentum = past_state
# early return if sequence length less than chunk size
if num_chunks == 0:
updates = rearrange_dict_values(weights, 'bh ... -> bh 1 ...')
next_store_state = NeuralMemState(next_seq_len_index, weights, remainder, past_state, updates)
output = (updates, next_store_state)
if not return_surprises:
return output
return (*output, (unweighted_mem_model_loss, adaptive_lr))

and I didn't see any updates of parameter weight in NeuralMemory.forwawrd and MemoryAsContextTransformer.forward yet in this situation. It might be a stupid question but I really want to know if I have missed some states.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions