vislearn/FrEIA

Adapt MLP to FrEIA framework

Closed this issue · 1 comments

Hi all,

I'm trying to create an invertible MLP, particularly a Siren (MLP with periodic activations, from Sitzmann et al 2019).

This is a simplified Siren (I removed some code to keep it a simple example, but it can be found here):

class Siren(nn.Module):
    def __init__(self, in_features, hidden_features, hidden_layers, out_features, outermost_linear=False, first_omega_0=30, hidden_omega_0=30.):
        super().__init__()        
        self.net = []
        self.net.append(SineLayer(in_features, hidden_features, is_first=True, omega_0=first_omega_0))

        for i in range(hidden_layers):
            self.net.append(SineLayer(hidden_features, hidden_features, is_first=False, omega_0=hidden_omega_0))

        if outermost_linear:
            final_linear = nn.Linear(hidden_features, out_features)
            
            with torch.no_grad():
                final_linear.weight.uniform_(-np.sqrt(6 / hidden_features) / hidden_omega_0, np.sqrt(6 / hidden_features) / hidden_omega_0)
                
            self.net.append(final_linear)
        else:
            self.net.append(SineLayer(hidden_features, out_features, is_first=False, omega_0=hidden_omega_0))
        
        self.net = nn.Sequential(*self.net)
    
    def forward(self, coords):
        coords = coords.clone().detach().requires_grad_(True) # allows to take derivative w.r.t. input
        output = self.net(coords)
        return output, coords 

Now I need to adapt the Siren to be able to plug it into FrEIA.

My problem is very similar to this one, and, trying to follow some of the suggestions, I adapted the Siren model to this:

class Siren_FREIA(nn.Module):

    def __init__(self, in_features, hidden_features, hidden_layers, out_features, outermost_linear=False, first_omega_0=30, hidden_omega_0=30.):
        super().__init__()
        self.in_features = in_features
        self.hidden_features = hidden_features
        self.hidden_layers = hidden_layers
        self.out_features = out_features
        self.outermost_linear = outermost_linear
        self.first_omega_0 = first_omega_0
        self.hidden_omega_0 = hidden_omega_0
        
    def forward(self, inp, out):
        net = OrderedDict()
        
        net['SL1'] = SineLayer(inp, self.hidden_features, 
                                  is_first=True, omega_0=self.first_omega_0)

        for i in range(self.hidden_layers):
            net[f'SLh{i}'] = SineLayer(self.hidden_features, self.hidden_features, is_first=False, omega_0=self.hidden_omega_0)

        if self.outermost_linear:
            final_linear = nn.Linear(self.hidden_features, out)
            
            with torch.no_grad():
                final_linear.weight.uniform_(-np.sqrt(6 / self.hidden_features) / self.hidden_omega_0, np.sqrt(6 / self.hidden_features) / self.hidden_omega_0)
                
            net['final_linear'] = final_linear
        else:
            net['some_other_sine'] = SineLayer(self.hidden_features, out, is_first=False, omega_0=self.hidden_omega_0)
        
        net = nn.Sequential(net)       
        return net

Running this Siren model requires, in my case for example, an input shape (which is the same as the output shape) of: (BATCH_SIZE, 16384, 3), where BATCH_SIZE = 1

This is how I initialize the inn:

raw_scores_siren = Siren_FREIA(in_features=3, out_features=3, hidden_features=256, hidden_layers=3, outermost_linear=True)      
input_dims = (16384 , 3) #does not include batch size, which is 1
inn = Ff.SequenceINN(*input_dims)
inn.append(Fm.AllInOneBlock, subnet_constructor=raw_scores_siren, permute_soft=True)

However, I get the following message: "UserWarning: Soft permutation will take a very long time to initialize with 16384 feature channels. Consider using hard permutation instead. warnings.warn(("Soft permutation will take a very long time to initialize". After a long time (~20 minutes), I get the following error message:

"mat1 and mat2 shapes cannot be multiplied (2x16384 and 2x256)"

Because of this, I've tried to:

  1. Set permute_soft=False. I get the same error message

  2. Tranpose the input to (3, 16384) (while keeping batch_size = 1, don't know if it makes sense). I get the same error message

  3. Eliminate the batch dimension = 1 and setting the batch size to 16384. Therefore, input_dims = (3,) . This runs (trains) until the end, but the results are not the intended (i.e., the results are not the same as those I obtain when I train the Siren alone, without plugging it into FrEIA)

  4. Duplicate the input, such that the batch size is 2 (I read that there could be issues with a batch size of 1). However, this also does not work

So my question is, what am I doing wrong when adapting the Siren model to the Siren_FREIA model?

Sorry for the long post, but if not clear I can explain further. Thank you in advance!

Adding the color channel seems to have fixed the problem:

class Siren(nn.Module):
    def __init__(self, in_features, hidden_features, hidden_layers, out_features, outermost_linear=False, 
                 first_omega_0=30, hidden_omega_0=30.):
        super().__init__()
        
        self.net = []
        self.net.append(SineLayer(in_features, hidden_features, 
                                  is_first=True, omega_0=first_omega_0))

        for i in range(hidden_layers):
            self.net.append(SineLayer(hidden_features, hidden_features, 
                                      is_first=False, omega_0=hidden_omega_0))

        if outermost_linear:
            final_linear = nn.Linear(hidden_features, out_features)
            
            with torch.no_grad():
                final_linear.weight.uniform_(-np.sqrt(6 / hidden_features) / hidden_omega_0, 
                                              np.sqrt(6 / hidden_features) / hidden_omega_0)
                
            self.net.append(final_linear)
        else:
            self.net.append(SineLayer(hidden_features, out_features, 
                                      is_first=False, omega_0=hidden_omega_0))
        
        self.net = nn.Sequential(*self.net)
    
    def forward(self, in_features, out_features):       
        return self.net


input_dims = (1, 16384, 3) # added color channel. Does not include batch size which is one
inn = Ff.SequenceINN(*input_dims)
inn.append(Fm.AllInOneBlock, subnet_constructor=raw_scores_siren, permute_soft=False)
inn = inn.to(device)