parameter count is innacurate
jameshensman opened this issue · 1 comments
jameshensman commented
SliceGPT adds linear modules in the "skip connections" around each block. We currently implement these as pytorch buffers, which causes two problems:
- counting the parameters is inaccurate (since buffers are not counted by default)
- it's not straightforward to fine-tune these additional matrices (which may improve performance?)
Proposed solution:
change the buffers to actual linear layers, initialized as torch.eye()
. At every place in the code where we check whether this buffer is None
, we should replace that bit of code. During slicing, make sure to modify these buffers correctly. Update finetuning script (and BO script) to enable finetuning of these linears.
nailimixaM commented
Solved in #100.