apple/ml-4m

Input masks for generation - Potential small bug.

Opened this issue · 0 comments

Looks like there may be a small bug in the generation:

eos_idx = torch.where(mod_dict[domain]['tensor'] == eos_id)[1][0].item()

The input masks for text are determined by the position of the first batch eos only but subsequently applied to all batches. Is this intentional? Looks like it's commonly used with single batch generation (in the examples) so this may have fallen through the cracks? If not I'd be curious about the intention here, otherwise happy to make a PR.

Great stuff btw, thanks for open sourcing this!