google/flax

Flax.linen.conv unexpected behavior.

NITHISHM2410 opened this issue · 1 comments

I'm experiencing an unexpected output when using flax.linen.Conv. My output from conv layer has very odd stats. The mean is around 100-110 and sometimes is nan . I tested the same against TensorFlow 2.15 and I'm getting the expected output.

System information

  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Kaggle Notebook Ubuntu 20.04
  • Flax, jax, jaxlib versions (obtain with pip show flax jax jaxlib: flax: 0.8.4, jax: 0.4.26, jaxlib: 0.4.26
  • Python version: 3.10.13
  • GPU/TPU model and memory: Kaggle GPU P100 (16gb)
  • CUDA version (if applicable): 12.4

Problem you have encountered:

Here's the code to reproduce:

class ResBlock(nn.Module):
    c_in: int
    c_out: int

    @nn.compact
    def __call__(self, h, ):
        hr = h
        h = nn.Conv(features=self.c_out, kernel_size=(3, 3), strides=(1, 1), padding='SAME')(h)
        return h+hr
    
res = nn.Sequential([ResBlock(32, 32)])
p = res.init(jr.PRNGKey(12), jnp.ones((20, 32, 32, 32)))

@jax.jit
def apply_f(p, ip):
    return res.apply(p, ip)

x = apply_f(p, jnp.ones((40, 32, 32, 32)))

The first time I exec the above code I get this message but it doesn't stop the exec. Here's the message:

2024-08-01 04:27:56.893767: E external/xla/xla/service/gpu/buffer_comparator.cc:151] Difference at 0: 32.9203, expected 190.347
2024-08-01 04:27:56.893817: E external/xla/xla/service/gpu/buffer_comparator.cc:151] Difference at 1: 49.2585, expected 288.954
2024-08-01 04:27:56.893827: E external/xla/xla/service/gpu/buffer_comparator.cc:151] Difference at 2: 47.2793, expected 278.698
2024-08-01 04:27:56.893836: E external/xla/xla/service/gpu/buffer_comparator.cc:151] Difference at 3: 47.7769, expected 282.143
2024-08-01 04:27:56.893844: E external/xla/xla/service/gpu/buffer_comparator.cc:151] Difference at 4: 44.5133, expected 260.806
2024-08-01 04:27:56.893853: E external/xla/xla/service/gpu/buffer_comparator.cc:151] Difference at 5: 46.5634, expected 272.129
2024-08-01 04:27:56.893862: E external/xla/xla/service/gpu/buffer_comparator.cc:151] Difference at 6: 51.559, expected 302.015
2024-08-01 04:27:56.893871: E external/xla/xla/service/gpu/buffer_comparator.cc:151] Difference at 7: 52.788, expected 311.954
2024-08-01 04:27:56.893879: E external/xla/xla/service/gpu/buffer_comparator.cc:151] Difference at 8: 50.759, expected 299.335
2024-08-01 04:27:56.893888: E external/xla/xla/service/gpu/buffer_comparator.cc:151] Difference at 9: 47.583, expected 277.351
2024-08-01 04:27:56.894471: E external/xla/xla/service/gpu/conv_algorithm_picker.cc:747] Results mismatch between different convolution algorithms. This is likely a bug/unexpected loss of precision in cudnn.
(f32[40,32,32,32]{3,2,1,0}, u8[0]{0}) custom-call(f32[40,32,32,32]{3,2,1,0}, f32[32,32,3,3]{3,2,1,0}, f32[32]{0}, f32[40,32,32,32]{3,2,1,0}), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_oi01->bf01, custom_call_target="__cudnn$convBiasActivationForward", backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"cudnn_conv_backend_config":{"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":1,"leakyrelu_alpha":0},"force_earliest_schedule":false} for eng15{k5=1,k6=0,k7=1,k10=3} vs eng15{k5=1,k6=0,k7=1,k10=1}
2024-08-01 04:27:56.894496: E external/xla/xla/service/gpu/conv_algorithm_picker.cc:307] Device: Tesla P100-PCIE-16GB
2024-08-01 04:27:56.894504: E external/xla/xla/service/gpu/conv_algorithm_picker.cc:308] Platform: Compute Capability 6.0
2024-08-01 04:27:56.894510: E external/xla/xla/service/gpu/conv_algorithm_picker.cc:309] Driver: 12040 (550.90.7)
2024-08-01 04:27:56.894517: E external/xla/xla/service/gpu/conv_algorithm_picker.cc:310] Runtime: <undefined>
2024-08-01 04:27:56.894529: E external/xla/xla/service/gpu/conv_algorithm_picker.cc:317] cudnn version: 8.9.0

Logs, error messages, etc:

When I check the output stat's, I find these

x = apply_f(p, jnp.ones((40, 32, 32, 32)))
x.mean(), x.max(), x.min(), x.std()
(Array(nan, dtype=float32),
 Array(nan, dtype=float32),
 Array(nan, dtype=float32),
 Array(nan, dtype=float32))

Initializing the model & running it again for several times results in this

x = apply_f(p, jnp.ones((40, 32, 32, 32)))
x.mean(), x.max(), x.min(), x.std()
(Array(-4.959361, dtype=float32),
 Array(78.48207, dtype=float32),
 Array(-84.155266, dtype=float32),
 Array(39.148796, dtype=float32))

What you expected to happen:

I built the same model arch in TensorFlow and if I check the stats:

x.mean(), x.max(), x.min(), x.std()
(0.68330705, 3.5983667, -1.6166728, 1.0021479)

The above is what I expect.

Steps to reproduce:

Click here to access the Kaggle Notebook for reproducing the issue

Reinstalling jax solved the issue.