keras-team/keras

Batchnorm in shared layers goes to nan

sakvaua opened this issue ยท 40 comments

Hi,
There seems to be a bug with Batch Normalization layer when using it for shared layers. I traced the problem to the running mean growing uncontrollably and then going to nan. It surfaced in my kind of a triplet loss model. See code attached to reproduce. Several other issues have this problem:
#11897
#9646

import numpy as np

import keras

import keras.backend as K

from keras.optimizers import Adam
from keras.applications import ResNet50
from keras.applications.resnet50 import preprocess_input as preprocess_resnet50
from keras.models import Model, Input

from keras.layers import Input,Lambda,subtract,GlobalMaxPooling2D,Dense,GlobalAveragePooling2D,concatenate,Activation
from keras.layers import concatenate, Lambda, Input, Dense, Dropout, Flatten, Conv2D, MaxPooling2D, \
        BatchNormalization, Activation, GlobalAveragePooling2D,GlobalMaxPooling2D, SeparableConv2D,DepthwiseConv2D

def get_keras_model(model, input_shape):   
    
    i = Input(shape=input_shape)
    x=model.output
    x = GlobalMaxPooling2D()(x)
    x=BatchNormalization()(x)
    x=Dense(512,kernel_initializer='he_normal')(x)
    x=Activation('relu')(x)
    x=BatchNormalization()(x)
    o=Dense(5005,kernel_initializer='he_normal', activation='softmax')(x)
    return Model(inputs=model.input, outputs=o)

def get_resnet50(input_shape=(192,256,3)):
    K.clear_session()
    resnet=ResNet50(include_top=False, weights='imagenet', input_tensor=None, input_shape=input_shape, pooling=None)

    for l in range(0,len(resnet.layers)):
        if type(resnet.layers[l])==keras.layers.normalization.BatchNormalization:
            resnet.layers[l].trainable=True
        else:
            resnet.layers[l].trainable=False
    return resnet


resnet=ResNet50(include_top=False, weights='imagenet', input_tensor=None, input_shape=(192,256,3), pooling=None)
model=get_keras_model(get_resnet50(input_shape=(192,256,3)),input_shape=(192,256,3))
#model.load_weights(path+'models/Resnet50/loss_{:.4f}_{:.4f}_{:02d}.hdf5'.format(4.9773,0.6173,1))
trimmed_resnet50=Model(inputs=model.input,outputs=model.layers[173].output)
x=trimmed_resnet50.output
x = GlobalAveragePooling2D()(x)
x=BatchNormalization()(x)
x=Dense(1024,kernel_initializer='he_normal')(x)
x=Activation('relu')(x)
x=BatchNormalization()(x)
o=Dense(128,kernel_initializer='he_normal')(x)
base_net=Model(inputs=model.input,outputs=o)
#base_net.summary()


def create_model(base_model, input_shape=(192,256,3)):
    input_tensor1 = Input(shape=input_shape)
    input_tensor2 = Input(shape=input_shape)
    input_tensor3 = Input(shape=input_shape)

    reduce_sum = Dense(1,activation='linear',kernel_initializer='ones',bias_initializer='zeros',name='reduce_sum')

    x1 = base_model(input_tensor1)
    x2 = base_model(input_tensor2)
    x3 = base_model(input_tensor3)

    d12 = subtract([x1,x2])
    d13 = subtract([x1,x3])

    d12 = Lambda(lambda val: (val)**2)(d12)
    d13 = Lambda(lambda val: (val)**2)(d13)

    d12 = reduce_sum(d12)
    d13 = reduce_sum(d13)
    
    d12 = Lambda(lambda val: keras.backend.sqrt(val+K.epsilon()))(d12)
    d13 = Lambda(lambda val: keras.backend.sqrt(val+K.epsilon()))(d13)
    
    d = concatenate([d12,d13])
    d = Activation('softmax')(d)

    model = Model(inputs=[input_tensor1,input_tensor2,input_tensor3], outputs=d)
    metric = Model(inputs=input_tensor1, outputs=x1)


    for l in model.layers:
        if l.name == 'reduce_sum':
            print('reduce sum')
            l.trainable=False
    return model,metric
    
triplet_model,metric = create_model(base_net)

train_1=np.random.uniform(-128,128,size=(1000,192,256,3))
train_2=np.random.uniform(-128,128,size=(1000,192,256,3))
train_3=np.random.uniform(-128,128,size=(1000,192,256,3))
train_y=np.array([1,0]*1000).reshape(1000,2)

triplet_model.compile(loss='categorical_crossentropy',optimizer=Adam(1e-5, clipnorm=1.0),metrics=['accuracy'])
triplet_model.fit([train_1,train_2,train_3], train_y,batch_size=16,epochs=5)

intermediate=Model(inputs=metric.layers[1].get_input_at(0), outputs=metric.layers[1].layers[3].output)
output=intermediate.predict(train_1[0:1])
print(metric.layers[1].layers[3].name)
print(output.shape, output)

print(intermediate.layers[3].weights[2])
print(K.eval(intermediate.layers[3].weights[2]))

I'm trying to train a kind-of triplet loss CNN which has three shared Resnet50 models.
Please make sure that the boxes below are checked before you submit your issue.
If your issue is an implementation question, please ask your question on StackOverflow or on the Keras Slack channel instead of opening a GitHub issue.

Thank you!

  • [ X] Check that you are up-to-date with the master branch of Keras. You can update with:
    pip install git+git://github.com/keras-team/keras.git --upgrade --no-deps

  • [ X] Check that your version of TensorFlow is up-to-date. The installation instructions can be found here.

  • [ X] Provide a link to a GitHub Gist of a Python script that can reproduce your issue (or just copy the script here if it is short).

Hope this bug gets resolved soon. I have spent past 4 weeks on trying to figure out the issue, before understanding that the Batch Norm layers were the culprit.

My training process takes days and then realizing that all the while it was something wrong with the implementation.

I'm sorry for bumping up the topic? Any luck tracking it down? Maybe some workaround possible?

do you solve the problem? I come into the same one

msymp commented

Hi @fchollet , several users in this thread have found this Batch Normalization layer bug using in shared layers. Thanks.

I just came across this same bug while training a triplet loss model.

I'm using Resnet50 as encoder and a standard triplet loss:

    input_anchor = Input(shape=(224, 224, 3))
    input_positive = Input(shape=(224, 224, 3))
    input_negative = Input(shape=(224, 224, 3))

    encoded_anchor = resnet_encoder(input_anchor)
    encoded_positive = resnet_encoder(input_positive)
    encoded_negative = resnet_encoder(input_negative)

    distance_good = Lambda(euclidean_distance)([encoded_anchor, encoded_positive])
    distance_bad = Lambda(euclidean_distance)([encoded_anchor, encoded_negative])
    distances = Concatenate(axis=1)([distance_good, distance_bad])

    model = Model([input_anchor, input_positive, input_negative], distances)

With the triplet loss defined as:

def triplet_loss(y_true, y_pred):
    '''Triplet loss'''
    distance_good = y_pred[:, 0]
    distance_bad = y_pred[:, 1]
    margin = 0.2
    loss = K.square(distance_good) - K.square(distance_bad) + margin
    return K.maximum(loss, 0.0)

def euclidean_distance(vects):
    '''Return the euclidean distance between two vectors'''
    x, y = vects
    distance = K.sqrt(K.sum(K.square(x - y), axis=1, keepdims=True))
    return K.maximum(distance, K.epsilon())

The training loss decreases nicely while the testing loss explodes after ~30 batches. The culprit is the moving_average of the first batch normalization layer.

I tried training with my dataset, using random noise images, with and without imagenet weights, the result is always the same. The moving_mean starts with a maximum value of around 3 and doubles for each epoch until it breaks the entire model.

Ironically, the training doesn't stop because the "learnt" broken moving_mean is only used in testing phase, while the minibatch local mean is used for BatchNormalization during training, so it's very hard to spot the bug!!!

@fchollet is there anything that can be done about this problem?

It has been close to 2 months now with no avail. It would be helpful if this issue / bug is made public in the Keras Documentation. Since most of the time we realize that something is wrong with the Library rather than the code very late.

In the past I spent close to a month on trying to fix my code before realizing that it was not the architecture but rather the running means which were screwing up.

I believe this issue is also related to this tracked bug in tensorflow, although it was closed 2 weeks ago because of lack of activity...

@sidgairo18 do you have any workaround?

@ale152 Hi! I had an upcoming paper deadline so my work around was porting the entire code to PyTorch. (I was really frustrated with trying to fix this, with absolutely no support whatsoever).

It worked like a breeze.

Oh no, it would be an immense work for me trying to port everything into PyTorch...

Anyway, I confirm that the bug only concerns networks with 3 inputs and it does not affect networks with 2 inputs:

import numpy as np

from keras.models import Model
from keras.layers import Input, Flatten, Lambda
from keras.optimizers import Adam
from keras import backend as K
from keras.applications import ResNet50


def dummy_distance(vects):
    x, y, z = vects
    sum_square = K.sum(K.square(x - y), axis=1, keepdims=True)
    return K.sqrt(K.maximum(sum_square, K.epsilon()))

# network definition
input_shape = (224, 224, 3)
model_resnet = ResNet50(include_top=False)

input_encoder = Input(shape=input_shape)
net = model_resnet(input_encoder)
output_encoder = Flatten()(net)
resnet_encoder = Model(input_encoder, output_encoder)

input_a = Input(shape=input_shape)
input_b = Input(shape=input_shape)
input_c = Input(shape=input_shape)

processed_a = resnet_encoder(input_a)
processed_b = resnet_encoder(input_b)
processed_c = resnet_encoder(input_c)

distance = Lambda(dummy_distance)([processed_a, processed_b, processed_c])

model = Model([input_a, input_b, input_c], distance)

# train
model.compile(loss='mse', optimizer=Adam())

for i in range(30):
    print('Batch {}'.format(i))
    input_1 = np.random.random(([2, input_shape[0], input_shape[1], input_shape[2]]))
    input_2 = np.random.random(([2, input_shape[0], input_shape[1], input_shape[2]]))
    input_3 = np.random.random(([2, input_shape[0], input_shape[1], input_shape[2]]))

    y = np.zeros((2,1))

    out = model.train_on_batch([input_1, input_2, input_3], y)

    submodel = model.layers[3].layers[1]
    first_bn = submodel.layers[2]
    print('Maximum bn moving_mean: {}'.format(K.get_value(first_bn.moving_mean).max()))

produces:

Batch 0
Maximum bn moving_mean: 9.986820220947266
Batch 1
Maximum bn moving_mean: 15.314239501953125
Batch 2
Maximum bn moving_mean: 34.565223693847656
Batch 3
Maximum bn moving_mean: 68.03244018554688
Batch 4
Maximum bn moving_mean: 132.89358520507812
Batch 5
Maximum bn moving_mean: 278.90631103515625
Batch 6
Maximum bn moving_mean: 526.226318359375
Batch 7
Maximum bn moving_mean: 1122.40673828125
Batch 8
Maximum bn moving_mean: 2099.57470703125
Batch 9
Maximum bn moving_mean: 4496.4072265625
Batch 10
Maximum bn moving_mean: 8392.982421875
Batch 11
Maximum bn moving_mean: 17992.404296875
Batch 12
Maximum bn moving_mean: 33566.62109375
Batch 13
Maximum bn moving_mean: 71976.390625
Batch 14
Maximum bn moving_mean: 134261.1875
Batch 15
Maximum bn moving_mean: 287912.375
Batch 16
Maximum bn moving_mean: 537039.375
Batch 17
Maximum bn moving_mean: 1151656.25
Batch 18
Maximum bn moving_mean: 2148152.25
Batch 19
Maximum bn moving_mean: 4606632.0

With two inputs only:

import numpy as np

from keras.models import Model
from keras.layers import Input, Flatten, Lambda
from keras.optimizers import Adam
from keras import backend as K
from keras.applications import ResNet50


def dummy_distance(vects):
    x, y = vects
    sum_square = K.sum(K.square(x - y), axis=1, keepdims=True)
    return K.sqrt(K.maximum(sum_square, K.epsilon()))

# network definition
input_shape = (224, 224, 3)
model_resnet = ResNet50(include_top=False)

input_encoder = Input(shape=input_shape)
net = model_resnet(input_encoder)
output_encoder = Flatten()(net)
resnet_encoder = Model(input_encoder, output_encoder)

input_a = Input(shape=input_shape)
input_b = Input(shape=input_shape)

processed_a = resnet_encoder(input_a)
processed_b = resnet_encoder(input_b)

distance = Lambda(dummy_distance)([processed_a, processed_b])

model = Model([input_a, input_b], distance)

# train
model.compile(loss='mse', optimizer=Adam())

for i in range(30):
    print('Batch {}'.format(i))
    input_1 = np.random.random(([2, input_shape[0], input_shape[1], input_shape[2]]))
    input_2 = np.random.random(([2, input_shape[0], input_shape[1], input_shape[2]]))

    y = np.zeros((2,1))

    out = model.train_on_batch([input_1, input_2], y)

    submodel = model.layers[2].layers[1]
    first_bn = submodel.layers[2]
    print('Maximum bn moving_mean: {}'.format(K.get_value(first_bn.moving_mean).max()))

it works like expected:

Batch 0
Maximum bn moving_mean: 5.890077590942383
Batch 1
Maximum bn moving_mean: 3.331754684448242
Batch 2
Maximum bn moving_mean: 5.894614219665527
Batch 3
Maximum bn moving_mean: 3.330280065536499
Batch 4
Maximum bn moving_mean: 5.899479389190674
Batch 5
Maximum bn moving_mean: 3.3306570053100586
Batch 6
Maximum bn moving_mean: 5.904243469238281
Batch 7
Maximum bn moving_mean: 3.3300342559814453
Batch 8
Maximum bn moving_mean: 5.909705638885498
Batch 9
Maximum bn moving_mean: 3.3289315700531006

This explains why so many people experienced this bug while implementing a triplet loss...

Just for the record, using a network with 3 inputs duplicates the moving_mean parameter, while a network with 4 inputs triplicates it. I hope this will help to find the bug.

@ale152 I understand it would be painful for you to re-write the entire thing. But I strongly recommend that you do so, it should not take more than a couple of days.
It would be a good shift at the right time.

I have been experiencing similar issues with training a convolutional network which uses triplet loss in Keras. It seems that the problem can be countered by reimporting tensorflow explicitly and/or creating a new session right before creating and fitting your model.

I.e.

import tensorflow as tf
with tf.Session() as sess:
    # Create your model and run it afresh

I'm not sure if this will be a solution for everyone, but it may help in tracking down the root cause of the problem.

@sdpenguin I tried as you said, but the batch normalisation layer keeps exploding (see code below).
Can you get this simple example to work with your trick?

import numpy as np

from keras.models import Model
from keras.layers import Input, Flatten, Lambda
from keras.optimizers import Adam
from keras import backend as K
from keras.applications import ResNet50


def dummy_distance(vects):
    x, y, z = vects
    sum_square = K.sum(K.square(x - y), axis=1, keepdims=True)
    return K.sqrt(K.maximum(sum_square, K.epsilon()))

import tensorflow as tf
with tf.Session() as sess:
    # network definition
    input_shape = (224, 224, 3)
    model_resnet = ResNet50(include_top=False)

    input_encoder = Input(shape=input_shape)
    net = model_resnet(input_encoder)
    output_encoder = Flatten()(net)
    resnet_encoder = Model(input_encoder, output_encoder)

    input_a = Input(shape=input_shape)
    input_b = Input(shape=input_shape)
    input_c = Input(shape=input_shape)

    processed_a = resnet_encoder(input_a)
    processed_b = resnet_encoder(input_b)
    processed_c = resnet_encoder(input_c)

    distance = Lambda(dummy_distance)([processed_a, processed_b, processed_c])

    model = Model([input_a, input_b, input_c], distance)

    # train
    model.compile(loss='mse', optimizer=Adam())

    for i in range(30):
        print('Batch {}'.format(i))
        input_1 = np.random.random(([2, input_shape[0], input_shape[1], input_shape[2]]))
        input_2 = np.random.random(([2, input_shape[0], input_shape[1], input_shape[2]]))
        input_3 = np.random.random(([2, input_shape[0], input_shape[1], input_shape[2]]))

        y = np.zeros((2,1))

        out = model.train_on_batch([input_1, input_2, input_3], y)

        submodel = model.layers[3].layers[1]
        first_bn = submodel.layers[2]
        print('Maximum bn moving_mean: {}'.format(K.get_value(first_bn.moving_mean).max()))

@ale152 Hi, sorry for the late reply. I couldn't manage to get your code to work with the reimport, and it seems as though this doesn't actually completely solve the problem for my code either. I have searched around and found one possible relevant thread though: #9965. Perhaps you can try initialising the batch normalisation layers in your models with:

x = BatchNormalization()(y, training=False)

As suggested in the thread.

@sdpenguin The problem with setting training=False is that the layer isn't trained anymore... At that point, it'd be the same as not using BatchNormalization at all.

I noticed that the instability with BatchNormalization also depends on the depth of the network. In the following example I'm using 8 layers of BatchNormalization and the layer explodes after ~500 epochs. With 12 layers, the networks become unstable after just 10 epochs!

import keras
import numpy as np

from keras.models import Model, Sequential
from keras.layers import Input, Flatten, Lambda, Conv2D, MaxPooling2D, Dropout, Dense, BatchNormalization
from keras.optimizers import Adam
from keras import backend as K
from keras.applications import ResNet50


def dummy_distance(vects):
    x, y, z = vects
    sum_square = K.sum(K.square(x - y), axis=1, keepdims=True)
    return K.sqrt(K.maximum(sum_square, K.epsilon()))

# network definition
input_shape = (5, 5, 1)
inputs = Input(input_shape)
net = Conv2D(3, kernel_size=(3, 3), input_shape=(5, 5, 1))(inputs)
net = Conv2D(4, (3, 3), activation='relu', padding='same')(net)
net = BatchNormalization()(net)
net = Conv2D(4, (3, 3), activation='relu', padding='same')(net)
net = BatchNormalization()(net)
net = Conv2D(4, (3, 3), activation='relu', padding='same')(net)
net = BatchNormalization()(net)
net = Conv2D(4, (3, 3), activation='relu', padding='same')(net)
net = BatchNormalization()(net)
net = Conv2D(4, (3, 3), activation='relu', padding='same')(net)
net = BatchNormalization()(net)
net = Conv2D(4, (3, 3), activation='relu', padding='same')(net)
net = BatchNormalization()(net)
net = Conv2D(4, (3, 3), activation='relu', padding='same')(net)
net = BatchNormalization()(net)
net = Conv2D(4, (3, 3), activation='relu', padding='same')(net)
net = Flatten()(net)
net = Dense(3, activation='relu')(net)
encoder = Model(inputs=inputs, outputs=net)

input_a = Input(shape=(5, 5, 1))
input_b = Input(shape=(5, 5, 1))
input_c = Input(shape=(5, 5, 1))

processed_a = encoder(input_a)
processed_b = encoder(input_b)
processed_c = encoder(input_c)

distance = Lambda(dummy_distance)([processed_a, processed_b, processed_c])

model = Model([input_a, input_b, input_c], distance)

# train
model.compile(loss='mse', optimizer=Adam(lr=1e-6))
bf = []
bs=128
for i in range(5000):
    print('Batch {}'.format(i))
    input_1 = np.random.random(([bs, input_shape[0], input_shape[1], input_shape[2]]))
    input_2 = np.random.random(([bs, input_shape[0], input_shape[1], input_shape[2]]))
    input_3 = np.random.random(([bs, input_shape[0], input_shape[1], input_shape[2]]))

    y = np.zeros((bs,1))

    out = model.train_on_batch([input_1, input_2, input_3], y)

    first_bn = model.layers[3].layers[3]
    bf.append(K.get_value(first_bn.moving_mean).max())
    print('Maximum bn moving_mean: {}'.format(bf[-1]))

This bug is now six moths old, and is still breaking siamese networks. Any updates?

The suggestions so far of disabling BatchNorm layers or rewriting everything in PyTorch aren't ideal.

@kwaegel I updated my environment to python 3.7, keras 2.2.4, tensorflow 1.13.1 and I can't reproduce the bug anymore. Maybe it was fixed and never announced here?

I updated my environment to python 3.7, keras 2.2.4, tensorflow 1.13.1 and I can't reproduce the bug anymore. Maybe it was fixed and never announced here?

@ale152: did you test your train several times and always got good result or you tried just once ?

In my case, I'm with:

  • Python 3.6.7 (default in Ubuntu 18.04)
  • Keras 2.2.4
  • Backend = Tensorflow 1.13.1 GPU

I've also a Siamese network with triplet-loss (so 3 inputs for the network).
The crazy thing in my case is:

  • For the 2 first trainings, I got "correct" result. "correct" means I don't have nan on moving-average. But not really "correct" because my moving-average and moving-variance are very high
  • From then, I tried 3rd, 4th, ... times and always got nan on the moving-average of the BN layer (which is just after the first Conv2D). (I cannot reproduce the "correct" result as the 2 first trainings)
    My codes + my database are the same for all above training cases. The only different thing is the result of shuffle when I create mini-batch.

Thanks !

@baocareos , do you also use Dropout in your network? I noticed that Dropout sometimes causes problems when used together with BatchNorm...

@ale152: No, I don't use Dropout in my network.
I use mainly Conv2D/Depthwise, BatchNorm, Relu and just one Dense layer (without dropout) + L2 norm layer in the end of the base network.
The siamese network uses this base network but with 3 inputt

I updated my environment to python 3.7, keras 2.2.4, tensorflow 1.13.1 and I can't reproduce the bug anymore. Maybe it was fixed and never announced here?

@ale152: did you test your train several times and always got good result or you tried just once ?

In my case, I'm with:

* Python 3.6.7 (default in Ubuntu 18.04)

* Keras 2.2.4

* Backend = Tensorflow 1.13.1 GPU

I've also a Siamese network with triplet-loss (so 3 inputs for the network).
The crazy thing in my case is:

* For the 2 first trainings, I got "correct" result. "correct" means I don't have nan on moving-average. But not really "correct" because my moving-average and moving-variance are very high

* From then, I tried 3rd, 4th, ... times and always got nan on the moving-average of the BN layer (which is just after the first Conv2D). (I cannot reproduce the "correct" result as the 2 first trainings)
  My codes + my database are the same for all above training cases. The only different thing is the result of shuffle when I create mini-batch.

Thanks !

@baocareos I get the same error. It is really sad that over 7 months have passed and this cannot be fixed. I've got a lot of trouble because of this.

I have the same error when training Siamese like network with binary loss and I spent like couple of weeks to figure out batch normalization was the culprit. My batch normalization layer was after global average pooling of resnet50. The batch normalization work by computing mean and variance of features of samples in a batch:

x= (x-batch_mean)/batch_var
batch_var= 1/(N-1) * batch_var

There could be two cases where BN layer give nan (1) if x =0, mean=0, var=0 which can occur when there are no feature in particular channel in all samples. (2) N=1, gives batch_var division by zero which can results in nan. I hope case 2 can be easily solved by changing the batch size, but for case 1, we cannot really control the output of filter and hence channel features. For now I remove that layer. Would be grateful if this bug can be fixed!

I am facing the same issue training GANs with Batchnorm layers in generator and 3 losses. The training loss appears in reasonable range, but in evaluation both moving mean and variance are NaN.

I am facing the same issue with a ResNet 50 from tf.keras.applications.resnet50.ResNet50 with only batch of images as input. The moving mean and moving variance explose up to nan.
I am using tensorflow 1.14 and python 3.7

@fchollet or whoever is looking or assigned to looking at the issue: this issue should be prioritized, as it basically renders BatchNorm un-useable in practical situations. BTW Merry Christmas!

is there any answer about it?

I encounter the same problem when I'm training a pseudo siamese network with mmd loss and mae loss. Is there any solution for that?

wow it is really very sad to see this issue open for almost 2 years...

@PokeLu

I do not use this batch normalization but did you tried to add the norm of the moving mean to your norm? (ex: + 0.1 * abs(running)mean)

This is an ugly fix but it may prevent the coefficient from exploding (see L2 / L1 regularization) .

Just commenting to add weight to this issue. Like others, I wasted a lot of time tracking this down. And there is no real workaround available. I'd love to see this solved!

I have experienced simmilar issue and described it here:
tensorflow/tensorflow#43603

I have a feeling that somehow batch normalization wrongly aggregates values across columns of shared vision model...
I hope for a real solution.

Similar network can be rewritten as set of TimeDistributed layers, or as 3D convoltion (in these 2 cases, normalization seems to not crash), however I can not use them as I need to extract my shared vision model and reuse it later on single images. Also 3DConvolution and TimeDistributed wrappers are not convertable to TFLite.

Does this being resolved?

Same problem here! Already checked many things. Since we do not use fused batch norm, batch size = 1 is okay to me. We also do not use dropout. Really cannot find any clue!

I encountered the same problem here. My training goes fine. But when testing (i.e. when setting training=False in BatchNormalization, it produced nan after a few batches

Wow I'm having the same problem. and nothing yet ๐Ÿคฆ๐Ÿผโ€โ™‚๏ธ
I think I need to learn PyTorch

In my case there was bug in my custom loop I had batches of size o and 1 which makes sens, the variance and the mean would be undefined in these cases

@fchollet Why did you close this issue - was there a commit associated with a fix? Otherwise it appears unresolved and closed by mistake.

I can confirm that there are still issues with BatchNorm when two paths go through the same BN layer (siamese, multi-tower etc).

I'm surprised the issue was simply closed without any solution or suggestion of a workaround.

This issue happens for me training a GAN with 3 losses. Closing it without saying a word is a slap to the face to the people who spent weeks struggling trying to find out what their issue was. Time to check out pytorch.

Does this have anything to do with the issue tf.keras.layers.experimental.SyncBatchNormalization()?