Hello, I'm really appreciate for this work, but when I read the codes I have a question.
If input's seq_len < chunk_size, which leads to num_chunks == 0 in NeuralMemory, will the weights of memory_model be updated? In store_memories, the function seems to return with only unweighted_mem_model_loss, adaptive_lr changed:
|
if not exists(past_state): |
|
# minibatch_init_weight corresponds to W0 in figure 7 of TTT paper |
|
|
|
minibatch_init_weight = weights |
|
init_momentum = self.init_momentum(batch) |
|
|
|
past_state = (minibatch_init_weight, init_momentum) |
|
|
|
past_last_update, past_last_momentum = past_state |
|
|
|
# early return if sequence length less than chunk size |
|
|
|
if num_chunks == 0: |
|
updates = rearrange_dict_values(weights, 'bh ... -> bh 1 ...') |
|
next_store_state = NeuralMemState(next_seq_len_index, weights, remainder, past_state, updates) |
|
|
|
output = (updates, next_store_state) |
|
|
|
if not return_surprises: |
|
return output |
|
|
|
return (*output, (unweighted_mem_model_loss, adaptive_lr)) |
and I didn't see any updates of parameter weight in NeuralMemory.forwawrd and MemoryAsContextTransformer.forward yet in this situation. It might be a stupid question but I really want to know if I have missed some states.
Hello, I'm really appreciate for this work, but when I read the codes I have a question.
If input's
seq_len < chunk_size, which leads tonum_chunks == 0in NeuralMemory, will the weights ofmemory_modelbe updated? Instore_memories, the function seems to return with onlyunweighted_mem_model_loss, adaptive_lrchanged:titans-pytorch/titans_pytorch/neural_memory.py
Lines 726 to 747 in 7874d35
and I didn't see any updates of parameter
weightinNeuralMemory.forwawrdandMemoryAsContextTransformer.forwardyet in this situation. It might be a stupid question but I really want to know if I have missed some states.