In this article we will go in depth into implementing PPO using JAX. You can find the code repository here.
Introduction
Proximal Policy Optimization (PPO) is a policy gradient algorithm that incorporates concepts from trust regions, function approximation, and model-free reinforcement learning. In practice, PPO commonly uses artificial neural networks as function approximators, making it a key method in deep reinforcement learning. PPO has achieved notable success across various Gymnasium tasks and has even been used for post-deployment training of models like ChatGPT. The core idea of PPO is a clipped objective function, which helps maintain a trust region during optimization and stabilizes training.
Characteristic
Description
Algorithm Type
Policy Gradient
Trust Region
Clipped surrogate objective
Function Approximator
Neural Networks (typically MLPs or CNNs)
On/Off Policy
On-policy
Online/Offline
Online
Before we dive into implementing PPO let’s cover some background topics to get a better understanding.
Background
Policy Gradient Methods
In policy gradients we are conerned with maximizing an often parameterized scalar performance measure \(\mathcal{J}(\mathbf{\theta})\). To place the gradient into policy-gradient we use the following estimator for the performance gradient:
Where \(\hat{G}_t\)can be (but doesn’t have to) the general advantage estimation: \[
\hat{G}_t = \sum_{i=0}^{T-t-1} (\gamma \lambda)^i \, \delta_{t+i}
\]
where \(\delta_t\) is the TD error: \[
\delta_t = r_t + \gamma V(s_{t+1}) - V(s_t)
\]
Trust Region Policy Optimization
We can express TRPO as a constrained optimization problem:
This is essentially saying, “Optimize performance, but do it cautiously.” In other words, the constraint prevents the policy from diverging too much from the old policy and limits changes to a so called trust region specified by the parameter \(\epsilon\).
PPO
The main performance objective from PPO clips the objective \(\frac{\pi_\theta(a|s)}{\pi_{\theta_{\text{old}}}(a|s)} \hat{G}_t\) within the Trust Region boundaries represented by the interval \([1-\epsilon, 1+\epsilon]\):
Where \(\rho_t(\theta) = \frac{\pi_\theta(a_t|s_t)}{\pi_{\theta_{\text{old}}}(a_t|s_t)}\)
To visualize the effect of this clipping let’s consider a toy problem with a single parameter \(\theta \in \mathbb{R}\) which creates the parameterized policy
Now imagine at a given time step we receive \(\hat{G}_t = 1\) as our advantage estimate, the loss function (unclipped objective) will have the following form
Now let’s apply the clipping the get our clipped surrogate objective used in PPO (\(\epsilon=0.2\)):
Show code
clip =0.2L_clipped =lambda action, G: G * jnp.minimum(rho_theta(action), jnp.clip(rho_theta(action), 1-clip, 1+clip))fig, axes = plt.subplots(2, 2, figsize=(24, 24))for ax_row in axes:for ax in ax_row: ax.tick_params(axis='both', which='major', labelsize=22) ax.title.set_size(28) ax.xaxis.label.set_size(26) ax.yaxis.label.set_size(26)axes[0][0].plot(jax.vmap(rho_theta)(a), jax.vmap(L_clipped, in_axes=(0, None))(a, 1.0), label=r'$J^{CLIP}(\theta), G > 0$')axes[0][0].set_title(r'$J^{CLIP}(\theta), G > 0$')axes[0][0].set_xlabel(r'$\rho(\theta)$')axes[0][0].set_ylabel(r'$J^{CLIP}(\theta)$') axes[0][0].grid(True)axes[0][1].plot(jax.vmap(rho_theta)(a), jax.vmap(L_clipped, in_axes=(0, None))(a, -1.0), label=r'$J^{CLIP}(\theta), G < 0$')axes[0][1].set_title(r'$J^{CLIP}(\theta), G > 0$')axes[0][1].set_xlabel(r'$\rho(\theta)$')axes[0][1].set_ylabel(r'$J^{CLIP}(\theta)$')axes[0][1].grid(True)grad_clipped_loss = jax.grad(L_clipped)axes[1][0].plot(jax.vmap(rho_theta)(a), jax.vmap(grad_clipped_loss, in_axes=(0, None))(a, 1.0), label=r'$\nabla J^{CLIP}(\theta), G > 0$')axes[1][0].set_title(r'$\nabla J^{CLIP}(\theta), G > 0$')axes[1][0].set_xlabel(r'$\rho(\theta)$')axes[1][0].set_ylabel(r'$\nabla J^{CLIP}(\theta)$')axes[1][0].grid(True)axes[1][1].plot(jax.vmap(rho_theta)(a), jax.vmap(grad_clipped_loss, in_axes=(0, None))(a, -1.0), label=r'$\nabla J^{CLIP}(\theta), G < 0$')axes[1][1].set_title(r'$\nabla J^{CLIP}(\theta), G < 0$')axes[1][1].set_xlabel(r'$\rho(\theta)$')axes[1][1].set_ylabel(r'$\nabla J^{CLIP}(\theta)$')axes[1][1].grid(True)plt.suptitle("Clipped Objective")plt.show()
Implementation Details in Discrete Action Space
The implementation details listed here are based on commit id 708816a. So make sure to check this out using git if you’re trying to follow along and the code and article are out of sync. I’m going to continue to improve and develop PPO-JAX in the future.
Psuedocode
\begin{algorithm} \caption{Proximal Policy Optimization} \begin{algorithmic} \State Initialize Actor and Critic Networks \State Initialize Actor Optimizer with $\alpha_{actor}$ \State Initialize Critic Optimizer with $\alpha_{critic}$ \State T $\leftarrow$ Total time-steps \State N $\leftarrow$ Number of parallel envs \State M $\leftarrow$ Number of minibatches \State S $\leftarrow$ Number of rollout steps \State Epochs $\leftarrow$ Number of training epochs \State Batch Size $\leftarrow$ N $\times$ S \State Minibatch Size $\leftarrow$ $\frac{Batch\,Size}{M}$ \State Iterations $\leftarrow$ $\frac{T}{Batch\,Size}$ // How many batches can you collect over T time steps? \State Initialize Rollout Buffer \For{i in Iterations} \For{t in S} \State Rollout Buffer $\leftarrow$ Collect Rollout Statistics \EndFor \For{t in reverse S} \State Compute $G_t$ using Rollout Buffer // General Advantage Estimate \EndFor \For{epoch in Epochs} \For{minibatch in Range(0, Batch Size, Minibatch Size)} \State Sample Random MiniBatch from Batch \State $\mathcal{L}_{Critic}(\mathbf{\psi})$ $\leftarrow$ $\frac{1}{\text{Minibatch Size}} \sum_{(s, G) \in \text{MiniBatch}} (V(s, \mathbf{\psi}) - G)^2$ \State $\mathcal{J}_{Actor}(\mathbf{\theta})$ $\leftarrow$ $\frac{1}{\text{Minibatch Size}} \sum_{(s, a, G) \in \text{MiniBatch}} \min \left( \rho(\mathbf{\theta}) G, \mathrm{clip}(\rho(\mathbf{\theta}), 1-\epsilon, 1+\epsilon) G \right) + c_{entropy} * \mathcal{H}(\pi_\theta(\cdot|s))$ \State $\mathbf{\psi} \gets \mathbf{\psi} - \alpha_{critic} \nabla_{\mathbf{\psi}} \mathcal{L}_{Critic}(\mathbf{\psi})$ \State $\mathbf{\theta} \gets \mathbf{\theta} + \alpha_{actor} \nabla_{\mathbf{\theta}} \mathcal{J}_{Actor}(\mathbf{\theta})$ \EndFor \EndFor \EndFor \end{algorithmic} \end{algorithm}
Rollout Buffer
This is where we collect the episode values and statistics.
obs = np.zeros((args.num_steps, args.num_envs) +\ envs.single_observation_space.shape) # (S, N, D): D is the observation vector sizeactions = np.zeros((args.num_steps, args.num_envs) +\ envs.single_action_space.shape) # (S, N, A): A is the number of Actionlogprobs = np.zeros((args.num_steps, args.num_envs)) # (S, N)rewards = np.zeros((args.num_steps, args.num_envs)) # (S, N)dones = np.zeros((args.num_steps, args.num_envs)) # (S, N)values = np.zeros((args.num_steps, args.num_envs)) # (S, N)
At first it may seem more intuitive to have the number of environments as the first dimension. But this makes the application of value functions and other rollout statistics across dimensions of the parallel environments given a certain time step easier. In other words, it’s easier to do obs[step] = next_obs than obs[:, step] = next_obs. To drive this point home consider this example:
The assignment operator in the first buffer is easier.
The values are grouped closer together, which potentially improves information retrieval.
Training Loop
for iteration inrange(1, args.num_iterations +1): # number of iterations is calculated during run-time (check psuedo-code)for step inrange(0, args.num_steps): # Rollout steps specified by the user and passed using `tyro` args# Collect Rolloutfor t inreversed(range(args.num_steps)): # Going through the steps backwards and dynamically calculating advantage functions# Compute GAE (General Advantage Esimtation)# prepare rollout batchesfor epoch inrange(args.update_epochs): # train for this number of epochs specified # ...for start inrange(0, args.batch_size, args.minibatch_size): # sample minibatches and optimizes losses over minibatches# Optimize over minibatches
Actor and Critic Networks
To have a value function we need to create a critic neural network. A simple multi-layered perceptron works well for our purposes:
class Critic(nn.Module):@nn.compactdef__call__(self, x): x = nn.Dense(64, kernel_init=layer_init(np.sqrt(2)))(x) x = nn.tanh(x) x = nn.Dense(64, kernel_init=layer_init(np.sqrt(2)))(x) x = nn.tanh(x) x = nn.Dense(1, kernel_init=layer_init(1.0))(x)return x.squeeze(-1)class Actor(nn.Module): action_dim: int@nn.compactdef__call__(self, x): x = nn.Dense(64, kernel_init=layer_init(np.sqrt(2)))(x) x = nn.tanh(x) x = nn.Dense(64, kernel_init=layer_init(np.sqrt(2)))(x) x = nn.tanh(x) x = nn.Dense(self.action_dim, kernel_init=layer_init(0.01))(x)return x
To initialize the model we may do something params = model.init(rng, jnp.ones(obs_shape)) where obs_shape is the dimensions of a single observation sample. You maybe wondering but we are dealing with batches of data and a rollout of steps S. How do we do a forwar pass or (apply in JAX lingo) over a larger dimension. This is where JAX conveniently shines through and gives us vmap:
def value_fn(params, obs): value = jax.vmap(lambda x: critic_state.apply_fn(params, x))(obs) # dyncamically scales function to batch dimensions return valuedef policy_fn(params, obs):return jax.vmap(lambda x: actor_state.apply_fn(params, x))(obs)
Performing Rollout
During rollout we compute statistics state/observation values and sample actions during the policy:
Show Corrosponding Math
\[
\hat{V}(S, \mathbf{\psi})
\]
value = value_fn(critic_state.params, next_obs)
Now sampling an action from a discrete action-space created using logits from the actor network takes more work:
The nuances from using jax.random.split and jax.random.categorical really reflect the gap between going from math symbols and theory to a working code. Make sure to checkout how a rollout is collected in the code.
We want to minimize the state value errors and maximize the policy objective. In the code I call both of these measures “loss”. One is the actor_loss_fn and the other critic_loss_fn. It would be more appropriate to rename actor_loss_fn. Something for future iterations.
Combining the rollout, loss computation, and optimization steps, we get our very own Proximal Policy Optimization algorithm in JAX (Yay! 🎉🤖), ready to be used on some cool control tasks (Checkout Results section next). Here are some final thoughts:
JIT could speed things up a lot when done right but can also slow things down and make it worse. A good rule of thumb is to JIT functions where the parameter sizes are fixes and won’t change during runtime. This ensures that we don’t incur overhead from recompiling the same function over and over.
Optax has a useful feature for reducing learning rate as the episode progresses. Use this to “anneal” your learning rate.
This implementation of PPO fails in the MointainCar task. What could be wrong? Food for thought.
In future iterations it would be better to refactor the code to stop repeatedly calculating the logits multiple times and then sampling from them. A single function that could handle that would be better.
On that note, this implementation “works” but can be improved to be better. And that’s what’s in store for this project to stay tuned :)