cybertronai/pytorch-sso

Can VOGN optimise the input tensor instead of the weights of the model?

razvanmarinescu opened this issue · 3 comments

I have an optimisation where I find the optimal latent z that gives me the closest image to an input image I:

argmin_z || G(z;w) - I ||

where the pre-trained neural net G( . ;w) has fixed weights w. I want to use VOGN to sample the posterior over p(z | I, w). My setting is a bit more complicated, but this is the rough idea.

Can I use VOGN to optimise the input tensor z, instead of the weights w of the model? I see that I need to pass the entire model G to the VOGN constructor, which will likely try to optimise the weights w of G instead of the input parameters z that I want the posterior over. How can I tell it to optimise the inputs z to my model instead?

Thanks,
Raz

One idea I'm trying now is to define a fake module M at the beginning that simply returns z and M.parameters() = z.

Then would I just need to pass M to VOGN as follows? VOGN(M, dataset_size=1)

Or do I need to pass the entire pipeline until the loss is evaluated? I don't need a posterior over any parameters in G ... VOGN([M, G], dataset_size=1)

dataset_size is 1 as I optimize one image at a time.

One more question. In the closure function:

def closure():
  optimizer.zero_grad()
  output = model(data)
  loss = F.cross_entropy(output, target)
  loss.backward(create_graph=args.create_graph)
  return loss, output

How is the output used? For normal classification on MNIST, this is a scalar/1D-vector on which sigmoid/softmax is applied. In my case, I'm running it on a generative model (StyleGAN), so the output is an NxN image and the final_loss is composed of multiple losses (a pixelwise L2 loss and a perceptual loss). In the closure function, what should I set as the output?

sarihl commented

I know this is old, But if it is relevant to someone, You can achieve this by defining a nn.Module that has one Linear layer(1, needed_input_dim, bias=False), the weights of this layer after optimization are 'z'.