kimhc6028/soft-decision-tree

Will this work out of the box with [16] shaped input?

Opened this issue · 4 comments

Hi, I don't want to classify images, just 16-element tensors per datum. Will this work out of the box with that? If not, where should I look to change?

sorry, I don't understand your purpose. What is "16-element tensors per datum"?

I mean, what is the expected shape of the input data?

This repo uses pytorch default mnist data loader, and reshape data into (batch_size, -1).
i.e.,

data = data.view(batch_size,-1)

Therefore, I guess you can try your custom data loader that outputs data shape of (batch_size, 16). Of course there can be some error while implementing it (this is 2 years old repository and I forgot most of the details), but there shouldn't be big problem

Thanks Kim, appreciate it.