iteration fails with distributed dataset
bryorsnef opened this issue · 1 comments
Hi,
I was hoping to implement a distributed custom training loop on two gpus. However, it looks like for some reason iterating over a distributed dataset with a for loop fails.
MWE:
library(tensorflow)
library(tfautograph)
strategy <- tf$distribute$MirroredStrategy()
dataset <- tf$data$Dataset$from_tensors(1.)$"repeat"(100L)$batch(16L)
dist_dataset <- strategy$experimental_distribute_dataset(dataset)
train_step <- tf_function(function(x) {
x - 1.
})
autograph({
for(d in dist_dataset) {
res <- strategy$run(train_step, args = list(d))
tf$print(res)
}
})
Error in .Primitive("for")(d, , { :
invalid for() loop sequence
Hi, thank you very much for the excellent bug report, I was able to reproduce the issue and I pushed a fix.
Please install the dev version remotes::install_github("t-kalinowski/tfautograph")
and try again.
Unfortunately, I don't have easy access to a machine with multiple devices, so I can't test if strategy$run
works correctly in that case, but at least on my machine with only one GPU device it results in a crashed R session. Can you please file a new issue if strategy$run
also crashes for you when there are multiple devices available?