linear activation function gives errors
EnricoTrizio opened this issue · 2 comments
In the list of activation functions there is also linear
that doesn't do anything except for printing a warning.
def get_activation(activation: str):
"""Return activation module given string."""
activ = None
if activation == "relu":
activ = torch.nn.ReLU(True)
...
elif activation == "linear":
print("WARNING: no activation selected")
elif activation is None:
pass
else:
raise ValueError(
f"Unknown activation: {activation}. options: 'relu','elu','tanh','softplus','shifted_softplus','linear'. "
)
return activ
So if used (which is something that makes sense), activ is initialized as None, giving errors.
Maybe we can create a fake activation function that doesn't do anything but still is a torch.nn.module
class No_Activation(torch.nn.Module):
def __init__(self):
super(No_Activation, self).__init__()
def forward(self, input):
return input
Also, the list of available activations must be updated.
I think that works but it'll probably end up creating more lines of code than those you'd save by handling it explicitly as usual ('linear'
here means really None
). I don't think that function is called in many places anyway. What if we remove that case and just call that function as
if activ != 'linear':
get_activation(activ)
else:
logger.warn('No activation function')
Sorry, I realized this because for the committor one has an activation on the last layer, but it's also useful to turn it off sometimes.
I can move it to the post-processing part in case.
I was thinking that if someone chooses linear
it would also be nice to have it working.