/bayesian-flow

Unofficial Implementation of Bayesian Flow Network in easy PyTorch for didactic purposes

Primary LanguagePython

Bayesian Flow Networks in Easy PyTorch

This repo contains the unofficial implementation for Bayesian Flow Networks as introduced in Graves et al. (2023).

Usage

import torch

from src.unet import UNet
from src.bfn import BayesianFlowNetwork

bfn = BayesianFlowNetwork(
    backbone=UNet(
        net_dim=4,
        ctrl_dim=None,
        use_cond=False,
        use_attn=True,
        num_group=4,
        adapter='b c h w -> b (h w) c',
    ),
    loss_kind='continuous',
    data_kind='continuous',
    data_shape=(32, 32),
)

# Get some fake imgs for testing
imgs = torch.randn(16, 3, 32, 32)

# Compute the Bayesian Flow loss
loss = bfn.compute_loss(imgs)

# Compute the model gradients
loss.backward()

...

# Once the model is trained, we can sample from the learnt
# inverse flow by simply doing
samples = bfn(
  num_samples=4,
  num_steps=100,
  sigma_1=1e-3,
)

Citations

@article{graves2023bayesian,
  title={Bayesian Flow Networks},
  author={Graves, Alex and Srivastava, Rupesh Kumar and Atkinson, Timothy and Gomez, Faustino},
  journal={arXiv preprint arXiv:2308.07037},
  year={2023}
}