I put together this repository as part of a classroom presentation of Transformers can do Bayesian inference.
We want to approximate the distribution that underlies some dataset.
The paper introduces a prior-data-fitted-network, which works as follows:
Using the above method, we can train a model to optimize the term on the left directly:
This is called the PPD, posterior predictive distribution, and it is difficult to approximate directly using other methods such as Markov Chain Monte Carlo (MCMC).
Transformers typically use positional embeddings to model sequences. In this case, we want to ignore the sequence of the inputs, making them "permutation invariant."
The input points attend to each other, and the queries attend to the input points.
Modifications to the algorithm (4) for single-headed attention are shown. Can be easily extended to multi-head self-attention.
Instead of X and Z, which are vector representations of primary and context tokens,
we have D (data points) and Q (queries).
Input:
| D, input (x, y) pairs
| Q, queries
ℓD ← length(D)
Mask[𝑡d, 𝑡q] = [[𝑡d ≤ ℓD]]
Then mask normally for attention.
∀𝑡d, 𝑡q, if ¬Mask[𝑡d, 𝑡q] then S[𝑡d, 𝑡q] ← -∞
The authors of the paper created a simple demonstration on Huggingface Spaces: Huggingface Spaces
I ported some of their demonstration code into a Colab that you can run yourself: Colab link
The authors generate a prior dataset for a shared handwriting recognition task called Omniglot. The task looks like this:
The authors trained a model on synthetically generated data that looks like this:
The paper participated in an open review process.
Things that are still not clear to me that I think the paper could address more clearly:
- What are some real world problems that might benefit from such an approach? The authors show that the method is effective for handwriting recognition. They also show that the method is more efficient than existing approaches for a diverse collection of tabular datasets (when run on a Tesla V100...), but they reduce these datasets to binary classification problems and take other simplifying steps, so it is not clear to me how close we are to a 'market-ready' approach.
- Speaking of the Tesla V100, it seems that the PFN approach should be compared to other GPU-accelerated approaches. I suspect there are other methods that could see massive speedups if they were also GPU-optimized.
- The authors don't provide any intuitive explanation for why the method seems to work so well. In particular, I am confused about the transformer architecture. Why is the transformer architecture useful here? They do not use positional embeddings, and I do not understand how self-attention is contributing to their setup. Would this method work with a simpler NN architecture?
- In terms of the writeup, it isn't clear where this contribution fits in to the literature or what issues should be addressed in future work.
- Is it important to have accurate uncertainty estimates in your field? How important are confidence intervals in a predictive model?
- A requirement to create a prior-data-fitted network (PFN) is to generate synthetic (labelled) data. Then, you perform Bayesian inference on real data. Can we think of this approach as Bayesian deep learning with weak supervision?
- If the parameters of a transformer network can serve as a prior for Bayesian inference, could a sufficiently large language model serve as a universal prior?