yaringal/ConcreteDropout

pytorch version: small error?

Opened this issue · 0 comments

Dear Yarin,

fascinating research, which I am now trying to use in my own. I believe there is a small error in the 'fit_model' function of the pytorch version concerning the computation of the batches, which affects execution speed even when probably benign to the results.

I believe this code corrects the error:

...
for i in range(self.nb_epoch):
for batch in range(int(np.ceil(self.X.shape[0] / self.batch_size))):
_x = self.X[self.batch_size * batch : self.batch_size * (batch+1)]
_y = self.Y[self.batch_size * batch : self.batch_size * (batch+1)]
x = torch.FloatTensor(_x).cuda() # 32-bit floating point
y = torch.FloatTensor(_y).cuda()
mean, log_var, regularization = self.model(x) # forward pass
loss = heteroscedastic_loss(y, mean, log_var) + regularization
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
...

Kind regards,

Hans