Flax.linen.conv unexpected behavior.
NITHISHM2410 opened this issue · 1 comments
NITHISHM2410 commented
I'm experiencing an unexpected output when using flax.linen.Conv. My output from conv layer has very odd stats. The mean is around 100-110 and sometimes is nan . I tested the same against TensorFlow 2.15 and I'm getting the expected output.
System information
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Kaggle Notebook Ubuntu 20.04
- Flax, jax, jaxlib versions (obtain with
pip show flax jax jaxlib
: flax: 0.8.4, jax: 0.4.26, jaxlib: 0.4.26 - Python version: 3.10.13
- GPU/TPU model and memory: Kaggle GPU P100 (16gb)
- CUDA version (if applicable): 12.4
Problem you have encountered:
Here's the code to reproduce:
class ResBlock(nn.Module):
c_in: int
c_out: int
@nn.compact
def __call__(self, h, ):
hr = h
h = nn.Conv(features=self.c_out, kernel_size=(3, 3), strides=(1, 1), padding='SAME')(h)
return h+hr
res = nn.Sequential([ResBlock(32, 32)])
p = res.init(jr.PRNGKey(12), jnp.ones((20, 32, 32, 32)))
@jax.jit
def apply_f(p, ip):
return res.apply(p, ip)
x = apply_f(p, jnp.ones((40, 32, 32, 32)))
The first time I exec the above code I get this message but it doesn't stop the exec. Here's the message:
2024-08-01 04:27:56.893767: E external/xla/xla/service/gpu/buffer_comparator.cc:151] Difference at 0: 32.9203, expected 190.347
2024-08-01 04:27:56.893817: E external/xla/xla/service/gpu/buffer_comparator.cc:151] Difference at 1: 49.2585, expected 288.954
2024-08-01 04:27:56.893827: E external/xla/xla/service/gpu/buffer_comparator.cc:151] Difference at 2: 47.2793, expected 278.698
2024-08-01 04:27:56.893836: E external/xla/xla/service/gpu/buffer_comparator.cc:151] Difference at 3: 47.7769, expected 282.143
2024-08-01 04:27:56.893844: E external/xla/xla/service/gpu/buffer_comparator.cc:151] Difference at 4: 44.5133, expected 260.806
2024-08-01 04:27:56.893853: E external/xla/xla/service/gpu/buffer_comparator.cc:151] Difference at 5: 46.5634, expected 272.129
2024-08-01 04:27:56.893862: E external/xla/xla/service/gpu/buffer_comparator.cc:151] Difference at 6: 51.559, expected 302.015
2024-08-01 04:27:56.893871: E external/xla/xla/service/gpu/buffer_comparator.cc:151] Difference at 7: 52.788, expected 311.954
2024-08-01 04:27:56.893879: E external/xla/xla/service/gpu/buffer_comparator.cc:151] Difference at 8: 50.759, expected 299.335
2024-08-01 04:27:56.893888: E external/xla/xla/service/gpu/buffer_comparator.cc:151] Difference at 9: 47.583, expected 277.351
2024-08-01 04:27:56.894471: E external/xla/xla/service/gpu/conv_algorithm_picker.cc:747] Results mismatch between different convolution algorithms. This is likely a bug/unexpected loss of precision in cudnn.
(f32[40,32,32,32]{3,2,1,0}, u8[0]{0}) custom-call(f32[40,32,32,32]{3,2,1,0}, f32[32,32,3,3]{3,2,1,0}, f32[32]{0}, f32[40,32,32,32]{3,2,1,0}), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_oi01->bf01, custom_call_target="__cudnn$convBiasActivationForward", backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"cudnn_conv_backend_config":{"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":1,"leakyrelu_alpha":0},"force_earliest_schedule":false} for eng15{k5=1,k6=0,k7=1,k10=3} vs eng15{k5=1,k6=0,k7=1,k10=1}
2024-08-01 04:27:56.894496: E external/xla/xla/service/gpu/conv_algorithm_picker.cc:307] Device: Tesla P100-PCIE-16GB
2024-08-01 04:27:56.894504: E external/xla/xla/service/gpu/conv_algorithm_picker.cc:308] Platform: Compute Capability 6.0
2024-08-01 04:27:56.894510: E external/xla/xla/service/gpu/conv_algorithm_picker.cc:309] Driver: 12040 (550.90.7)
2024-08-01 04:27:56.894517: E external/xla/xla/service/gpu/conv_algorithm_picker.cc:310] Runtime: <undefined>
2024-08-01 04:27:56.894529: E external/xla/xla/service/gpu/conv_algorithm_picker.cc:317] cudnn version: 8.9.0
Logs, error messages, etc:
When I check the output stat's, I find these
x = apply_f(p, jnp.ones((40, 32, 32, 32)))
x.mean(), x.max(), x.min(), x.std()
(Array(nan, dtype=float32),
Array(nan, dtype=float32),
Array(nan, dtype=float32),
Array(nan, dtype=float32))
Initializing the model & running it again for several times results in this
x = apply_f(p, jnp.ones((40, 32, 32, 32)))
x.mean(), x.max(), x.min(), x.std()
(Array(-4.959361, dtype=float32),
Array(78.48207, dtype=float32),
Array(-84.155266, dtype=float32),
Array(39.148796, dtype=float32))
What you expected to happen:
I built the same model arch in TensorFlow and if I check the stats:
x.mean(), x.max(), x.min(), x.std()
(0.68330705, 3.5983667, -1.6166728, 1.0021479)
The above is what I expect.
Steps to reproduce:
Click here to access the Kaggle Notebook for reproducing the issue
NITHISHM2410 commented
Reinstalling jax solved the issue.