vicgalle/stable-diffusion-aesthetic-gradients

Why optimize the full model?

jorgemcgomes opened this issue · 3 comments

If I understand correctly, all the weights of the CLIP text encoder are optimized, which naturally has some non-negligible computational cost.

Why was this chosen as opposed to just training part of the model?
My intuition would be to just optimize the last ~2 layers of the CLIP encoder.

Were there any experiments in this direction?

Hi @jorgemcgomes!
At first, I made some tests optimising just the final projection layer of the CLIP encoder, but I was unsatisfied with the results as the resulting image only varied slightly. So then I decided to optimize the full model instead. Regarding the computational cost, I found out it was worth it, as in my case (tested with a Tesla V100), running around 10 steps of the CLIP optimisation just takes about 1 second, whereas the following part of the diffusion process is around 12 seconds (for 50 diffusion steps and 3 images). So I believe that's a reasonable tradeoff.
But as you said, it may be interesting to further study some intermediate regime. For example, I should do a figure similar to Fig. 2, in which I compute the scores of the generations when fine-tuning the last layer, the last 2 layers, ..., until all the layers, and see if there is some monotony.

Thanks for the reply @vicgalle .

I see why optimizing just the final projection could fail. After all, it's just a simple linear projection. I've also tried doing that (in a different context) and got the same conclusions: very little can be done by changing just the final CLIP projection. But in my experience, you can change quite a lot by optimizing only the last few layers of the CLIP model, plus everything that comes after them (layer norm, projection). Might need to adjust the number of optimization steps though, but each step would be a lot faster...

Regarding the computational cost, I think the most significant would be in terms of VRAM, not necessarily time. The optimizer states and gradients for a full CLIP text model are several GB of VRAM.

eeyrw commented

I am also curious about the VRAM consumption of optimizing CLIP text model in runtime.