Adapt MLP to FrEIA framework
annukkaa opened this issue · 1 comments
Hi all,
I'm trying to create an invertible MLP, particularly a Siren (MLP with periodic activations, from Sitzmann et al 2019).
This is a simplified Siren (I removed some code to keep it a simple example, but it can be found here):
class Siren(nn.Module):
def __init__(self, in_features, hidden_features, hidden_layers, out_features, outermost_linear=False, first_omega_0=30, hidden_omega_0=30.):
super().__init__()
self.net = []
self.net.append(SineLayer(in_features, hidden_features, is_first=True, omega_0=first_omega_0))
for i in range(hidden_layers):
self.net.append(SineLayer(hidden_features, hidden_features, is_first=False, omega_0=hidden_omega_0))
if outermost_linear:
final_linear = nn.Linear(hidden_features, out_features)
with torch.no_grad():
final_linear.weight.uniform_(-np.sqrt(6 / hidden_features) / hidden_omega_0, np.sqrt(6 / hidden_features) / hidden_omega_0)
self.net.append(final_linear)
else:
self.net.append(SineLayer(hidden_features, out_features, is_first=False, omega_0=hidden_omega_0))
self.net = nn.Sequential(*self.net)
def forward(self, coords):
coords = coords.clone().detach().requires_grad_(True) # allows to take derivative w.r.t. input
output = self.net(coords)
return output, coords
Now I need to adapt the Siren to be able to plug it into FrEIA.
My problem is very similar to this one, and, trying to follow some of the suggestions, I adapted the Siren model to this:
class Siren_FREIA(nn.Module):
def __init__(self, in_features, hidden_features, hidden_layers, out_features, outermost_linear=False, first_omega_0=30, hidden_omega_0=30.):
super().__init__()
self.in_features = in_features
self.hidden_features = hidden_features
self.hidden_layers = hidden_layers
self.out_features = out_features
self.outermost_linear = outermost_linear
self.first_omega_0 = first_omega_0
self.hidden_omega_0 = hidden_omega_0
def forward(self, inp, out):
net = OrderedDict()
net['SL1'] = SineLayer(inp, self.hidden_features,
is_first=True, omega_0=self.first_omega_0)
for i in range(self.hidden_layers):
net[f'SLh{i}'] = SineLayer(self.hidden_features, self.hidden_features, is_first=False, omega_0=self.hidden_omega_0)
if self.outermost_linear:
final_linear = nn.Linear(self.hidden_features, out)
with torch.no_grad():
final_linear.weight.uniform_(-np.sqrt(6 / self.hidden_features) / self.hidden_omega_0, np.sqrt(6 / self.hidden_features) / self.hidden_omega_0)
net['final_linear'] = final_linear
else:
net['some_other_sine'] = SineLayer(self.hidden_features, out, is_first=False, omega_0=self.hidden_omega_0)
net = nn.Sequential(net)
return net
Running this Siren model requires, in my case for example, an input shape (which is the same as the output shape) of: (BATCH_SIZE, 16384, 3), where BATCH_SIZE = 1
This is how I initialize the inn:
raw_scores_siren = Siren_FREIA(in_features=3, out_features=3, hidden_features=256, hidden_layers=3, outermost_linear=True)
input_dims = (16384 , 3) #does not include batch size, which is 1
inn = Ff.SequenceINN(*input_dims)
inn.append(Fm.AllInOneBlock, subnet_constructor=raw_scores_siren, permute_soft=True)
However, I get the following message: "UserWarning: Soft permutation will take a very long time to initialize with 16384 feature channels. Consider using hard permutation instead. warnings.warn(("Soft permutation will take a very long time to initialize". After a long time (~20 minutes), I get the following error message:
"mat1 and mat2 shapes cannot be multiplied (2x16384 and 2x256)"
Because of this, I've tried to:
-
Set
permute_soft=False
. I get the same error message -
Tranpose the input to (3, 16384) (while keeping batch_size = 1, don't know if it makes sense). I get the same error message
-
Eliminate the batch dimension = 1 and setting the batch size to 16384. Therefore,
input_dims = (3,)
. This runs (trains) until the end, but the results are not the intended (i.e., the results are not the same as those I obtain when I train the Siren alone, without plugging it into FrEIA) -
Duplicate the input, such that the batch size is 2 (I read that there could be issues with a batch size of 1). However, this also does not work
So my question is, what am I doing wrong when adapting the Siren model to the Siren_FREIA model?
Sorry for the long post, but if not clear I can explain further. Thank you in advance!
Adding the color channel seems to have fixed the problem:
class Siren(nn.Module):
def __init__(self, in_features, hidden_features, hidden_layers, out_features, outermost_linear=False,
first_omega_0=30, hidden_omega_0=30.):
super().__init__()
self.net = []
self.net.append(SineLayer(in_features, hidden_features,
is_first=True, omega_0=first_omega_0))
for i in range(hidden_layers):
self.net.append(SineLayer(hidden_features, hidden_features,
is_first=False, omega_0=hidden_omega_0))
if outermost_linear:
final_linear = nn.Linear(hidden_features, out_features)
with torch.no_grad():
final_linear.weight.uniform_(-np.sqrt(6 / hidden_features) / hidden_omega_0,
np.sqrt(6 / hidden_features) / hidden_omega_0)
self.net.append(final_linear)
else:
self.net.append(SineLayer(hidden_features, out_features,
is_first=False, omega_0=hidden_omega_0))
self.net = nn.Sequential(*self.net)
def forward(self, in_features, out_features):
return self.net
input_dims = (1, 16384, 3) # added color channel. Does not include batch size which is one
inn = Ff.SequenceINN(*input_dims)
inn.append(Fm.AllInOneBlock, subnet_constructor=raw_scores_siren, permute_soft=False)
inn = inn.to(device)