waylandzhang/Transformer-from-scratch

How to comprehend the token generate procession? just confused about the current context_length? I thought it was just a fixed parameter ....

jimmmwong opened this issue · 3 comments

def generate(self, idx, max_new_tokens):
# idx is (B,T) array of indices in the current context
for _ in range(max_new_tokens):
# Crop idx to the max size of our positional embeddings table
idx_crop = idx[:, -self.context_length:]
# Get predictions
logits, loss = self(idx_crop)
# Get the last time step from logits where the dimensions of the logits are (B,T,C)
logits_last_timestep = logits[:, -1, :]
# Apply softmax to get probabilities
probs = F.softmax(input=logits_last_timestep, dim=-1)
# Sample from the probabilities' distribution.
idx_next = torch.multinomial(input=probs, num_samples=1)
# Append the sampled indexes idx_next to idx
idx = torch.cat((idx, idx_next), dim=1)
return idx

A bit hard part to explain in pure texts, you can watch my video posts.. if you can read Chinese ^

idx_next = torch.multinomial(input=probs,num_samples=1) 这个怎么理解?我还以为想着应该是取probs最大值或者最小值的索引号,但是又不是这样

取的确实是token索引,只是取值方法被multinomial那个函数定义了。有多种取法,并不一定是最大概率的那个。