shreydesai/latext-gan

what is the dataset of this code?

Closed this issue · 3 comments

Hi, could you tell me where can I get the data that the code need?
Thank you for your help!

Unfortunately, I don't have the dataset that is mentioned in this code (sst.txt) anymore. However, I used a different dataset with 2.9M sentences extracted from Yelp reviews. You can download this dataset at this link. You will likely need to modify the _load function in dataset.py to account for this new dataset, although it shouldn't be much work since I already preprocessed this dataset.

Hi, i used the dataset that you had linked earlier but i have some issues when i modified the _load function in dataset.py by changing it from

def _load(self):
        with open('sst.txt','r') as f:
            sents = [x for x in f.read().split('\n') if \
                     len(x.split())-1<=self.seq_len-2]
            reviews = [x.split()[1:] for x in sents]
            labels = [int(x.split()[0]) for x in sents]
        return (reviews, labels)

to

def _load(self):
        with open('sst.txt','r') as f:
            sents = [x for x in f.read().split('\n') if \
                     len(x.split())-1<=self.seq_len-2]
            reviews = [x.split() for x in sents]
            labels = [i for i in range(len(sents))]
        return (reviews,labels)#(reviews, labels)

and i ended up receiving an error as shown below in the image
Capture

Do you mind helping me with this?

I suspect the inputs need to have a long dtype. Try doing x = self.embedding(x.long()).permute(1,0,2) as it's shown in the stack trace. And, move the tensors to cuda with .cuda() if you are using a GPU.