ViCCo-Group/thingsvision

model feature extraction not possible on cuda

florianmahner opened this issue · 1 comments

if torch.device is cuda, then thingsvision.model_class.Model.extract_features() throws a type error, since features are torch.Tensor() and features = np.vstack(features) #L223 is not possible. suggested fix on #L221 features.append(act.cpu().numpy())

Feel free to open a PR. :)