
iteration fails with distributed dataset

bryorsnef opened this issue · 1 comments


I was hoping to implement a distributed custom training loop on two gpus. However, it looks like for some reason iterating over a distributed dataset with a for loop fails.



strategy <- tf$distribute$MirroredStrategy()

dataset <- tf$data$Dataset$from_tensors(1.)$"repeat"(100L)$batch(16L)
dist_dataset <- strategy$experimental_distribute_dataset(dataset)

train_step <- tf_function(function(x) {
  x - 1.

  for(d in dist_dataset) {
    res <- strategy$run(train_step, args = list(d))

Error in .Primitive("for")(d, , { :
invalid for() loop sequence

Hi, thank you very much for the excellent bug report, I was able to reproduce the issue and I pushed a fix.

Please install the dev version remotes::install_github("t-kalinowski/tfautograph") and try again.

Unfortunately, I don't have easy access to a machine with multiple devices, so I can't test if strategy$run works correctly in that case, but at least on my machine with only one GPU device it results in a crashed R session. Can you please file a new issue if strategy$run also crashes for you when there are multiple devices available?