deep-spin/sparse_text_generation

Index -1 is out of bounds

Closed this issue · 1 comments

Hi! I am training a language model similar to one in Sparse Text Generation project with custom input format. When I start training it can not calculate an entmax loss.
My inputs and labels both has shapes (batch_size, seq_len) before went to loss. Afterwards (batch_size*seq_len, vocab_size) and (batch_size*seq_len,) respectively. I use masking via -1 in labels and despite I set ignore_index=-1 , my log is:

Traceback (most recent call last):                                                                                                       │
  File "run_lm_finetuning.py", line 782, in <module>                                                                                     │
    main()                                                                                                                               │
  File "run_lm_finetuning.py", line 736, in main                                                                                         │
    global_step, tr_loss = train(args, train_dataset, model, tokenizer, gen_func)                                                        │
  File "run_lm_finetuning.py", line 300, in train                                                                                        │
    outputs = model(inputs, labels=labels)                                                                                               │
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 880, in _call_impl                                      │
    result = self.forward(*input, **kwargs)                                                                                              │
  File "/app/src/pytorch_transformers/modeling_gpt2.py", line 607, in forward                                                            │
    loss = self.loss(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))                                                │
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 880, in _call_impl                                      │
    result = self.forward(*input, **kwargs)                                                                                              │
  File "/usr/local/lib/python3.6/dist-packages/entmax/losses.py", line 17, in forward                                                    │
    loss = self.loss(X, target)                                                                                                          │
  File "/usr/local/lib/python3.6/dist-packages/entmax/losses.py", line 278, in loss                                                      │
    return entmax_bisect_loss(X, target, self.alpha, self.n_iter)                                                                        │
  File "/usr/local/lib/python3.6/dist-packages/entmax/losses.py", line 242, in entmax_bisect_loss                                        │
    return EntmaxBisectLossFunction.apply(X, target, alpha, n_iter)                                                                      │
  File "/usr/local/lib/python3.6/dist-packages/entmax/losses.py", line 129, in forward                                                   │
    ctx, X, target, alpha, proj_args=dict(n_iter=n_iter)                                                                                 │
  File "/usr/local/lib/python3.6/dist-packages/entmax/losses.py", line 45, in forward                                                    │
    p_star.scatter_add_(1, target.unsqueeze(1), torch.full_like(p_star, -1))                                                             │
RuntimeError: index -1 is out of bounds for dimension 1 with size 50257  

How to fix this?

UPD:
I realized that the problem is not connected with ignore_index, but with shapes missmatch between target and p_star in forward method of _GenericLossFunction class. Still don't know hot to fix this bug. So, help me please, if somebody know how :)

Hello, I want to know about the data preprocessing method. Is it convenient for you to let me know? @liehtman