multiclass extension (of notebook 2)
jdmoore7 opened this issue · 1 comments
Trying to adapt code for a multiclass classification task, ran into the error:
TypeError: '<' not supported between instances of 'Example' and 'Example'
Code snippets WITH changes have been listed (assume all else is same as notebook 2):
INPUT_DIM = len(TEXT.vocab)
EMBEDDING_DIM = 100
HIDDEN_DIM = 256
OUTPUT_DIM = len(LABEL.vocab) ### changed
N_LAYERS = 2
BIDIRECTIONAL = True
DROPOUT = 0.5
PAD_IDX = TEXT.vocab.stoi[TEXT.pad_token]
model = RNN(INPUT_DIM,
EMBEDDING_DIM,
HIDDEN_DIM,
OUTPUT_DIM,
N_LAYERS,
BIDIRECTIONAL,
DROPOUT,
PAD_IDX)
import torch.optim as optim
optimizer = optim.Adam(model.parameters())
criterion = nn.CrossEntropyLoss() ### changed
model = model.to(device)
criterion = criterion.to(device)
The next block of code will trigger the error:
N_EPOCHS = 5
best_valid_loss = float('inf')
for epoch in range(N_EPOCHS):
start_time = time.time()
train_loss, train_acc = train(model, train_iterator, optimizer, criterion)
valid_loss, valid_acc = evaluate(model, valid_iterator, criterion)
end_time = time.time()
epoch_mins, epoch_secs = epoch_time(start_time, end_time)
if valid_loss < best_valid_loss:
best_valid_loss = valid_loss
torch.save(model.state_dict(), 'tut2-model.pt')
print(f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s')
print(f'\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%')
print(f'\t Val. Loss: {valid_loss:.3f} | Val. Acc: {valid_acc*100:.2f}%')
>>>
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-888-c1b298b1eeea> in <module>
7 start_time = time.time()
8
----> 9 train_loss, train_acc = train(model, train_iterator, optimizer, criterion)
10 valid_loss, valid_acc = evaluate(model, valid_iterator, criterion)
11
<ipython-input-885-9a57198441ec> in train(model, iterator, optimizer, criterion)
6 model.train()
7
----> 8 for batch in iterator:
9
10 optimizer.zero_grad()
~/opt/anaconda3/lib/python3.7/site-packages/torchtext/data/iterator.py in __iter__(self)
140 while True:
141 self.init_epoch()
--> 142 for idx, minibatch in enumerate(self.batches):
143 # fast-forward if loaded from state
144 if self._iterations_this_epoch > idx:
~/opt/anaconda3/lib/python3.7/site-packages/torchtext/data/iterator.py in pool(data, batch_size, key, batch_size_fn, random_shuffler, shuffle, sort_within_batch)
284 for p in batch(data, batch_size * 100, batch_size_fn):
285 p_batch = batch(sorted(p, key=key), batch_size, batch_size_fn) \
--> 286 if sort_within_batch \
287 else batch(p, batch_size, batch_size_fn)
288 if shuffle:
TypeError: '<' not supported between instances of 'Example' and 'Example'
Is there a simpler way to structure this code for multiclass classification?
I will check this now. One other thing you need to change is the LabelField
, by removing the dtype
argument. This is because CrossEntropyLoss
expects the targets to be LongTensors
, whereas in the tutorials we use BCEWithLogitsLoss
which expects the targets to be FloatTensors
. The dtype
argument overrides the TorchText default of making them LongTensors
and instead converts them to FloatTensors
.
Have you checked out the multi-class notebook? https://github.com/bentrevett/pytorch-sentiment-analysis/blob/master/5%20-%20Multi-class%20Sentiment%20Analysis.ipynb
I believe the actual error is that you need to provide a sort_key
to the BucketIterator
. The datasets provided by TorchText have their sort_key
s already set, but if you are using your own dataset you will need to manually provide one.
BATCH_SIZE = 64
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_iterator, valid_iterator, test_iterator = data.BucketIterator.splits(
(train_data, valid_data, test_data),
batch_size = BATCH_SIZE,
sort_within_batch = True,
sort_key = lambda x: x.text, #this has been added
device = device)
This is mentioned in the appendix notebook on using TorchText with your own datasets: https://github.com/bentrevett/pytorch-sentiment-analysis/blob/master/A%20-%20Using%20TorchText%20with%20Your%20Own%20Datasets.ipynb