tensorflow/mesh

Memory issues when using the "distillation" class

danyaljj opened this issue · 1 comments

@nshazeer @mmatena

We're trying to use your StudentTeacher class (which seems to be added by @mmatena) to do distillation w/ T5 (some details on our implementation in this issue).

Here is our current state: for larger "teacher" models we're getting an obscure message: "could not parse rpc response" (which thanks to @jysohn23 we know that it is due to out-of-memory errors.) What is odd is that models like T5-11B work just fine on, say, v3-8 TPUs when being fine-tuned individually. However, when paired with another model (say, a T5-small model) for "distillation", it requires much much bigger TPUs. So our current guess is that mesh-tensorflow is not appropriately distributing the model/data, for the "distillation".

Wondering if you have thoughts/suggestions that you can share with us: is there a bug here? (or, it's the expected behavior?)

FYI @sbhaktha

Looks like the issue was due to our model_parallelism parameter. Closing it for now.