EMI-Group/evox

Incompatibility on `Brax Problem` of EvoX with the Latest Version of Flax

Opened this issue · 3 comments

Hello,

I would like to report an incompatibility issue between the newly supported NNX in Flax and the existing version of EvoX.

The EvoX framework implements a Problem for Brax, which requires a policy argument that should be the apply function of a Flax model defined using flax.linen. In this context, the apply requires model weights to be passed as one of its arguments. However, in the latest implementation of flax.nnx, the policy argument for the Brax Problem no longer requires the model weights as input.

This discrepancy results in incompatibility with the existing Problem in EvoX, as the expected function signature no longer aligns with the new implementation of Flax.

It would be great if this issue could be addressed in a future update to ensure compatibility between these frameworks.

Thank you for your attention.

I noticed this issue as well. From the API's standpoint, we can address it fairly easily, given that the weight in NNX is stored directly in self. However, I am uncertain about NNX's compatibility with various transformations, such as tree_map and vmap (especially when mapping along the weight dimension, which is a must for the EC workflow, though perhaps not essential for them). In a way, the new NNX is not functional (as stated in their design), which is a bad thing for us, so for now I would advice not using the NNX.

I see the point now. Thank you for the clarification.

We can keep this issue open to keep track of the progress.