nnx.Swish, jax.swish,... change the input shape
leson207 opened this issue · 2 comments
leson207 commented
System information
- Colab
- Latest version (!pip install git+https://github.com/google/flax.git)
Problem you have encountered:
print(xBC.shape)
xBC = jax.nn.swish(x) or xBC = nnx.swish(x)
print(xBC.shape)
output shape is not the same as expected
What you expected to happen:
Output:
(1, 128, 288)
(1, 128, 128)
Expect:
(1, 128, 288)
(1, 128, 288)
cgarciae commented
Hey, currently
assert jax.nn.swish is nnx.swish
so I'm not sure what the issue could be here.
leson207 commented
Hey, currently
assert jax.nn.swish is nnx.swishso I'm not sure what the issue could be here.
I know that them the same, i mean when i use these functions, which build on each other or just other name(silu,swish), it change my input shape