bp.dnn.ToFlaxRNNCell is not working
Dr-Chen-Xiaoyu opened this issue · 4 comments
Dr-Chen-Xiaoyu commented
- Check for duplicate issues.
- Provide a complete example of how to reproduce the bug, wrapped in triple backticks like this:
- If applicable, include full error messages/tracebacks.
Hi, chaoming,
I am trying to use bp.dnn.ToFlaxRNNCell(), but some bugs pop out. I guess this is because some updating issue as to new versions of Flax. or maybe I misuse the function ?
Best,
Xiaoyu Chen
import jax
import jax.numpy as jnp
import flax.linen as nn
import brainpy as bp
import brainpy.math as bm
bm.set_platform('cpu')
bm.set_mode(bm.training_mode)
print('bp version:', bp.__version__)
print('jax version:',jax.__version__)
print('flax version:',jax.__version__)
bp version: 2.6.0
jax version: 0.4.26
flax version: 0.4.26
cell = bp.dnn.ToFlaxRNNCell(bp.dyn.RNNCell(num_in=1, num_out=1,))
class myRNN(nn.Module):
@nn.compact
def __call__(self, x): # x:(batch, time, features)
x = nn.RNN(cell)(x) # Use nn.RNN to unfold the recurrent cell
return x
model = myRNN()
model.init(jax.random.PRNGKey(0), jnp.ones([1,10,1])) # batch,time,feature
---------------------------------------------------------------------------
NotImplementedError Traceback (most recent call last)
/data/xyc/codes/Tests/BrainPy/hessian.ipynb Cell 27 line 1
7 return x
9 model = myRNN()
---> 10 model.init(jax.random.PRNGKey(0), jnp.ones([1,10,1])) # batch,time,feature
[... skipping hidden 9 frame]
/data/xyc/codes/Tests/BrainPy/hessian.ipynb Cell 27 line 6
4 @nn.compact
5 def __call__(self, x): # x:(batch, time, features)
----> 6 x = nn.RNN(cell)(x) # Use nn.RNN to unfold the recurrent cell
7 return x
[... skipping hidden 2 frame]
File ~/anaconda/envs/env_bp_cpu/lib/python3.11/site-packages/flax/linen/recurrent.py:1066, in RNN.__call__(self, inputs, initial_carry, init_key, seq_lengths, return_carry, time_major, reverse, keep_order)
1061 keep_order = self.keep_order
1063 # Infer the number of batch dimensions from the input shape.
1064 # Cells like ConvLSTM have additional spatial dimensions.
1065 time_axis = (
-> 1066 0 if time_major else inputs.ndim - (self.cell.num_feature_axes + 1)
1067 )
1069 # make time_axis positive
1070 if time_axis < 0:
[... skipping hidden 1 frame]
File ~/anaconda/envs/env_bp_cpu/lib/python3.11/site-packages/flax/linen/recurrent.py:84, in RNNCellBase.num_feature_axes(self)
81 @property
82 def num_feature_axes(self) -> int:
83 """Returns the number of feature axes of the RNN cell."""
---> 84 raise NotImplementedError
NotImplementedError:
chaoming0625 commented
Yes, this is somehow the version issue. The flax has evolved.
chaoming0625 commented
Sorry, I am busy with other things. Maybe I can give a fix this weekend.
chaoming0625 commented
Moreover, I will give you a solution for parallerization this weekend. I am so sorry for the late response.
chaoming0625 commented
See #665