Train multiple hk.nets.MLP with one optimizer
rsmath opened this issue · 2 comments
rsmath commented
Hello,
I am trying to train two neural networks simultaneously with one optimizer. In PyTorch, this is trivial since model.parameters()
can be concatenated and passed to the optimizer. How do I accomplish this in general in haiku? This is assuming I have two parameter variables (one for each network) from the two networks' individual init functions (I also have the accompanying apply functions).
Thank you.
Ekundayo39283 commented
U can try the below format to see if it works for you
import haiku as hk
# Define your two networks
class Network1(hk.Module):
def __call__(self, x):
# Define network 1 architecture
return output
class Network2(hk.Module):
def __call__(self, x):
# Define network 2 architecture
return output
# Create instances of your networks
net1 = Network1()
net2 = Network2()
# Initialize parameters for both networks
params1 = net1.init(rng_key, input_shape1)
params2 = net2.init(rng_key, input_shape2)
# Apply parameters to create callable modules
net1_apply = hk.transform_with_state(net1.apply)
net2_apply = hk.transform_with_state(net2.apply)
# Concatenate parameters into a single list
all_params = list(params1.values()) + list(params2.values())
# Pass concatenated parameters to optimizer
optimizer = optim.Optimizer(learning_rate)
opt_state = optimizer.init(all_params)
rsmath commented
Yes, I managed to make one pytree of parameters to pass to the optimizer and it has been working fine. Thank you.