masa-su/pixyz

Encoder is executed 2 times in VAE

tatsuhiko-inoue opened this issue · 2 comments

以下のような VAE を実装して実行すると Encoder が2回実行されます。

import torch
import torch.nn as nn
from pixyz.distributions import Normal
from pixyz.losses import KullbackLeibler
from pixyz.models import VAE
import torch.optim as optim

class Encoder(Normal):
    def __init__(self):
        super().__init__(cond_var=["x"], var=["z"], name="q")
        self.linear = nn.Linear(10, 10)

    def forward(self, x):
        print("Encoder")
        return {"loc": self.linear(x), "scale": 1.0}

class Decoder(Normal):
    def __init__(self):
        super().__init__(cond_var=["z"], var=["x"], name="p")

    def forward(self, z):
        print("Decoder")
        return {"loc": z, "scale": 1.0}

def prior():
    return Normal(loc=torch.tensor(0.), scale=torch.tensor(1.), var=["z"], features_shape=[10], name="p_{prior}")

q = Encoder()
p = Decoder()

prior = prior()
kl = KullbackLeibler(q, prior)

mdl = VAE(q, p, regularizer=kl, optimizer=optim.Adam, optimizer_params={"lr":1e-3})

x = torch.zeros((10, 10))
loss = mdl.train({"x": x})

出力

Encoder
Decoder
Encoder

KL divergence と再構成誤差のそれぞれで Encoder を実行しているように見えます。
Encoder を2回実行すると、その分学習時間が長くかかるため、1回で済ませたいのですが、方法はありますでしょうか?

ありがとうございます.
#109 2回実行される問題については,こちらのプルリクで対応中です.

対応済み(v0.3.3で完全対応)