Work in progress, feedback and collaboration welcome!
Transformers are amazing models. Could they be useful models of the brain? At a fine scale, the "columns" of a transformer somewhat resemble cortical columns. Zooming out however, the large-scale architecture of the transformer is pretty different from the brain:
- In a transformer, the columns are arranged in a sequence. In the brain, they are arranged on a 2D folded cortical sheet.
- A transformer is "narrow" with a typical column count (i.e. sequence length) of ~103. The brain is "wide" with ~106 columns.
- A transformer has dense all-to-all connectivity between columns. The brain has sparse connectivity with mostly local and some long-range connections.
- Each column in a transformer layer shares weights. Weight sharing between columns in the brain is not possible.
- A transformer consists of multiple independent layers (i.e. blocks) applied sequentially. The brain consists of a single layer of columns (the cortex) applied recurrently.
Here we try to design a transformer-like architecture, the Columnformer, that closes these gaps.
- The model consists of a single large layer of "columns".
- Each column consists of an attention and an MLP module, like a transformer.
- Unlike a transformer, columns do not share weights.
- The model is recurrent. The layer is applied to the input recursively for
depth
steps.
Effectively, the attention module implements communication between columns, while the MLP implements computation within a column (cf Karpathy on Transformers). To promote sparse, structured communication between columns, we also:
- Embed columns in a geometric space. E.g. on a 2D grid or sphere.
- Penalize the "wiring cost" of the attention map with respect to this geometry.
The proposed model is highly flexible. Only the geometry of the column layer constrains the learned connectivity pattern (cf Geometric constraints on brain function). This raises interesting questions:
- What sorts of connectivity patterns does the model learn?
- Will we see spontaneous emergence of functional hierarchy? Feedback connections? Topographic organization? Functional specialization?
- What kinds of geometries and wiring cost penalties lead to more brain-like connectivity?
- Is it possible to learn an optimal geometry?
Both the transformer and the columnformer have a width (the number of columns) and a depth (the number of compute steps). One key difference is that the transformer shares weights across width, whereas the columnformer shares weights across depth (through recurrence). What impact does this have on model performance?
- Because width >> depth, columnformers have many more parameters than transformers, therefore less inductive bias. Will columnformers even learn?
- How hard is it to get the untied columns in a columnformer to "agree" on a latent feature space? Do we need some penalty term to promote feature consistency (e.g. feature smoothness)?
- Are there any advantages to sharing weights across depth? I.e. is there advantage to recurrence?
Brain activity and connectivity patterns are both highly sparse. Likewise, it will be important to leverage sparsity in columnformers as we scale the number of columns. What's the best way to approach this?
- Should we try to hand-design sparse connectivity patterns?
- Can we learn sparse connectivity patterns? What about some kind of progressive model training, where we alternate between training, pruning connections, and scaling the model?
- Will it be useful to promote sparsity over the column activations? Or could activation sparsity emerge spontaneously?
- Initial model implementation (
model_v1.py
) - Masked-image-modeling training implementation
- Initial training run of small model on COCO
This is a personal side research project. All work will be done openly. If you're interested in this idea, please get in touch! Feedback or collaboration is very welcome!
-
Lu, Zejin, et al. End-to-end topographic networks as models of cortical map formation and human visual behaviour: moving beyond convolutions. arXiv (2023).
-
Doshi, Fenil R., and Talia Konkle. Cortical topographic motifs emerge in a self-organized map of object space. Science Advances (2023).
-
Margalit, Eshed, et al. A Unifying Principle for the Functional Organization of Visual Cortex. bioRxiv (2023).
-
Achterberg, Jascha, et al. Spatially embedded recurrent neural networks reveal widespread links between structural and functional neuroscience findings. Nature Machine Intelligence (2023).
-
Pogodin, Roman, et al. Towards biologically plausible convolutional networks. NeurIPS (2021).