A simple and efficient library for training humanoid locomotion in MJX and MuJoCo.
- Clone this repository:
git clone https://github.com/kscalelabs/ksim.git
cd ksim
- It is recommended that you use a virtual environment to install the dependencies for this project. You can create a new conda environment using the following command:
conda create --name ksim python=3.11
conda activate ksim
To install the dependencies, run the following command:
make install-dev
Finally, add your Weights & Biases API key via conda:
conda env config vars set WANDB_API_KEY=<your api key>
MJX Gym is a library for training and evaluating reinforcement learning agents in MJX environments. It is built on top of the Brax library and provides a simple interface for running experiments with Stompy and other humanoid formfactors. Currently, we support walking but plan on adding more tasks and simulator environments in the future.
For quick experimentation, you may specify all relevant training configurations via YAML files, and simply run train.py with the desired configuration. Our configurations allow for rapid reward function prototyping, environment specification, and hyperparameter tuning. Additionally, all training results are tracked and logged using Weights & Biases.
We recommend starting with the default humanoid environment to get a feel for the simulator. To train the default humanoid environment, first navigate to the /mjx_gym directory:
cd ksim/mjx_gym
Then, run the following command:
python train.py --config experiments/default_humanoid_walk.yaml
Example training curves are shown below:
We provide an easy way to test and render the trained model using the play.py script. This script loads the trained model and runs it in the specified environment.
To test and render the trained model, run the following command:
python play.py --config experiments/default_humanoid_walk.yaml
Here is an example of the trained model walking in the default humanoid environment after roughly 15 minutes of training on the K-Scale cluster.
You may transfer the model (trained via MJX) to MuJoCo for CPU-based simulation via arguments to the play script. For example, to run the model in the default humanoid environment, use the following command:
python play.py --config experiments/default_humanoid_walk.yaml --use_mujoco
It is important that the configuration file used for testing is the same as the one used for training. This ensures that the model is loaded correctly and avoids any catastrophic failures during testing. For simplicity, we have included default training weights in the /mjx_gym/weights directory, which are automatically used when running the play script with the command above. However, you may specify a different weight file using the --params_path argument.
Here is an example of the trained model (in MJX) walking in the default humanoid environment in MuJoCo.
While the model performs well in the MJX environment, it is important to note that the model's performance may vary when transferred to MuJoCo. This is due to slight differences in the simulators and the underlying physics engines. Evaluating the model in both simulators is helpful for determining whether the model will generalize well across different environments.
Getting Stompy to walk in MJX is a challenging task, and we encourage the open source community to contribute toward training efforts. The most recent set of weights allows stompy (without upper body) to maintain balance in a stable environment. An example trajectory is included below:
To run from a SLURM cluster (e.g. the K-Scale Andromeda cluster), cd to /mjx_gym/ and run:
sbatch scripts/train.slurm
Training logs, errors, and gpu-utilization will be tracked in the /logs/ directory.
You might see the following error when running train.py:
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
One way to fix this issue is to uninstall jax and jaxlib and reinstall them with your specific CUDA version the following commands:
pip uninstall jax jaxlib
pip install --upgrade "jax[cuda12]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Note that the CUDA version should match the one installed on your machine.
A similar issue might also occur:
jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.
This error is usually caused by an incorrect version of cudnn installed in your environment. To fix this issue, you can try installing cuDNN using the following command:
conda install cudnn
Another common issue occurs when rendering on a headless server. For now, we recommend rendering locally or using a remote desktop connection to view the rendering.
To visualize any MJCF file, you can run the following command:
python3 -m mujoco.viewer --mjcf <path-to-mjcf-file>
The command above loads MuJoCo's GUI, which allows you to simulate the model, manually specify joints, and save keyframes.
- Bootstrap
- Get Stompy walking with PPO through bootstrapping
- Implement imitation learning techniques for end-to-end walking and recovering tasks
- Add CPU-based MuJoCo training platform for improved sim-to-sim support
- Add sim-to-real transfer learning techniques during training