Left branches getting ignored in PyTorch implementation
p-i- opened this issue · 2 comments
Bojan and I dug through this work (convo in Yannic's Discord -> DailyPapers channel).
UltraFastBERT/benchmark_pytorch/fff/fff_bmm.py:
# y = torch.einsum('b i j , b i -> b j', selected_w2s, F.gelu(all_logits))
y = torch.einsum('b i j , b i -> b j', selected_w2s, all_scores)
return y
I've removed the .gelu
from this line. (Bojan switched to einsum also to improve clarity, but that's not relevant here).
If you're using .gelu then you're discarding information from all negative-scoring nodes, as you're multiplying their y-vector contribution by 0.
Here's a gist with an MNIST example: https://gist.github.com/p-i-/784ea313d21856c286b823f27bf79d90
If you put the .gelu back the accuracy will deteriorate.
Notice:
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = FFF(nIn=28*28, nOut=500)
self.fc2 = FFF(nIn=500, nOut=10)
# self.fc1 = FFF(nIn=28*28, nOut=10)
def forward(self, x):
x = x.view(-1, 28*28)
y_hat = self.fc2(torch.relu(self.fc1(x)))
# y_hat = self.fc1(x)
return y_hat
So I'm introducing a nonlinearity in between two FFF layers.
I think there's maybe a cleaner way to re-conceptualize the result at which you have arrived.
Each node has a node.x, which points in some direction in INPUT space and a node.y which points in some direction in OUTPUT space.
The node[][].x
each represents the normal vector to a region-splitting hyperplane. And for input x
, node[p][q].score = dot(x, node[p][q].x)
projects the input onto this normal-vector.
If it's positive, it's "sunny-side" of the hyper-plane and we branch up and right, else it's "darkside" and we branch up and left.
Either way we'll reach a new node with a fresh region-splitting hyperplane.
And so, once we've traversed the (depth D) tree (and I'm going to follow the authors in considering a solitary root-node as a tree of depth-0 (D=0)) we have split the input space into 2**D
regions, of which we are inside one.
And this to me is the beautiful part.
If we consider the "winning" node sequence node_{1..D}, then node_k.x form a basis for a D-dimensional subspace within our INPUT space. e_1 ... e_D
.
And node_k.y form a basis for a D-dimensional subspace within our OUTPUT space. f_1 ... f_D
.
And our input x
can be written as lambda_1 e_1 + ... + lambda_D e_D + remainderTerm
, where lambda_i is just node_i.score
And we're projecting this to lambda_1 f_1 + ... + lambda_D f_D
So the FFF layer is figuring out a "most-useful--D-dimensional-transform" and applying it. It's lerping from a basis over INPUT space to a basis over OUTPUT space.
And this basis-pair depends on where our input x is located in INPUT space. There's 2**D possible basis-pairs.
And the backprop will move the bases around to optimize performance. So it will learn a "most-useful" mapping. It reminds me of LoRA in this way.
There's room for exploring quite a few ideas from here. I've emailed one of the authors.
Hi,
Thanks for getting in touch and for the summary. Do you have any action items to suggest with regards to the above?
Bojan and I are actively exploring this area at https://github.com/sap-ient-ai/FFF
There's a Discord linked in the README.md where a group of us are hanging out and trying to get our heads around this rich new conceptual space that your work has broken open.
Fundamental seems the idea that a FF layer is expensive: It is an nIn
x nOut
matrix, and that compute-power can surely be better put to use.
Fundamental seems the idea of branching conditionally on the input.
Bojan knocked out an experiment yesterday, where we pre-calculate the lambdas (each node we traverse thru we pick up a lambda, so how about just pass the input thru a nIn
x depth
matrix to pre-gen the lambdas, and then use them to navigate the tree, thus applying them to a chosen set of y-vectors.
Amazingly this benchmarks about the same as a FF layer on CIFAR10.
But there's the challenge of the information-content of the input. Just because a novel FFF-variant performs well on MNIST/CIFAR doesn't give any guarantee it will handle a dataset with higher intrinsic dimensionality (or to put it another way, a more interesting heatmap over the input distribution).
I feel we're bumping into many of the problems that you yourselves encountered in your research.