shrutirij/ocr-post-correction

Distributed training and auto-batching with this codebase

Opened this issue · 0 comments

Hi @shrutirij thanks for the great work and very well-documented and clean codebase, I greatly appreciate it!

I adapted and tried running this codebase for some conceptually similar experiments and I've observed a few quirks that I wanted to run by you to get your thoughts since I haven't really used Dynet before.

  1. I ran the codebase with --dynet-gpus 8 (after also modifying opts.py to support this arg) and found that although 8 processes are spawned and attached to 8 GPUs, only the first process has > 0% GPU utilization. It appears that this codebase doesn't support distributed training in its current form. Is that accurate? Is there an equivalent to PyTorch's DistributedDataParallel and DistributedSampler that I can use to perform data-parallel training and inference with Dynet? It would greatly speed up my experiments.
  2. It appears that the training time with a CPU is the same as the training time with GPU when using 1 GPU via provision of the --dynet-gpu flag. Is this what you noticed too during your runs? If not, could you suggest how I can get this to run faster with a GPU?
  3. It appears that the Dynet auto-batching feature isn't working, because I tried running the code with and without the --dynet-autobatch 1 flag and the run-time doesn't seem to change. I see the main training loop looks like the following (where minibatch_size is always set to 1 here):
            for i in range(0, len(train_data), minibatch_size):
                cur_size = min(minibatch_size, len(train_data) - i)
                losses = []
                dy.renew_cg()
                for (src1, src2, tgt) in train_data[i : i + cur_size]:
                    losses.append(self.model.get_loss(src1, src2, tgt))
                batch_loss = dy.esum(losses)
                batch_loss.backward()
                trainer.update()
                epoch_loss += batch_loss.scalar_value()
            logging.info("Epoch loss: %0.4f" % (epoch_loss / len(train_data)))

Doesn't this mean that cur_size is always 1, causing the inner for loop to just iterate over a list of size 1 by default? If I were to override minibatch_size to, say, 32, how does Dynet ensure that 1 forward operation occurs per batch of 32 examples instead of 32 separate forward passes?

Thanks a lot for your time and thanks again for the great work toward protecting endangered languages!