bayesgroup/code_transformers

Questions about calculating MRR and predicting the code block

Opened this issue · 1 comments

From my understanding, mrr function uses the variable where to filter out tokens in memory (e.g., the first half tokens for intermediate blocks as explained in #4 ) , unknown tokens, and padding tokens. Then _mrr function sorts the top-10 predicted tokens for each position and uses the indexes of correctly predicted tokens of each position to calculate MRR.

My question is are all positions in this code block predicted at the same time or they are all predicted when calculating MRR (i.e., in _step() function of LightningModel class)? Let's say this code block is [range(250, 750), 250], then the range(250, 500) is the memory, and range(500, 750) is to be predicted. So when we calculate MRR of this block, is it for all the tokens in the range(500, 750)?

How does it complete the code block? Does it predict the 501th token based on range(250, 500) and then predict 502nd token based on range(250, 501)? Could you explain more and point out the related code.

Besides, how can I use a trained model for inference and get the vocab from their idx? Here is the script I wrote. I'm not sure about it. Please let me know your idea. Thanks for your help!

...
args.load = 'path to the trained model.ckpt'

model = LightningModel(args)
model.eval()

# use eval_dataset for example
eval_dataset = model.eval_dataset
eval_dataloader = DataLoader(eval_dataset, 
    num_workers=args.num_workers, collate_fn=model.collate_fn, drop_last=False, pin_memory=True, shuffle=False)

# get the vocab of tensor idx
def idx2vocab_func(batch, idx2vocab): # assuming batch = 1
    vocab_len = len(idx2vocab)
    batch = batch.view(-1).tolist()
    input_v = list()
    for i in batch:
        if i >= vocab_len:
            input_v.append('UNK')
        else:
            input_v.append(idx2vocab[i])
    return input_v

def get_pred(batch):
    y = batch['input_seq']['values']
    y_pred_types, y_pred_values = model(batch['input_seq'], rel=batch['rel_mask'], positions=batch['positions'])
    
    ext = batch['extended'].unsqueeze(-1).repeat(1, y.size(-1))
    ext_ids = torch.arange(y.size(-1), device=ext.device).view(1, -1).repeat(*(y.size()[:-1]+(1,)))
    where = ext_ids >= ext
    where = where.view(-1)

    y_pred_values = y_pred_values.view(-1, y_pred_values.size(-1))[where]

    _, y_pred_values = torch.topk(y_pred_values, k=1, dim=-1) # choose the top1 predicted token
    return y_pred_values.view(-1)

with torch.no_grad():
    for i, sample in enumerate(eval_dataloader):
        print('---------------Input----------------')
        input_v = sample['input_seq']['values']
        print(input_v.shape)
        print('Vocab:', ' '.join(idx2vocab_func(input_v, idx2vocab_value)))

        print('---------------Target----------------')
        target_v = sample['target_seq']['values']
        print(target_v.shape)
        print('Vocab:', ' '.join(idx2vocab_func(target_v, idx2vocab_value)))

        print('---------------Predicted----------------')
        pred_v = get_pred(sample)
        print(pred_v.shape)
        print('Vocab:', ' '.join(idx2vocab_func(pred_v, idx2vocab_value)))

Hi, thank you for the question! For code completion evaluation, we only predict the next type/value given the context, as in the teacher forcing regime. Predicting the whole subtree given the context is out of the scope of this work.