HomebrewNLP/revlib

Compute input from output?

isaacrob opened this issue · 7 comments

Hey @ClashLuke, nice code! You point to iRevNet as a comparison, and explain why this implementation is more memory efficient. However, iRevNet includes a method to, given the pre-pooler feature set for the CNN, directly invert the features and compute the input that generated those features. I don't see an obvious way to do that using your library. Is that something that can be done and do you intend to add a method to do that? I'm particularly interested in computing the inverse of a transformer given the output hidden states, which is something that iRevNet does not support as far as I can tell since it's CNN-specific, but if it's implementable within your framework as seems likely I would be very interested in leveraging it for my research.

Thank you! :)

Hi @isaacrob, thank you for your interest!
We're currently only calculating the inverse (and feeding it back up) during the backward pass, so you could manually get the inverse right now by calling your model like below:

from revlib import ReversibleModule

mod = ReversibleModule(torch.nn.Linear(2, 2))
x0 = torch.ones((1, 2))
x1 = torch.ones((1, 2))
x0_back = torch.ones((1, 2))
x1_back = torch.ones((1, 2))
y0, y1, y0_back, y1_back = mod(x0, x1, x0_back, x1_back)
(y0 + y1).sum().backward()
fn_inverse = y0_back.grad

As this is very clunky, I'd propose we implement invertibility natively into RevLib, which means we'd have to extract these lines into a new function and expose it. Do you want to add this, or should I give it a try over the weekend?

If you can't wait for a new implementation, MemCNN also supports inverses natively. However, it's missing some of the other features and uses more memory.

Either way, good luck with your research!

Hello! I am just now seeing your response, I expected an email notification but for some reason GitHub never sent one.. regardless I'm still interested in this, looks like you haven't added anything. If you want I can give it a shot when I have time!

With 37b8a4f, it's now part of RevLib. Just upgrade to RevLib 1.4.0 and you should be good to go!
If you have any other feature requests or problems you encounter, feel free to raise another issue :)

i would like to add a Reversible layer that shuffle/unshuffle (a simple reversible reshape). could you give some example on i would write the layer so that it is used by revlib?
something like:

def voxel_unshuffle(x, p=2):
    b,c,d,h,w = x.shape
    p = self.downscale_factor
    y = rearrange(x, "b c (z p1) (y p2) (x p3) -> b (c p1 p2 p3) z y x", 
                      z=d//p, y=h//p, x=w//p, p1=p, p2=p, p3=p)
    return y

def voxel_shuffle(x, p=2):
    b,c,d,h,w = x.shape
    p = self.upscale_factor
    y = rearrange(x, "b (c p1 p2 p3) z y x -> b c (z p1) (y p2) (x p3)", 
                      z=d, y=h, x=w, p1=p, p2=p, p3=p)
    return y

class VoxelUnshuffle(torch.nn.Module):
    def __init__(self, downscale_factor=2):
        super().__init__()
        self.downscale_factor = downscale_factor

    def forward(self, x):
        return voxel_unshuffle(x, self.downscale_factor)

    def inverse(self, x):
        return voxel_shuffle(x, self.downscale_factor) 

It seems like you're interested in running a 3d PixelShuffle within the model.
To do that, you don't need an inverse method, as RevLib hooks straight into PyTorch's autograd and runs these things for you. I'd recommend wrapping your nn.Module in voxel_shuffle and voxel_unshuffle and feeding that into RevLib.
For example, the following code works out of the box with RevLib:


class VoxelUnshuffle(torch.nn.Module):
    def __init__(self, inner_module: torch.nn.Module, downscale_factor: int):
        super().__init__()
        self.inner_module = inner_module
        self.downscale_factor = downscale_factor

    def forward(self, x):
        shuffled = voxel_shuffle(x, self.downscale_factor)
        output = self.inner_module(shuffled)
        return voxel_unshuffle(output, self.downscale_factor)

Is this what you were after?

yes, i was wondering though if shuffled/ unshuffled activations are "stored" within the autodiff of torch or if it is recomputed? or say i come up with a layer for which i know how to invert it, how should i tell revlib to optimize the backprop memory given i can reconstruct the input? perhaps i can simply rewrite the backward of this layer myself?

That's exactly what RevLib does!
Usually, PyTorch's autograd would store every intermediate value required for backpropagation. However, with RevLib, some of these are recomputed. The difference to gradient checkpointing is that RevNet doesn't need to store every input for recomputation but only two outputs.
The core idea behind RevNet is that you don't need to compute or implement dozens of inversions manually, but instead "invert" it automatically by slightly modifying the network architecture.

However, if you do have some functions where you know an inversion yourself that might be faster than PyTorch's own inversion, you can still implement it as a torch.autograd.Function. For example, the code below computes instance_norm(relu(x) ** 3 * y + z) and its gradients without storing any of the intermediate values PyTorch would otherwise allocate:
https://github.com/HomebrewNLP/HomebrewNLP/blob/7ae5eb6fc92e28f0aeb78cb0ca4b040d7801f82a/src/model.py#L32-L53
If you're in true memory pressure, even after applying RevLib, this might be the best option to trade off your own sanity for more GPU memory without reducing speed.
If reducing speed is an option, you can also offload parameters and optimizer states to the CPU, which makes things up to 10x slower but removes the limit that the model has to fit into VRAM. This way, you can fine-tune models like GPT-J on any GPU.