Modern large language models are typically trained in three steps:
- First, you pre-train LLMs on a humongous amount of web-scrawled data.
- Then, you supervised fine-tune (SFT) the pre-trained models on a large dataset consisting of “expert” prompt-response pairs.
- Lastly, you “align” the models’ behaviors with human preferences using a technique called reinforcement learning with human feedback (RLHF).
During pre-training, models are effectively turned into general-purpose statistical machines that can decently predict the next token that follows a sequence of tokens. After this initial stage, models are better equipped with the knowledge of how to generate tokens given trillions and gazillions of tokens they have learned from, but they are not quite useful yet for real tasks, so we subsequently instruct them in a supervised way how to generate better responses by exposing them to labeled input-output pairs. However, this is a rather expensive procedure due to the challenges of collecting such datasets, and SFT alone may not be sufficient. The models might have learned how to craft accurate information, yet the model behaviors could certainly diverge from our expectations of how models should respond. The supervised fine-tuned models could generate subpar responses for creative tasks, since the SFT phase had still been all about maximizing the likelihood using autoregressive next word prediction. Even worse, models might generate responses that are biased, harmful, overly formal, or lengthy, just to list some of the undesirable traits for an LLM response. RLHF is a technique that addresses these shortcomings to better align models’ responses with our expectations. The way we do it is by training a reward model by telling them about our preferences, and using it as a proxy to update the LLMs’ behaviors, or “policy,” to borrow RL terminology.
RLHF is not a form of sorcery specifically invented for LLM post-training, but it is undoubted that it has immensely contributed to a revolution that has rendered LLMs more helpful and safer. The pioneer of this revolution, which has traditionally been an integral part of the alignment research, as it pertains to LLM alignment was OpenAI’s InstructGPT, which was trained using RLHF on top of GPT-3. It is shown to generate responses preferred by human lebelers and show improvements in truthfulness of the responses, while still excelling at traditional NLP tasks as tested on conventional benchmarks 11 Ouyang et al. (2022), Training language models to follow instructions with human feedback . Anthropic also showed that RLHF improves LLMs’ performance, but also notably accentuated that the mechanisms of RLHF based on preference modeling (which we’ll discuss 00 later ) could function as a double-edged sword, raising concerns for censorship, fraud, and misinformation22 Bai et al. (2022), Training a helpful and harmless assistant with reinforcement learning from human feedback .
The north star of this post is to lay out the details of RLHF in a granular way, with a particular focus on the optimization algorithms used to update the policy, among which we’ll discuss PPO, GRPO, DPO, and KTO.
Basic concepts in RL
Agent, state, action, reward, and trajectory
An RL agent interacts with the world it is situated in by receiving the (potentially incomplete and fragmentary) information about the world, which we refer to as states. It takes an action given its knowledge of the world, and receives a reward00 It is possible for the reward to be deferred until the end, rather than after each action is taken. , where a high reward indicates to the agent that it was a promising action to take, hence increasing the likelihood of taking the same action in the future. Along with the reward, the agent receives a new state.
To articulate this a little more mathematically, we can first define the state space \(\mathcal S\) and the action space \(\mathcal A\) representing the set of all available states and actions the agent can take. At time step \(t\), the agent has access to state \(s_t\); takes action \(a_t\); and receives reward \(r_t\) and a new state \(s_{t+1}\). A trajectory, denoted as \(\tau\), is a sequence of states, actions, and rewards that has happened. Oftentimes, \(T\) is reserved as a symbol that denotes the terminal timestep, so a trajectory from timestep \(t\) to \(T\) can be expressed as \(\tau = (s_t, a_t, r_t, s_{t+1}, a_{t+1}, r_{t+1}, \cdots, s_T)\).
Policy, policy gradient and basic policy gradient algorithm
A policy of an RL agent, often denoted as \(\pi\), decides which action to take. A policy can be deterministic or stochastic: under a deterministic policy, a state is mapped to a specific action, while under a stochastic policy, the agent obtains the probability of taking a specific action given the state. Mathematically, an agent following a deterministic policy \(\pi\) gets \(\pi(s_t) = a_t\), while it gets \(\pi(a_t \mid s_t) = p\), where \(p \in [0, 1]\), under a stochastic policy. A policy is often parameterized, meaning that learning to optimize a policy involves updating the parameters of a model, which could be a deep learning model, for instance. In this post, we denote such parameters as \(\theta\).
It is also helpful to make a distinction between on-policy and off-policy RL. On-policy learning updates the policy using the data generated by the same policy. For off-policy learning, it is not necessary for the data to have been generated by the same policy. The pros and cons are clear: on-policy learning could be more effective, since there is an organic relationship between the data used for training and the policy itself, while it is less sample-efficient; off-policy learning is advantageous in terms of sample efficiency, but due to the data relevance or shifts, the training could be less stable.
Consider a trajectory \(\tau = (s_1, a_1, r_1, \cdots, s_T)\), where \(T\) could potentially be \(\infty\) (to be a little handwavy), in which case we have an infinite trajectory. The training objective of RL is to maximize the expected cumulative reward over \(\tau\). There is some design choice concerning how we should define the cumulative reward: should we assign equal weight to \(r_t\) and \(r_i\) for \(i > t\), or should we prioritize \(r_t\)? We can represent both scenarios in a single equation by incorporating a discount factor \(\gamma \in (0, 1]\), which controls how much weight is given to future rewards. In the formula below, \(R(\tau)\) is the reward function representing the cumulative return over the trajectory \(\tau\): \begin{align} R(\tau) = \sum_{i=1}^{T} \gamma^i r_i. \end{align} If \(\gamma = 1\), we have \(R(\tau) = r_1 + r_2 + \cdots + r_T\), with all \(r_i\)’s equally weighted. On the other hand, as \(\gamma \to 0\), we more heavily weigh the immediate rewards. (YOLO!) Instead of considering the entire trajectory, that is, \(1 \leq i \leq T\), we can consider a smaller interval \(t \leq i \leq T\) and define reward-to-go as follows, which tends to be more stable:
$$ \begin{align*} G_t = \sum_{i=t}^T \gamma^{i-t} r_i. \end{align*} $$Using \(G_t\) instead of \(R(\tau)\) offers the benefit of lower variance while still having the estimator unbiased. For now, we’ll stick to \(R(\tau)\). We’ll get a chance to revisit a relevant discussion in a section or two.
This naturally leads to the definition of the expected cumulative reward \(J(\pi_\theta) = \mathbb E_{\tau \sim \pi_\theta} [R(\tau)]\), the objective we try to maximize in training with RL00 Here, I use \(R(\tau)\) but it is also commonly replaced with \(G_t\). In machine learning, \(J(\theta)\) is, as far as I’m aware, a common notation for the cost function of a model parameterized by \(\theta\). , with the optimal parameters being \(\theta^*\). To update \(\theta\), under this formulation, we can use the standard gradient ascent or descent, \(\theta \leftarrow \theta + \alpha \nabla_\theta J(\pi_\theta)\). The subscript under \(\mathbb E\) also merits a closer look: our expectation is over the trajectories sampled from \(\pi_\theta\).
We are now ready to formally define the policy gradient, but in fact, we have already introduced it in the previous paragraph. The policy gradient is the gradient of the expected cumulative reward, \begin{align} \nabla_\theta J(\pi_\theta) = \nabla_\theta \mathbb E_{\tau \sim \pi_\theta} [ R(\tau) ], \end{align} and policy gradient algorithms are algorithms used to optimize our policy.
Consider an RL agent exploring an infinitely expanding 1-indexed 2D grid, with traps (negative rewards) and treasures (positive rewards) whose locations are fixed. Each cell of the grid is identified as \((i, j)\) where \(i\) and \(j\) are positive integers. The state space \(\mathcal S\) is \(\mathbb Z_{\geq 0} \times \mathbb Z_{\geq 0}\) and the action space is \(\{\texttt{R}, \texttt{L}, \texttt{U}, \texttt{D}\}\), where \(\texttt{R}\) takes the agent from \((i, j)\) to \((i+1, j)\), \(\texttt{U}\) takes the agent from \((i, j)\) to \((i, j-1)\), and so forth. The agent’s current state is \((r, c)\), and there is a trap in the cell to the right of the agent, which can deplete the agent’s energy immediately, programmed as \(r_t = \texttt{float(-inf)}\).
If we have set \( \pi_\theta(\texttt{R} \mid (r, c)) = \pi_\theta(\texttt{L} \mid (r, c)) = \pi_\theta(\texttt{U} \mid (r, c)) = \pi_\theta(\texttt{D} \mid (r, c)) = 0.25 \) as the prior, the agent initially (and unfortunately) has a one-quarter probability of falling into the trap. Since \(\texttt{-inf}\) is the dead-end in terms of the reward, as part of the policy optimization, we want the agent to learn to avoid selecting \(\texttt{R}\) when it’s located in \((r, c)\). In other words, we want \(\pi_\theta (\texttt{R} \mid (r, c))\) to be \(0\).
Note that we have assumed a live exploration scenario, but the discussion is essentially identical even if we consider maximizing the expected cumulative reward given trajectories sampled from \(\pi_\theta\).
REINFORCE and baseline
Let’s derive the actual expression that we can implement to optimize \(\theta\). \((2)\) is mathematically simple, but its expectation is over \(\tau \sim \pi_\theta\). We can call the most basic version of our algorithm to optimize our policy, well, the vanilla policy gradient algorithm, but it actually has a better name: REINFORCE.
$$ \begin{alignat*}{2} \nabla_\theta \mathbb E_{\tau \sim \pi_\theta} [R(\tau)] &= \nabla_\theta \int_{\tau} p(\tau \mid \theta) R(\tau) d\tau \\ &= \int_\tau \nabla_\theta \, p(\tau \mid \theta) R(\tau) d\tau \\ &= \int_\tau p(\tau \mid \theta) \nabla_\theta \log p(\tau \mid \theta) R(\tau) d\tau & \quad (a) \\ &= \mathbb E_{\tau \sim \pi_\theta} \left[ \nabla_\theta \log p(\tau \mid \theta) R(\tau) \right] \\ &= \mathbb E_{\tau \sim \pi_\theta} \left[ \nabla_\theta \log \left( \mu(s_1) \prod_{i=1}^T \pi_\theta (a_i \mid s_i) \, p(s_{i+1} \mid s_i, a_i) \right) R(\tau) \right] \\ &= \mathbb E_{\tau \sim \pi_\theta} \left[ \sum_{i=1}^T \nabla_\theta \log \pi_\theta (a_i \mid s_i) R(\tau) \right] & (b) \\ &\approx \frac{1}{N} \sum_{j=1}^{N} \sum_{i=1}^T \nabla_\theta \log \pi_\theta (a_{ji} \mid s_{ji}) R(\tau_j), & (c) \end{alignat*} $$where \(a_{ji}\) and \(s_{ji}\) denote \(a_i\) and \(s_i\) from \(\tau_j\). Here are some more details:
- \((a)\) Log-derivative trick is use here, which I’ve also written about in my VAE post.
- \((b)\) When \(\log\) expands to each term, \(\pi_\theta (\cdot)\) is the only term that depends on \(\theta\), making every other term cancellable.
- \((c)\) This is an unbiased estimator.
So far so good, but REINFORCE has some significant shortcomings:
In updating the parameters based on this gradient, we are using the reward of the entire trajectory, as evidenced by the \(R(\tau_j)\) at the end. Thus naturally, it has a large variance, which could render the training unstable.
Takeaway: We need to modify the algorithm to reduce variance.If the trajectory is too lengthy, this will lead to relatively few gradient updates, hence REINFORCE is sample-inefficient.
Takeaway: We need a more sample-efficient approach.
How can we reduce variance? A good starting point is replacing the reward \(R(\tau)\) only computed at the end of the trajectory with reward-to-go \(G_t\):
$$ \begin{align*} \nabla_\theta J(\pi_\theta) = \mathbb E_{\tau \sim \pi_\theta} \left[ \sum_{i=1}^T \nabla_\theta \log \pi_\theta (a_i \mid s_i) \, G_t \right]. \end{align*} $$Notice that the root cause of a high variance is the rewards fluctuating wildly. Replacing \(R(\tau)\) with \(G_t\) helps, but what if we can go one step further? 00 Baselines often do a great job as a variance reduction technique. Nevertheless, implementing the optimal baseline is not straightforwward. Moreover, as Chung et al. (2021) point out, introducing baselines may also affect learning dynamics in unexpected ways. . Take a look at the following modified expression where we’ve replaced \(G_t\) with \(G_t - b(s_i)\):
$$ \begin{align*} \nabla_\theta J(\pi_\theta) = \mathbb E_{\tau \sim \pi_\theta} \left[ \sum_{i=1}^T \nabla_\theta \log \pi_\theta (a_i \mid s_i) \left( G_t - b(s_i) \right) \right]. \end{align*} $$The term \(b(s_i)\) is called the baseline. Unlike REINFORCE, where the gradient was scaled by the absolute reward, this version changes the scaling factor to a number representing how much better or worse the outcome for a certain state was, relative to a baseline. Introducing \(b(s_i)\) can center the signal and thus reduce variance.
What is a reasonable baseline?
$$ \begin{align*} V^{\pi_\theta}(s_t) = \mathbb E_{\tau \sim \pi_\theta} \left[ G_t \mid s_t = s \right]. \end{align*} $$
A commonly used baseline is the on-policy value function, which gives the expected return given a (static) policy starting from a state:If starting from \(s_0\), \(V^{\pi_\theta} (s_0) = \mathbb E_{\tau \sim \pi_\theta} \left[ R(\tau) \right] \) by definition. Compare \(b(s_i)\) and \(V^{\pi_\theta} (s_t)\), and their similarity is apparent.
But where do we obtain the value function?
Using \(V^{\pi_\theta}\) as the baseline only makes sense when we actually have the knowledge of the expected return starting from an arbitrary state, which is clearly not provided to us a priori. Well, if we don’t know what it is, then we can learn it! In practice, we train a separate neural network parameterized by \(\phi\) to estimate the value function. This model that learns to predict \(V^{\pi_\theta}\) is often called a critic. Recall that the baseline should be a good representation of the actual reward-to-go, which is obtained from actual samples. The critic’s training objective for \(V_\phi\) is thus to minimize the difference between \(G_t\) and \(V_\phi(s_t)\): the smaller the value of \(G_t - b(\cdot)\) is, the smaller policy gradient we have. If you hear “actor-critic” method, it is precisely referring to this family of algorithms where the actor (agent) learns to optimize its policy using the \(V\)-signals provided by the critic.Wouldn’t it shrink the gradient?
This is the beautiful part — the answer is no. Baselines don’t affect the expected value.
From our formula for \(\nabla_\theta J(\pi_\theta)\) we have above, let’s focus on the second term that contains the baseline:
$$ \begin{align*} \mathbb E_{\tau \sim \pi_\theta} \left[ \sum_{i=1}^T \nabla_\theta \log \pi_\theta (a_i \mid s_i) \, b(s_i) \right] = \sum_{i=1}^T \mathbb E_{\tau \sim \pi_\theta} \left[ \nabla_\theta \log \pi_\theta (a_i \mid s_i) \, b(s_i) \right]. \end{align*} $$For simplicity, let \(X_i \triangleq \nabla_\theta \log \pi_\theta (a_i \mid s_i) \, b(s_i) \). If we fix a state \(s_i\) and take the expectation over \(a \sim \pi_\theta ( \,\cdot \mid s) \),
$$ \begin{align*} \mathbb E_{\tau \sim \pi_\theta} \left[ X_i \right] &= \mathbb E_{s_i \sim d_i^{\pi_\theta} } \left[ \mathbb E_{a_t \sim \pi_\theta (\cdot \mid s_i)} \left[ X_i \mid s_i \right] \right] \end{align*} $$where \(d_i^{\pi_\theta}(s)\) denotes the marginal distribution over \(s_i\) in the trajectory under policy \(\pi_\theta\). Continuing with the derivation,
$$ \begin{alignat*}{2} \mathbb E_{s_i \sim d_i^{\pi_\theta} } \left[ \mathbb E_{a_i \sim \pi_\theta (\cdot \mid s_i)} \left[ \nabla_\theta \log \pi_\theta (a_i \mid s_i) b(s_i)\right] \right] &= \mathbb E_{s_i \sim d_i^{\pi_\theta} } \left[ \int_{\mathcal A} \pi_\theta(a_i \mid s_i) \nabla_\theta \log \pi_\theta (a_i \mid s_i) \, b(s_i) \, da_i \right] \\ &= b(s_i) \, \mathbb E_{s_i \sim d_i^{\pi_\theta} } \left[ \int_{\mathcal A} \nabla_\theta \pi_\theta(a_i \mid s_i) \, da_i \right] & \quad (a) \\ &= b(s_i) \, \mathbb E_{s_i \sim d_i^{\pi_\theta} } \left[ \nabla_\theta \, 1 \right] & \quad (b) \\ &= 0 \end{alignat*} $$Some details:
- \((a)\) Given \(s_i\), \(b(s_i)\) is constant. We’ve also applied the log-derivative trick.
- \((b)\) \(\nabla_\theta\) is a linear operator. Assuming a continuous action space \(\mathcal A\), the expression evaluates to 1, hence \(\nabla_\theta 1 = 0\).
Therefore, given a fixed state, baselines are independent on the sampled action! On average, it computes to zero, therefore
$$ \begin{align*} \mathbb E_{\tau \sim \pi_\theta} \left[ \sum_{i=1}^T \nabla_\theta \log \pi_\theta (a_i \mid s_i) \left( G_t - b(s_i) \right) \right] = \mathbb E_{\tau \sim \pi_\theta} \left[ \sum_{i=1}^T \nabla_\theta \log \pi_\theta (a_i \mid s_i) \, G_t \right]. \end{align*} $$Some notational sidenotes and conceptual clarifications
So far, we’ve laid out all the necessary details for RL. When I was first learning about these concepts, I was fairly confused about the notations, some of which even appeared to be interchangeable at times. So I’d like to dedicate a short section to clarify on the notations.
Let’s start again with the policy gradient \(\nabla_\theta J(\theta)\). In the previous section, we introduced two versions, with the latter often having a benefit of reduced variance:
$$ \begin{align*} \mathbb E_{\tau \sim \pi_\theta} \left[ \sum_{i=1}^T \nabla_\theta \log \pi_\theta (a_i \mid s_i) \, R(\tau) \right] \quad \text{and} \quad \mathbb E_{\tau \sim \pi_\theta} \left[ \sum_{i=1}^T \nabla_\theta \log \pi_\theta (a_i \mid s_i) \, G_t \right]. \end{align*} $$So it turns out the only difference really lies in what comes after the gradient expression. As beautifully explained originally here, there, hither, and yon, we can define a generalized version of the policy gradient as:
$$ \begin{align*} \nabla_\theta J(\theta) = \mathbb E_{\tau \sim \pi_\theta} \left[ \sum_{i=1}^T \nabla_\theta \log \pi_\theta (a_i \mid s_i) \, \Psi_t \right] \end{align*} $$where \(\Psi_t\) can be any of the following six forms:
- \(R(\tau)\): total reward of the trajectory
- \(G_t\): total reward of the trajectory starting from \(s_t\), introduced as “reward-to-go”
- \(G_t - b(s_t)\): reward-to-go with baseline
- \(Q^{\pi_\theta} (s_t, a_t)\): state-action value function
- \(A^{\pi_\theta} (s_t) \triangleq Q^{\pi_\theta} (s_t, a_t) - V^{\pi_\theta} (s_t)\): advantage function
- \(\delta_t \triangleq r_t + \gamma V^{\pi_\theta} (s_{t+1}) - V^{\pi_\theta} (s_t)\): temporal difference (TD) residual, where \(\gamma \in (0, 1]\) is the discount factor
More terms have been introduced but (I promise) we will construct a nice set of pairwise conceptual bridges very soon.
- \(Q\) is basically \(V\) after taking an action: The state-action value function is not a novel concept. While \(V^{\pi_\theta} (s_t)\) denotes the expected cumulative return starting at state \(s_t\) and following policy \(\pi\) (which is parameterized by \(\theta\)), \(Q^{\pi_\theta} (s_t, a_t)\) denotes the same quantity but crucially after taking action \(a_t\) at state \(s_t\), and then following policy \(\pi\). Formally, \(Q^{\pi_\theta} (s_t, a_t) = \mathbb E \left[ G_t \mid s = s_t, a = a_t \right] \).
- \(G_t\) and \(b(s_t)\) are Monte Carlo estimates: \(G_t\) and \(Q^{\pi_\theta} (s_t, a_t)\) look like twins, and so do \(b(s_t)\) and \(V^{\pi_\theta} (s_t)\). This is because \(G_t\) and \(b(s_t)\) are Monte Carlo estimates of \(Q^{\pi_\theta}(s_t, a_t)\) and \(V^{\pi_\theta}(s_t)\). This provides a justification to why we have \(G_t - b(s_t)\) as a possible form of \(\Psi_t\).
- \(A^{\pi_\theta}\) measures how good an action is: The advantage function is defined as \(Q^{\pi_\theta} (s_t, a_t) - V^{\pi_\theta} (s_t)\). Recall that our immediate TODO in a gradient ascent at step \(t\) is to help our agent learn how to take the best action \(a_t\). The advantage function tells us how “advantageous” \(a_t\) is: if it is positive, it means that \(a_t\) was a promising action to take (since it is larger than the expected reward over all possible actions). The Monte Carlo estimate for \(A^{\pi_\theta}\) is \(\hat{A}_t = G_t - b(s_t)\).
- TD is a method that uses bootstrapping and updates more frequently: Previously, we replaced \(R(\tau)\) with \(G_t\) as a variance reduction technique. \(G_t\) still relies on the trajectory from time \(t\) until terminal. But what if the trajectory never ends (in other words, it’s continuous and not episodic)? The TD method is an effective solution here: it updates after each step, rather than waiting until the end. The change is simple: instead of using \(G_t\) or \(Q^{\pi_\theta}(s_t, a_t)\), we replace it with \(r_t + \gamma V^{\pi_\theta}(s_{t+1})\). As a result, this results in a form that is essentially the advantage, commonly using \(V_\phi\) estimated by the critic. It should be additionally noted that the TD residual is biased, and it is unknown if the TD method converges faster than the Monte Carlo methods. (Empirically, it does converge faster, though.)00 There is an interesting section on Wikipedia on the relationship between TD learning and neuroscience
So that’s it! We are now ready to dive into other policy gradient algorithms.