Incompatibility on `Brax Problem` of EvoX with the Latest Version of Flax
Opened this issue · 3 comments
Hello,
I would like to report an incompatibility issue between the newly supported NNX in Flax and the existing version of EvoX.
The EvoX framework implements a Problem
for Brax, which requires a policy
argument that should be the apply
function of a Flax model defined using flax.linen
. In this context, the apply
requires model weights to be passed as one of its arguments. However, in the latest implementation of flax.nnx
, the policy
argument for the Brax Problem
no longer requires the model weights as input.
This discrepancy results in incompatibility with the existing Problem
in EvoX, as the expected function signature no longer aligns with the new implementation of Flax.
It would be great if this issue could be addressed in a future update to ensure compatibility between these frameworks.
Thank you for your attention.
I noticed this issue as well. From the API's standpoint, we can address it fairly easily, given that the weight in NNX is stored directly in self. However, I am uncertain about NNX's compatibility with various transformations, such as tree_map and vmap (especially when mapping along the weight dimension, which is a must for the EC workflow, though perhaps not essential for them). In a way, the new NNX is not functional (as stated in their design), which is a bad thing for us, so for now I would advice not using the NNX.
I see the point now. Thank you for the clarification.
We can keep this issue open to keep track of the progress.