graykode/nlp-tutorial

In code 4-1.Seq2Seq might have wrong section

karim-moon opened this issue · 0 comments

At the function translate(in line 90), there's no pre defined object 'args'.
And the function make_batch has no expected args but '[[word, 'P' * len(word)]], args' are given

so, I think the code should be modified.

from

    def translate(word, args):
        input_batch, output_batch, _ = make_batch([[word, 'P' * len(word)]], args)

        # make hidden shape [num_layers * num_directions, batch_size, n_hidden]
        hidden = torch.zeros(1, 1, args.n_hidden)
        output = model(input_batch, hidden, output_batch)
        # output : [max_len+1(=6), batch_size(=1), n_class]

        predict = output.data.max(2, keepdim=True)[1] # select n_class dimension
        decoded = [char_arr[i] for i in predict]
        end = decoded.index('E')
        translated = ''.join(decoded[:end])

        return translated.replace('P', '')

to

# Test
    def translate(word):
        input_batch, output_batch = make_testbatch(word)

        # make hidden shape [num_layers * num_directions, batch_size, n_hidden]
        hidden = torch.zeros(1, 1, n_hidden)
        output = model(input_batch, hidden, output_batch)
        # output : [max_len+1(=6), batch_size(=1), n_class]

        predict = output.data.max(2, keepdim=True)[1] # select n_class dimension
        decoded = [char_arr[i] for i in predict]
        end = decoded.index('E')
        translated = ''.join(decoded[:end])

        return translated.replace('P', '')

and make_testbatch should pre declared

#make test batch
def make_testbatch(input_word):
    input_batch, output_batch = [], []

    input_w = input_word + 'P' * (n_step - len(input_word))
    input = [num_dic[n] for n in input_w]
    
    #make a sequence with just start token(S) and pad tokens(P)
    output = [num_dic[n] for n in 'S' + 'P' * n_step]

    input_batch = np.eye(n_class)[input]
    output_batch = np.eye(n_class)[output]

    return torch.FloatTensor(input_batch).unsqueeze(0), torch.FloatTensor(output_batch).unsqueeze(0)

Thank you