nikitakit/self-attentive-parser

RuntimeError in training: Sizes of tensors must match except in dimension 1

thorunna opened this issue · 0 comments

When I run python3 src/main.py train --model-path-base models/ --train-path data/gamalt_icepahc/train_random.clean --dev-path data/gamalt_icepahc/dev_random.clean --predict-tags --epochs 20 --use-words the following error occurs:

Traceback (most recent call last):
  File "src/main.py", line 613, in <module>
    main()
  File "src/main.py", line 609, in main
    args.callback(args)
  File "src/main.py", line 565, in <lambda>
    subparser.set_defaults(callback=lambda args: run_train(args, hparams))
  File "src/main.py", line 313, in run_train
    _, loss = parser.parse_batch(subbatch_sentences, subbatch_trees)
  File "/users/home/tha86/berk/self-attentive-parser-copy/src/parse_nk.py", line 1028, in parse_batch
    annotations, _ = self.encoder(emb_idxs, batch_idxs, extra_content_annotations=extra_content_annotations)
  File "/opt/share/python/3.6.1/lib/python3.6/site-packages/torch/nn/modules/module.py", line 489, in __call__
    result = self.forward(*input, **kwargs)
  File "/users/home/tha86/berk/self-attentive-parser-copy/src/parse_nk.py", line 612, in forward
    res, timing_signal, batch_idxs = emb(xs, batch_idxs, extra_content_annotations=extra_content_annotations)
  File "/opt/share/python/3.6.1/lib/python3.6/site-packages/torch/nn/modules/module.py", line 489, in __call__
    result = self.forward(*input, **kwargs)
  File "/users/home/tha86/berk/self-attentive-parser-copy/src/parse_nk.py", line 505, in forward
    annotations = torch.cat([content_annotations, timing_signal], 1)
RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 1. Got 627 and 645 in dimension 0 at /pytorch/aten/src/THC/generic/THCTensorMath.cu:83

I am using a Mac with python 3.6, pytorch 1.0.0 and Cython 0.29.14. Training is successful when the number of input sentences for training is small but this error comes up quickly when the number is increased.

Does anyone know why this error occurs and how it can be fixed?