Differences in batched vs. non-batched FMPE log_prob
Closed this issue · 5 comments
Description
When computing the log probability with FMPE's log_prob method, the resulting probability values depend on the other input elements in the batch. The change I saw was in the order of the third or fourth decimal place.
In any case, thanks already a lot for your work on LAMPE
Reproduce
Following the example, the two ways to compute log probabilities for a given configuration theta
and batch of corresponding simulated results x
produce different results:
from itertools import islice
import torch
import torch.nn as nn
import torch.optim as optim
import zuko
from lampe.data import JointLoader
from lampe.inference import FMPE, FMPELoss
from lampe.utils import GDStep
from tqdm import tqdm
LABELS = [r"$\theta_1$", r"$\theta_2$", r"$\theta_3$"]
LOWER = -torch.ones(3)
UPPER = torch.ones(3)
prior = zuko.distributions.BoxUniform(LOWER, UPPER)
def simulator(theta: torch.Tensor) -> torch.Tensor:
x = torch.stack(
[
theta[..., 0] + theta[..., 1] * theta[..., 2],
theta[..., 0] * theta[..., 1] + theta[..., 2],
],
dim=-1,
)
return x + 0.05 * torch.randn_like(x)
theta = prior.sample()
x = simulator(theta)
loader = JointLoader(prior, simulator, batch_size=256, vectorized=True)
estimator = FMPE(3, 2, hidden_features=[64] * 5, activation=nn.ELU)
loss = FMPELoss(estimator)
optimizer = optim.AdamW(estimator.parameters(), lr=1e-3)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, 128)
step = GDStep(optimizer, clip=1.0) # gradient descent step with gradient clipping
estimator.train()
with tqdm(range(128), unit="epoch") as tq:
for epoch in tq:
losses = torch.stack(
[
step(loss(theta, x))
for theta, x in islice(loader, 256) # 256 batches per epoch
]
)
tq.set_postfix(loss=losses.mean().item())
scheduler.step()
theta_star = prior.sample()
X = torch.stack([simulator(theta_star) for _ in range(10)])
estimator.eval()
with torch.no_grad():
# e.g. [3.1956, 1.8184, 2.4533, 1.6461, 3.0488, 2.5868, 2.7055, 2.7679, 3.3405, 1.5554]
log_p_one_batch = estimator.flow(X).log_prob(theta_star.repeat(len(X), 1))
# e.g. [3.1978, 1.8175, 2.4526, 1.6468, 3.0495, 2.5894, 2.7065, 2.7712, 3.3385, 1.5558]
log_p_individual = [estimator.flow(x).log_prob(theta_star) for x in X]
Expected behavior
I would expect that the individual log probability values for one theta
and x
pair are not affected by the other entries in the X
batch.
This is corroborated by the official implementation not showing that behaviour when evaluating log_prob_batch
with different subsets for the batch.
In the above example, I would expect both to e.g. result in [3.1978, 1.8175, 2.4526, 1.6468, 3.0495, 2.5894, 2.7065, 2.7712, 3.3385, 1.5558]
.
Causes and solution
I have no clear intuition why that would be the case. I suspected a stochastic influence and that the FreeFormJacobianTransform
exact mode might help, but it seems to be a deterministic difference and settings exact=true
did not affect that accordingly.
I noticed that the LAMPE implementation utilizes a trigonometrical embedding of the time dimension for the vector field computation when the official implementation by the authors does not, but it's also not obvious to me that this would explain the difference.
Environment
- LAMPE version: 0.8.2
- PyTorch version: 2.3.0
- Python version: 3.10.13
- OS: Ubuntu 20.04.6 LTS
Hello @LGro,
Thank you for reporting this bug. I think this comes from the tolerances used in FreeFormJacobianTransform
which are way higher in LAMPE/Zuko (1e-5
) than in the original implementation (1e-7
).
Could you try to modify the atol
and rtol
in the FreeFormJacobianTransform
of log_prob
and repeat your experiments?
Also it might be worth running in double precision (float64
).
Thanks for digging into this issue with me
Indeed, shrinking the tolerances while running with the estimator and inputs at float64 precision does reduce the initially observed discrepancy. Do I understand it right that the discrepancy is not problematic per-se as long as the magnitude is irrelevant for one's application?
does reduce the initially observed discrepancy
It does not vanish with both absolute and relative tolerances at 1e-7
?
Do I understand it right that the discrepancy is not problematic per-se as long as the magnitude is irrelevant for one's application?
Yes the discrepancy is not an implementation or method issue, but a numerical issue. If it is small enough, it should not affect downstream tasks. It could be worth adding the option to modify the tolerances in the FMPE
class though, or maybe a warning in the doc-string.
It does not vanish with both absolute and relative tolerances at
1e-7
?
For tolerances at 1e-9
the differences go down to the order of 1e-5
or 1e-6
, which was enough of an indicator for me. I have not tried to push it to the limit of float64
precision.
How does this compare to the official implementation (at 1e-7
)? If at the same tolerance the official implementation shows less discrepancies between batched/unbatched, it could be worth investigating further.