Memory issues when using the "distillation" class
danyaljj opened this issue · 1 comments
We're trying to use your StudentTeacher class (which seems to be added by @mmatena) to do distillation w/ T5 (some details on our implementation in this issue).
Here is our current state: for larger "teacher" models we're getting an obscure message: "could not parse rpc response" (which thanks to @jysohn23 we know that it is due to out-of-memory errors.) What is odd is that models like T5-11B work just fine on, say, v3-8 TPUs when being fine-tuned individually. However, when paired with another model (say, a T5-small model) for "distillation", it requires much much bigger TPUs. So our current guess is that mesh-tensorflow is not appropriately distributing the model/data, for the "distillation".
Wondering if you have thoughts/suggestions that you can share with us: is there a bug here? (or, it's the expected behavior?)
FYI @sbhaktha
Looks like the issue was due to our model_parallelism
parameter. Closing it for now.