Encoder is executed 2 times in VAE
tatsuhiko-inoue opened this issue · 2 comments
tatsuhiko-inoue commented
以下のような VAE を実装して実行すると Encoder が2回実行されます。
import torch
import torch.nn as nn
from pixyz.distributions import Normal
from pixyz.losses import KullbackLeibler
from pixyz.models import VAE
import torch.optim as optim
class Encoder(Normal):
def __init__(self):
super().__init__(cond_var=["x"], var=["z"], name="q")
self.linear = nn.Linear(10, 10)
def forward(self, x):
print("Encoder")
return {"loc": self.linear(x), "scale": 1.0}
class Decoder(Normal):
def __init__(self):
super().__init__(cond_var=["z"], var=["x"], name="p")
def forward(self, z):
print("Decoder")
return {"loc": z, "scale": 1.0}
def prior():
return Normal(loc=torch.tensor(0.), scale=torch.tensor(1.), var=["z"], features_shape=[10], name="p_{prior}")
q = Encoder()
p = Decoder()
prior = prior()
kl = KullbackLeibler(q, prior)
mdl = VAE(q, p, regularizer=kl, optimizer=optim.Adam, optimizer_params={"lr":1e-3})
x = torch.zeros((10, 10))
loss = mdl.train({"x": x})
出力
Encoder
Decoder
Encoder
KL divergence と再構成誤差のそれぞれで Encoder を実行しているように見えます。
Encoder を2回実行すると、その分学習時間が長くかかるため、1回で済ませたいのですが、方法はありますでしょうか?
masa-su commented
対応済み(v0.3.3で完全対応)