google-deepmind/ferminet

An issue regarding multi-node training with TF code

Closed this issue · 4 comments

Hi there. So I'm trying to run the TF-version ferminet on multiple nodes in a GPU cluster with a pretty naive idea of replacing MirroredStrategy with the MultiWorkerMirroredStrategy (I had some experience with TF's distributive training with its estimator API, but not with the low-level training loop nor with Sonnet).

Unfortunately, the effort failed with an issue that the strategy trying to place tensors on a device named like /job:worker/replica:0/task:0/device:GPU:0 on a worker node. Here /job:worker part is problematic since all available options are all started with /job:localhost instead (from my understanding TF should create a correctly named device but somehow it didn't). In my setting I didn't use any PS node and I am not sure if PS node is required when using MultiWorkerMirroredStrategy in TF 1.15 or if it's related to this phenomena.

So have you guys tried multi-node training with the TF-version code? If so did you use MultiWorkerMirroredStrategy (or did you run into this issue or something similar?)? Any comments is appreciated, thanks!

By the way, to my knowledge, JAX has not yet supported multi-node training, does it?

We have not tried MultiWorkerMirroredStrategy -- getting everything to work well with MirroredStrategy was very involved but sufficient for all published results. We have successfully done multi-working training using TF-Replicator, so in principle it should be achievable.

Are you trying with ADAM or K-FAC? For any porting work, I would suggest first getting ferminet working with ADAM and then investigate K-FAC. I added support for MirroredStrategy to K-FAC but don't know if MultiWorkerMirroredStrategy will work out of the box. I certainly hit many problems with tensor placement and naming getting MirroredStrategy to work.

Re: JAX multi-node support, please see google/jax#2731. We're not intending to continue developing the TF-version of ferminet.

Got it, thanks a lot for the info and suggestion @jsspencer ! Let me try harder then.

I was indeed trying KFAC, but it seems to me the error was thrown way before optimization starts.

From the github thread you mentioned about JAX, it seems its multi-node support is still under development. Let me ask them if they have anything to share now.

Thanks again!

These kinds of errors are normally triggered at graph-construction time rather than (later) at training time.

These kinds of errors are normally triggered at graph-construction time rather than (later) at training time.

You are absolutely right. I've managed to fix this issue and successfully run the TF code with MultiWorkerMirroredStrategy on multiple nodes (modulo the fact that using kfac would cause core-dump).