Errors in examples/vae
TMats opened this issue · 1 comments
TMats commented
I met two errors while running VAE example in examples
.
https://github.com/masa-su/Tars_pytorch/blob/master/examples/vae.ipynb
First, The name of module was renamed in Tars.distributions.
NormalModel
andBernoulliModel
were renamed toNormal
andBernoulli
Second, after fixing the first one, I met the error as below when running the last cell.
I can't find how to fix this for now.
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
<ipython-input-10-21cc15672efd> in <module>()
7
8 for epoch in range(1, epochs + 1):
----> 9 train_loss = train(epoch)
10 test_loss = test(epoch)
11
<ipython-input-7-d073e7c2f1ba> in train(epoch)
3 for batch_idx, (data, _) in enumerate(tqdm(train_loader)):
4 data = data.to(device)
----> 5 lower_bound, loss = model.train({"x": data.view(-1, 784)})
6 train_loss += loss
7
~/workspace/tars/Tars_pytorch/Tars/models/vae.py in train(self, train_x, coef)
34
35 self.optimizer.zero_grad()
---> 36 lower_bound, loss = self._elbo(train_x, coef)
37
38 # backprop
~/workspace/tars/Tars_pytorch/Tars/models/vae.py in _elbo(self, x, reg_coef)
61
62 # reconstrunction error
---> 63 samples = self.encoder.sample(x)
64 log_like = self.decoder.log_likelihood(samples)
65
~/workspace/tars/Tars_pytorch/Tars/distributions/distributions.py in sample(self, x, shape, batch_size, return_all, reparam)
137 else: # conditional
138 x = self._verify_input(x)
--> 139 self._set_distribution(x)
140
141 output = {self.var[0]: self._get_sample(reparam=reparam)}
~/workspace/tars/Tars_pytorch/Tars/distributions/distributions.py in _set_distribution(self, x)
50
51 def _set_distribution(self, x={}):
---> 52 params = self.get_params(**x)
53 self.dist = self.DistributionTorch(**params)
54
~/workspace/tars/Tars_pytorch/Tars/distributions/distributions.py in get_params(self, **params)
113
114 # append constant_params to map_dict
--> 115 output.update(self.constant_params)
116
117 return output
AttributeError: 'tuple' object has no attribute 'update'