Potential Bug iterator Pytorch
MattePalte opened this issue · 1 comments
Hi @wasiahmad
I noticed a potential bug while processing small testing dataset:
- testing dataset = 78 functions
- batch size = 64
At this location:
https://github.com/wasiahmad/NeuralCodeSum/blob/master/main/test.py#L275
def validate_official(args, data_loader, model):
"""Run one full official validation. Uses exact spans and same
exact match/F1 score computation as in the SQuAD script.
Extra arguments:
offsets: The character start/end indices for the tokens in each context.
texts: Map of qid --> raw text of examples context (matches offsets).
answers: Map of qid --> list of accepted answers.
"""
eval_time = Timer()
translator = build_translator(model, args)
builder = TranslationBuilder(model.tgt_dict,
n_best=args.n_best,
replace_unk=args.replace_unk)
# Run through examples
examples = 0
trans_dict, sources = dict(), dict()
with torch.no_grad():
pbar = tqdm(data_loader)
for batch_no, ex in enumerate(pbar):
batch_size = ex['batch_size'] # POTENTIAL BUG
ids = list(range(batch_no * batch_size,
(batch_no * batch_size) + batch_size))
batch_inputs = prepare_batch(ex, model)
Here we compute the ids based on the batch_size, and to compute the correct ids we assume that the batch_size is constant, but unfortunately I noticed with pdb that the batch size that comes form the data_loader (pytorch code) has a batch_size that is constant until the last batch. So for example if the dataset has 78 records, the first batch_size = 64 but the second and last is 14, and this mess up the writing of the results. It saves this last 14 prediction to ids that have already been assigned to some other record previously.
Maybe it is a particularity of my pytorch version:
pytorch=1.5.1=py3.6_cuda10.1.243_cudnn7.6.3_0
Anyway there is a quick fix:
def validate_official(args, data_loader, model, batch_size):
where batch_size is taken from the arg.batch_size and it stays constant throughout the entire validation procedure.
Luckily this bug is relevant only when the tested dataset is small and can invalidate at maximum 64 records, but on small testsets it is a relevant problem.
@wasiahmad let me know if it is clear how to reproduce the problem and I am curious to know if you can confirm the presence of the bug or there was something missing in my reasoning.
Thanks in advance,
Matteo
@MattePalte Yes, this is a bug. Since we have the counter examples
, we can simply do the following:
ids = list(range(examples, examples + batch_size))
Can you confirm if the above fix would solve the bug?