unable to run GTS on a custom dataset
indranilByjus opened this issue · 4 comments
The module seems to run fine with the provided datasets.
But it throws error on when I've included a custom dataset.
Errortrace:
MWPToolkit/mwptoolkit/model/Seq2Tree/gts.py line 220, in train_tree
current_num = current_nums_embeddings[idx, i - num_start].unsqueeze(0)
IndexError: index 124 is out of bounds for dimension 1 with size 118
I'm not sure the specific reason. But you could check the value of these variables below.
dataset.copy_nums
dataset.num_start
dataset.generate_list
dataset.out_idx2symbol
they are all about the decoder's vocabulary of GTS, or other models. Please check if they are currect.
Second, you may pay attention to the inputs(batch_data) of the model.
At the line where throws the error , current_nums_embeddings
means all number embedding at current decoding step (generate number + copy number). The size of it should be [batch_size, 118, hidden_size] , 118 is the sum of generate size
(static in different batches) and copy size
(dynamic in different batches, it's up to max(batch_data["num size"])
). You could check if batch_data["num size"]
is currect.
Another point is batch_data[num stack]
, in GTS, if a number appears twice or more in question sentence (one number has two position). So it has two optional symbols to generate. Which symbol to choose is decided while decoding. So target token is replaced by UNK_token
, when decoding, choose the symbol which has maximal score as target symbol. batch_data["num stack"]
means candidate symbols for UNK_token
.If UNK_token
is not replaced by candidate symbols currectly, it may cause the index out of bounds
. So please check if batch_data["num stack"] is currect.
Code for building number stack
MWPToolkit/mwptoolkit/data/dataset/abstactdataset.py line 192
def _build_num_stack(self, equation, num_list):
num_stack = []
for word in equation:
temp_num = []
flag_not = True
if word not in self.dataset.out_idx2symbol:
flag_not = False
if "NUM" in word:
temp_num.append(int(word[4:]))
for i, j in enumerate(num_list):
if j == word:
temp_num.append(i)
if not flag_not and len(temp_num) != 0:
num_stack.append(temp_num)
if not flag_not and len(temp_num) == 0:
num_stack.append([_ for _ in range(len(num_list))])
num_stack.reverse()
return num_stack
Code for choosing the target symbol according to maximal score
MWPToolkit/mwptoolkit/model/Seq2Tree/gts.py line 357
def generate_tree_input(self, target, decoder_output, nums_stack_batch, num_start, unk):
# when the decoder input is copied num but the num has two pos, chose the max
target_input = copy.deepcopy(target)
for i in range(len(target)):
if target[i] == unk:
num_stack = nums_stack_batch[i].pop()
max_score = -float("1e12")
for num in num_stack:
if decoder_output[i, num_start + num] > max_score:
target[i] = num + num_start
max_score = decoder_output[i, num_start + num]
if target_input[i] >= num_start:
target_input[i] = 0
return torch.LongTensor(target), torch.LongTensor(target_input)
I encountered a similar problem.
I think one possible reason is that: the text of equation/question has a different format. Eg. you are using "x=1+2" as the equation but the data loader is expecting "1+2".