samuela/git-re-basin

Cost Matrix Computation in Weight Matching

frallebini opened this issue · 10 comments

Hi, I read the paper and I am having a really hard time reconciling the formula

weight_matching

with the actual computation of the cost matrix for the LAP in weight_matching.py, namely

A = jnp.zeros((n, n))
for wk, axis in ps.perm_to_axes[p]:
  w_a = params_a[wk]
  w_b = get_permuted_param(ps, perm, wk, params_b, except_axis=axis)
  w_a = jnp.moveaxis(w_a, axis, 0).reshape((n, -1))
  w_b = jnp.moveaxis(w_b, axis, 0).reshape((n, -1))
  A += w_a @ w_b.T

Are you following a different mathematical derivation or am I missing something?

Hi @frallebini! The writeup in the paper is for the special case of an MLP with no bias terms -- the version in the code is just more general. The connection here is that there's a sum over all weight arrays that interact with that P_\ell. Then for each one, we need to apply its relevant permutations on all other axis, take the Frobenius inner product with the reference model, and all those terms together. So A represents that sum, each for loop iterations adds a single term in to the sum, get_permuted_param applies the other (non-P_\ell) permutations to w_b, and the moveaxis-reshape-matmul corresponds to the Frobenius inner product with w_a.

Thanks @samuela, I understand that the code is a generalization of the MLP with no bias case, but still:

  1. If the moveaxis-reshape-@ operation corresponded to the Frobenius inner product with w_a, wouldn't A be a scalar?
  2. How does get_permuted_param "skip" the non-P_\ell permutations? Doesn't the except_axis argument mean that, for example, if I want to permute rows, then I have to apply the permutation vector perm[p] along the column dimension?

If the moveaxis-reshape-@ operation corresponded to the Frobenius inner product with w_a, wouldn't A be a scalar?

Ack, you're right! I messed up: it's not actually a Frobenius inner product, just a regular matrix product. The moveaxis-reshape combo is necessary to flatten dimensions that we don't care about in the case of non-2d weight arrays.

How does get_permuted_param "skip" the non-P_\ell permutations? Doesn't the except_axis argument mean that, for example, if I want to permute rows, then I have to apply the permutation vector perm[p] along the column dimension?

Yup, that's exactly what except_axis is doing. But I think you may have it backwards -- except_axis is excepting the P_\ell axis but applying all other fixed P's to all the other axes.

Ok, but let us consider the MLP-with-no bias case. The way the paper models weight matching as an LAP is

weight_matching_complete

In other words, it computes A as

paper (1)

What the code does, instead—if I understood correctly—is computing A by

  1. Permuting w_b disregarding P_\ell
  2. Transposing it
  3. Multiplying w_a by it

In other words

code (2)

I don't think (1) and (2) are the same thing though.

Hmm I think the error here is in the first line of (2): The shapes here don't line up since $W_\ell^A$ has shape (n, *) and $W_{\ell+1}^A$ has shape (*, n). So adding those things together will result in a shape error if your layers have different widths.

I think tracing out the code for the MLP without bias terms case is a good idea. In that case we run through the for wk, axis in ps.perm_to_axes[p]: loop two times: once for $W_\ell$ and once for $W_{\ell+1}$.

  • For $W_\ell$: First of all, axis=0 since $W_\ell$ has shape (n, *). Then, w_b = get_permuted_param(ps, perm, wk, params_b, except_axis=axis) will give us $W_\ell^B P_{\ell-1}^T$. In other words, $W_\ell^B$ but with the other permutations -- $P_{\ell-1}$ in this case -- applied to the other axes. jnp.moveaxis(w_a, axis, 0).reshape((n, -1)) will be a no-op since axis = 0. And w_b = jnp.moveaxis(w_b, axis, 0).reshape((n, -1)) will also be a no-op. So, w_a @ w_b.T is $W_\ell^A (W_\ell^B P_{\ell-1}^T)^T$ matches up with the first term in the sum.
  • For $W_{\ell+1}$: In this case axis = 1 since $W_{\ell+1}$ has shape (*, n). Then, w_b = get_permuted_param(ps, perm, wk, params_b, except_axis=axis) will give us $P_{\ell+1} W_{\ell+1}^B$. In other words, $W_{\ell+1}^B$ but with the other permutations -- $P_{\ell+1}$ in this case -- applied to the other axes. jnp.moveaxis(w_a, axis, 0).reshape((n, -1)) will result in a transpose, aka $(W_{\ell+1}^A)^T$. And w_b = jnp.moveaxis(w_b, axis, 0).reshape((n, -1)) will also result in a transpose, aka $(W_{\ell+1}^B)^T P_{\ell+1}^T$. So, w_a @ w_b.T matches up with the second term in the sum.

Ok, the role of moveaxis is clear, and the computation matches the formula in the paper for an MLP with no biases.

On the other hand, the reshape((n, -1)) (extending the reasoning to the presence of biases):

  • Is always a no-op for weight matrices—as n is either the number of rows of $W_\ell$ or it is the number of columns of $W_{\ell+1}$, which however has already been transposed by the moveaxis.
  • It is needed in order to transform the (n,) bias vectors into (n, 1) vectors so that w_a @ w_b.T is a (n, n) matrix which can be added to A.

Right?

That's correct! In addition, it's necessary when dealing weight arrays of higher shapes as well, eg in a convolutional layer where the weights have shape (w, h, channel_in, channel_out).

Hi, I read the code and I really did not understand the following snippet. Because It relates to the weight matching algorithm, so I post here.
In the line 199 weight_matching.py:

perm_sizes = {p: params_a[axes[0][0]].shape[axes[0][1]] for p, axes in ps.perm_to_axes.items()}

According to the above line, if W_\ell has shape [m, n] (m is output feature dim, n is input feature dim) in the Dense layer, then the shape of the permutation matrix P_\ell will be [n, n]. But when I read the paper, I think it should be [m, m].

Sorry for the silly question, but might you explain? @samuela @frallebini

Thank you!

Hi @LeCongThuong, ps.perm_to_axes is a dict of form PermutationId => [(ParamId, Axis), ...] where in this case PermutationIds are strings, ParamIds are also strings, and Axiss are integers. So for example in an MLP (without bias and assuming that weights have shape [out_dim, in_dim]) terms this dict would look something like

{ "P_5": [("Dense_5/kernel", 0), ("Dense_6/kernel", 1)], ... }

Therefore, axes[0][0] will be something like "Dense_0/kernel" and axes[0][1] will be 0. HTH!

Thank you so much for replying @samuela!

I tried to understand ps.perm_to_axes and got the meaning of Axis. Axis, from what I got from your comment, it will let us know to permute W_b to another axis than "Axis''. Following your above example, I think it should be

{ "P_5": [("Dense_5/kernel", 1), ("Dense_6/kernel", 0)], ... }

From that axes[0][1] will be 1, thus the shape of P_l will be [n, n].

Thank you again for replying to my question.