DQN JAX
It’s officially here. My implementation of DQN in JAX. I’m going to move away from the model of documenting everything here in the blog and keep things contained inside the repository. This implementation is part of a bigger project implementing some other common RL baselines in JAX. This will include PPO and SAC implemented in JAX in the near future.
The philosophy here is pair jax with standard gymnasium and mujoco environments. Rather than build specialized environments like MJX. This will make it more useful for other applications while it won’t harness the full potential of JIT and XLA. You can find the code here
Feel free through browse to the wiki for more info on the algorithms. I’ll document more stuff there
If you find the code useful don’t forget to leave a ⭐!