google/flax

nnx.Swish, jax.swish,... change the input shape

leson207 opened this issue · 2 comments

System information

Problem you have encountered:

print(xBC.shape)
xBC = jax.nn.swish(x) or xBC = nnx.swish(x)
print(xBC.shape)

output shape is not the same as expected

What you expected to happen:

Output:
(1, 128, 288)
(1, 128, 128)
Expect:
(1, 128, 288)
(1, 128, 288)

Hey, currently

assert jax.nn.swish is nnx.swish

so I'm not sure what the issue could be here.

Hey, currently

assert jax.nn.swish is nnx.swish

so I'm not sure what the issue could be here.

I know that them the same, i mean when i use these functions, which build on each other or just other name(silu,swish), it change my input shape