
Mirrored strategy not working with generator

mr-francois opened this issue · 2 comments

Whenever i try to train a model with multiple GPUs and mirrored strategy, training freezes at first validation step.
If i don't use validation data, training freezes after last epoch.


in_shape <- 100
out_shape <- 3
batch_size <- 16

generator_dummy <- function(batch_size, in_shape, out_shape) {
  x <- array(0, dim = c(batch_size, in_shape))
  y <- array(0, dim = c(batch_size, out_shape)) 
  function() {
    return(list(x, y))

mirrored_strategy <- tensorflow::tf$distribute$MirroredStrategy()
with(mirrored_strategy$scope(), { 
  model <- keras::keras_model_sequential(input_shape = in_shape) %>%
    layer_dense(32) %>% 
  model %>% compile(loss = "mse",
                    optimizer = "adam")

gen_train <- generator_dummy(batch_size, in_shape, out_shape)
gen_val <- generator_dummy(batch_size, in_shape, out_shape)

history <-
  model %>% keras::fit(
    x = gen_train,
    validation_data = gen_val,
    steps_per_epoch = 10,
    validation_steps = 3,
    epochs = 4)

Equivalent code in python runs without problems on the same machine and conda environment.

These are my current settings:

> reticulate::py_config()
python:         ~/anaconda3/envs/env_tf_gpu/bin/python
libpython:      ~/anaconda3/envs/env_tf_gpu/lib/
pythonhome:     ~/anaconda3/envs/env_tf_gpu:~/anaconda3/envs/env_tf_gpu
version:        3.7.12 | packaged by conda-forge | (default, Oct 26 2021, 06:08:21)  [GCC 9.4.0]
numpy:          ~/anaconda3/envs/env_tf_gpu/lib/python3.7/site-packages/numpy
numpy_version:  1.21.6

NOTE: Python version was forced by RETICULATE_PYTHON_ENV
> tensorflow::tf_config()
TensorFlow v2.11.0 ()
Python v3.7 (~/anaconda3/envs/env_tf_gpu/bin/python)
> sessionInfo()
R version 4.1.1 (2021-08-10)
Platform: x86_64-conda-linux-gnu (64-bit)
Running under: AlmaLinux 8.5 (Arctic Sphynx)

Matrix products: default
BLAS/LAPACK: ~/anaconda3/envs/env_tf_gpu/lib/

  [1] LC_CTYPE=en_US.UTF-8       LC_NUMERIC=C              
[3] LC_TIME=en_US.UTF-8        LC_COLLATE=en_US.UTF-8    
[7] LC_PAPER=en_US.UTF-8       LC_NAME=C                 
[9] LC_ADDRESS=C               LC_TELEPHONE=C            

attached base packages:
  [1] stats     graphics  grDevices utils     datasets  methods   base     

loaded via a namespace (and not attached):
  [1] compiler_4.1.1    magrittr_2.0.3    Matrix_1.5-3      whisker_0.4.1    
[5] base64enc_0.1-3   withr_2.5.0       Rcpp_1.0.10       reticulate_1.28  
[9] tensorflow_2.11.0 grid_4.1.1        jsonlite_1.8.4    tfruns_1.5.1     
[13] png_0.1-8         lattice_0.20-45  

Hi @mr-francois.

It looks like there is a deadlock when tf$distribute$MirroredStrategy is used with an R generator to generate the training data batches. Fixing this will require some investigation.

Note however that a user-defined generator will generally be the bottleneck in the training pipeline. If you can define your training pipeline using {tfdatasets}, you'll see much greater performance.

Fixed on main now.