.. toctree:: :hidden: docs/examples/examples docs/interface/interface
vbzero is a minimal stochastic variational inference framework for torch with an interface similar to pyro.
Models are declared as python functions using :func:`vbzero.util.sample` statements. For example, the following snippet encodes the standard biased coin example.
>>> import torch as th
>>> from vbzero.util import model, sample
>>> @model
... def biased_coin():
... proba = sample("proba", th.distributions.Beta(1, 1))
... x = sample("x", th.distributions.Bernoulli(proba), sample_shape=10)
... return proba, x
>>> th.manual_seed(1) # For reproducibility.
<torch...>
>>> biased_coin()
(tensor(0.6003), tensor([1., 1., 0., ...]))
If provided, state information is encapsulated in a :class:`vbzero.util.State`. For example, we can access all variables as follows.
>>> from vbzero.util import State
>>> th.manual_seed(1) # For reproducibility.
<torch...>
>>> with State() as state:
... biased_coin()
(tensor(0.6003), tensor([1., 1., 0., ...]))
>>> state
{'proba': tensor(0.6003), 'x': tensor([1., 1., 0., ...])}
This allows different datasets and models to be handled within the same process. If a :class:`vbzero.util.State` context is not active, a state will be created implicitly. It can be retrieved by calling :meth:`vbzero.util.State.get_instance` within the model, but all state will be discarded after the model invocation unless it is created explicitly as above.
The :class:`vbzero.util.LogProb` context can be used to evaluate the likelihood of a sample under the model.
>>> from vbzero.util import LogProb
>>> with state, LogProb() as log_prob:
... biased_coin()
(tensor(0.6003), tensor([1., 1., 0., ...]))
>>> log_prob
{'proba': tensor(0.), 'x': tensor([-0.5103, -0.5103, -0.9171, ...])}
Including state
in the with
statement ensures that all variables are defined and the likelihood can be evaluated. We consider counterfactuals by modifying the state directly or using the :func:`vbzero.util.condition` statement.
>>> from vbzero.util import condition
>>> conditioned = condition(biased_coin, proba=th.as_tensor(0.5))
>>> with state, LogProb() as log_prob:
... conditioned()
(tensor(0.5000), tensor([1., 1., 0., ...]))
>>> log_prob
{'proba': tensor(0.), 'x': tensor([-0.6931, -0.6931, -0.6931, ...])}
>>> state
{'proba': tensor(0.5000), 'x': tensor([1., 1., 0., ...])}
Note
The state is modified by invoking the conditioned
model. Use :meth:`vbzero.util.State.copy` to create a shallow copy and prevent it from being modified. In general, we recommend not sharing state across model invocations.