masa-su/pixyz

Errors in examples/vae

TMats opened this issue · 1 comments

TMats commented

I met two errors while running VAE example in examples.
https://github.com/masa-su/Tars_pytorch/blob/master/examples/vae.ipynb

First, The name of module was renamed in Tars.distributions.

  • NormalModel and BernoulliModel were renamed to Normal and Bernoulli

Second, after fixing the first one, I met the error as below when running the last cell.
I can't find how to fix this for now.

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-10-21cc15672efd> in <module>()
      7 
      8 for epoch in range(1, epochs + 1):
----> 9     train_loss = train(epoch)
     10     test_loss = test(epoch)
     11 

<ipython-input-7-d073e7c2f1ba> in train(epoch)
      3     for batch_idx, (data, _) in enumerate(tqdm(train_loader)):
      4         data = data.to(device)
----> 5         lower_bound, loss = model.train({"x": data.view(-1, 784)})
      6         train_loss += loss
      7 

~/workspace/tars/Tars_pytorch/Tars/models/vae.py in train(self, train_x, coef)
     34 
     35         self.optimizer.zero_grad()
---> 36         lower_bound, loss = self._elbo(train_x, coef)
     37 
     38         # backprop

~/workspace/tars/Tars_pytorch/Tars/models/vae.py in _elbo(self, x, reg_coef)
     61 
     62         # reconstrunction error
---> 63         samples = self.encoder.sample(x)
     64         log_like = self.decoder.log_likelihood(samples)
     65 

~/workspace/tars/Tars_pytorch/Tars/distributions/distributions.py in sample(self, x, shape, batch_size, return_all, reparam)
    137         else:  # conditional
    138             x = self._verify_input(x)
--> 139             self._set_distribution(x)
    140 
    141             output = {self.var[0]: self._get_sample(reparam=reparam)}

~/workspace/tars/Tars_pytorch/Tars/distributions/distributions.py in _set_distribution(self, x)
     50 
     51     def _set_distribution(self, x={}):
---> 52         params = self.get_params(**x)
     53         self.dist = self.DistributionTorch(**params)
     54 

~/workspace/tars/Tars_pytorch/Tars/distributions/distributions.py in get_params(self, **params)
    113 
    114         # append constant_params to map_dict
--> 115         output.update(self.constant_params)
    116 
    117         return output

AttributeError: 'tuple' object has no attribute 'update'

@TMats
Thank you for reporting these errors!

I can't find how to fix this for now.

This is because theforward function in Tars.distributions should return tuple object, but it isn't so in the VAE example.
I'll fix this example code as soon as possible!

Thanks.