opendilab/LightZero

JAX support

carlosgmartin opened this issue · 3 comments

Would you consider adding support for JAX?

MCTX has already offered the JAX implementation of Gumbel MuZero, so we don't plan to add other JAX implementations in a few days. However, if someone wants to paticipate in our project for more JAX implementations, we are willing to provide the corresponding support.

@PaParaZz1 Looks like mctx doesn't plan to support Sampled MuZero: google-deepmind/mctx#87.

One of the repo owners also says:

The existing Stochastic MuZero implementation is not efficient inside mctx. An alternative library can be created for Stochastic MuZero.

Perhaps there's an opportunity for LightZero here.

As the contributor of mctx said, Sampled MuZero can make MCTS available on complex action spaces, but it is indeed not as good as the best algorithms on these action spaces. Our current research is working on designing a new action representation learning to deal with this problem. If our paper is accepted, we will release the corresponding PyTorch training codes.