facebookresearch/beanmachine

[Feature Request] Adding Optimization as Inference and extending Optimization to Compositional Inference

jakee417 opened this issue ยท 4 comments

Issue Description

TL; DR
Support for MLE, MAP, and Variational inference!

Context
In situations where scalability and speed need to be balanced with posterior sample quality, various optimization routines can provide useful approximations for fully Bayesian inference. A good summary of what I am referring to is taken from page 141 of this textbook:
image

In the simplest form, simply applying a torch.optim to the computational graph of a Bean Machine model should in theory reproduce MLE estimation for a given Bean Machine model. Also, MAP estimation can be reproduced using a Delta Distribution as a variational distribution in SVI or by again directly optimizing the log_prob computational graph when priors are included. These types of relaxations could be good if you were interested in "productionizing" a static analysis (maybe initially done with HMC) into a service where latency requirements rule out these more expensive Bayesian inference methods.

Steps to Reproduce

N/A

Expected Behavior

Users have the ability to take one functional declaration of a model, i.e.:

@bm.random_variable
def beta_1():
    return dist.Normal(0, 10)

@bm.random_variable 
def beta_0():
    return dist.Normal(0, 10)

@bm.random_variable
def epsilon():
    return dist.Gamma(1, 1)

@bm.random_variable
def y(X):
    return dist.Normal(beta_1() * X + beta_0(), epsilon())

And apply inference methods such as SingleMaximumLikelihood(), MaximumAPosteriori(), MaxumumLikelihoodII(), etc. And these added inference methods can be used with CompositionalInference in the usual way. I believe that since this would be moving from a "samples as a model" to a "parameters as a model" paradigm, there would need to be some exposed properties returning the optimized parameters from these routines.

So we could access something like:

beta_0() # returns tensor of parameters instead of samples

System Info

N/A

Additional Context

It seems that comparable packages and this one offer these types of "relaxations" from a fully Bayesian inference approach. Offering this functionality and then also allowing users to mix and match these methods with the existing CompositionalInference framework would be ๐Ÿ”ฅ๐Ÿ”ฅ๐Ÿ”ฅ

๐Ÿ‘ I think this would be ๐Ÿ”ฅ ๐Ÿ”ฅ ๐Ÿ”ฅ as well. There are some experimental APIs for VI which are much less mature than the references you've provided.

also allowing users to mix and match these methods with the existing CompositionalInference framework

This is an interesting idea, and I think there's some room to nail down how stepping a Markov Chain vs stepping an optimizer for variational parameters should interact. One way off the top of my head could be for something like:

CompositionalInference({
            model.foo: bm.SingleSiteAncestralMetropolisHastings(),
            model.bar: bm.SVI(),
        })

to mean that each "iteration" iteration consists of one of two possibilities:

  1. a Gibbs Metropolis-Hastings step for foo, using the target density p(foo | bar = sample from current q(bar) approximation), or
  2. a SVI step for bar, with the ELBO stochastically estimated using q(bar | foo = current value of foo in world).

Is this close to what you had in mind?

Yes, I think this is the most straightforward way to accomplish this, and best for a first iteration. Additionally, maybe we could directly pass in a guide param to bm.SVI() which could replicate MLE or MAP behavior. I think in my original example, I could do:

CompositionalInference({
            beta_1: bm.SingleSiteAncestralMetropolisHastings(),
            beta_0: bm.SVI(guide=dist.Delta()),  # treat this as a MAP estimate
        })

This would be useful in the sense that I may not need uncertainty for beta_0 and can pass a None or Delta guide which would do MLE or MAP estimation respectively (similar to this pyro tutorial, but I understand the bean machine VI is still being actively developed, so maybe this is a really big ask).

Perhaps a more interesting case would be from the HMM tutorial. In this case, rather than assume K is fixed, I would like to "maximize" this variable out doing an automated model selection:

class HiddenMarkovModel:
    def __init__(
        self,
        N: int,
        # K: int,
        concentration: float,
        mu_loc: float,
        mu_scale: float,
        sigma_shape: float,
        sigma_rate: float,
    ) -> None:
    ...
    @bm.random_variable
    def lambda(self):
        return dist.Exponential(1.)

    @bm.functional
    def K(self):
        return dist.Poisson(self.lambda()) + 1. # Not sure if this would actually work, but some dist controlling K

    @bm.random_variable
    def Theta(self, k):
        return dist.Dirichlet(torch.ones(self.K) * self.concentration / self.K)
    ...

In which case I would love to be able to call:

compositional = bm.CompositionalInference({
    (model.K): bm.SVI(guide=None),  # just maximize the likelihood of K, Type II MLE/ Empirical Bayes
    (model.X): bm.SVI(guide=dist.Delta()), # MAP estimate
    (model.Sigma, model.Mu): ...
})

And as you suggest, this would invoke a coordinate wise sampling/optimization mixture that would optimize first over my number of components K, and then MCMC updates for the descendants of K. This would even be a better solution to this TFP example!

And maybe one additional API design thought, you can change the ratio of optimization/ sample steps just by increasing the learning rate. But maybe we could conversely say, take 10 MCMC samples for every one VI step. I have seen this here, although this was for using MCMC to estimate the ELBO itself I believe.

I took some time to play with the ideas around MLE/MAP as a special case of VI with Deltas and Flats and came up with https://github.com/feynmanliang/beanmachine/blob/7779b6325188cec91395d648128d6cbba60756b7/tutorials/VI_Hierarchical_Regression.ipynb. It's very much a prototype, so please let me know what you think! Some notes:

  • The LinearRegressionNnet section doesn't have an intercept term, so it's compared against the (without-intercept) OLS estimate rather than the true generative model
  • We fail infer() calls wheneverqueries kwarg is empty, but as these examples show sometimes you just have a likelihood term (i.e. an observation) and you would still like to optimize free parameters in your model, so we should support this without nonsense like guiding a RV with itself (ie rv_to_guides: { rv(): rv() })
  • The world API is not well documented, but there are quite a few useful things (e.g. get_variable to obtain priors and variational distribution using the optimized values without copy-pasting the code, get_param to get optimized parameter values). I have it as a return type for ease of development, but we had discussed something similar to TFP's trace function for exposing useful information to the return type. Don't have a strong opinion right now on what to do here...

The prototype doesn't include anything about MixedHMC yet, but I think @zaxtax and I will be having some upcoming discussions about this in the next week or two.

I think your tutorial solidly covers the first two uses cases in my original post ๐Ÿฅณ. Since BM natively covers fullbayes, I believe only the MixedHMC cases are left ๐Ÿ˜Ž. One question with LinearRegressionNNet in the tutorial, if a ReLU followed the loc parameter, could this actually serve as a building block in a Bayesian Neural Network? In the sense you could compose many of these dense layers and then run variational inference on the weights of each layer? Similar to this paper?

Some thoughts:

  • for MaximumLikelihoodInfer would it be possible to create a class that automatically adds the likelihood to the world.log_prob and optimizes this directly? Not sure what you had in mind, but I agree extra reliance on VariationalInfer seems a bit confusing for new users especially with the rv_to_guides: { rv(): rv() } pathology. While the Flat() and Delta() combination is illustrative of what is going on in MLE, some disinterested users may just wind up confused.
  • I think the TFP style trace_fn arg would be great to track things like loss and parameter values. This becomes doubly important in deployment scenarios where you are monitoring retraining through something like a dashboard or tensorboard. What really comes to mind is the Tensorflow History Callback return type from their fit() method. Users are then free to pass additional callbacks as necessary during object instantiation maybe?
  • bm.param is new to me, and I think there is room to somehow overload the observation distribution (in our case, model.y(X) or net.forward(X) to not only return a RVIdentifier but also a torch.Distribution object when possible. You showed a very easy way to extract parameters using get_variable and get_param, but to make predictions I think you would need to redefine this observation distribution externally and then pass these parameters. If you could somehow tell bm.random_variable wrapper that you know wanted predictions with estimated values instead of the RVIdentifier, I think alot of people would have much easier times computing log_probs, predictions, and the like. What would be really sweet is a bm.model class decorator that simply followed the sklearn base api, so that once inference was completed, you could use a predict API like with most machine learning API's. Something like this:
@bm.model  # maybe defaults a __call__() or predict() method that simply returns the "guts" of self.y(X)
class LinearRegressionFreq:

    @bm.random_variable
    def sigma_beta_1(self):
        return Flat()

    @bm.random_variable
    def beta_1(self):
        return Flat()

    @bm.random_variable
    def beta_0(self):
        return Flat()

    @bm.random_variable
    def epsilon(self):
        return Flat()

    @bm.random_variable
    def y(self, X):
        # This behavior is automatically added to the "predict" method but returns torch.distributions instead of RVIdentifier
        return dist.Normal(self.beta_1() * X + self.beta_0(), torch.nn.functional.softplus(self.epsilon()))