what is the dataset of this code?
Closed this issue · 3 comments
Hi, could you tell me where can I get the data that the code need?
Thank you for your help!
Unfortunately, I don't have the dataset that is mentioned in this code (sst.txt
) anymore. However, I used a different dataset with 2.9M sentences extracted from Yelp reviews. You can download this dataset at this link. You will likely need to modify the _load
function in dataset.py
to account for this new dataset, although it shouldn't be much work since I already preprocessed this dataset.
Hi, i used the dataset that you had linked earlier but i have some issues when i modified the _load
function in dataset.py
by changing it from
def _load(self):
with open('sst.txt','r') as f:
sents = [x for x in f.read().split('\n') if \
len(x.split())-1<=self.seq_len-2]
reviews = [x.split()[1:] for x in sents]
labels = [int(x.split()[0]) for x in sents]
return (reviews, labels)
to
def _load(self):
with open('sst.txt','r') as f:
sents = [x for x in f.read().split('\n') if \
len(x.split())-1<=self.seq_len-2]
reviews = [x.split() for x in sents]
labels = [i for i in range(len(sents))]
return (reviews,labels)#(reviews, labels)
and i ended up receiving an error as shown below in the image
Do you mind helping me with this?
I suspect the inputs need to have a long
dtype. Try doing x = self.embedding(x.long()).permute(1,0,2)
as it's shown in the stack trace. And, move the tensors to cuda with .cuda()
if you are using a GPU.