Hi! đź‘‹ I'm an AI professional and researcher passionate about meaningful applications and insights, with interests spanning GenAI, LLMs, NLP applied to Arabic, and Operations Research. Happy Reading! đź“–
abderrahman.skiredj@um6p.ma / @ocpsolutions.ma | Google Scholar | LinkedIn | HuggingFace
Adapting large language models (LLMs) to specific tasks often involves prompting, Retrieval-Augmented Generation (RAG), or agentic systems. Prompting suits quick, general tasks but falters in complex reasoning or specialization. RAG excels with external knowledge but struggles to teach new skills or control output style. Agentic systems fit dynamic goals yet can overcomplicate simpler needs. Group Relative Policy Optimization (GRPO), a DeepSeek reinforcement learning method 1, is ideal when deep domain expertise, precise style and tone control, specific output formatting, or debiasing are required—particularly for reasoning-intensive tasks without clear answers, as shown in DeepSeekMath 2. This paper offers a clear, comprehensive guide to GRPO, blending theory, math, and practical steps. Where existing resources scatter or omit details, we provide a unified, pedagogical resource to unlock GRPO’s potential for fine-tuning LLMs effectively.
The paper is organized into four main sections to provide a comprehensive understanding of GRPO from theory to practice. The first section offers a theoretical deep dive, detailing the algorithm’s mechanics with rigor and intuitive explanations. The second section serves as a practical tutorial, guiding readers through a quick application of GRPO using the TRL library with simplified steps and examples. The third section presents a simplified, didactic implementation of GRPO, designed for clarity and educational purposes, using a small model and basic prompts. Finally, the fourth section explores an optimized, industrial-grade implementation from the TRL library, mapping theoretical steps to production-ready code.
Overview of the GRPO Algorithm
Group Relative Policy Optimization (GRPO) fine-tunes a language model by iteratively improving its policy through group-based reward comparisons. The algorithm proceeds as follows:
These steps are illustrated in Figure 1, which applies them to fine-tune an LLM on mathematical reasoning. Let us now delve into the details of each step.
Step 1: Prepare a Batch of Training Queries
Take a batch of training queries ${q_1, q_2, \dots, q_B}$, where $B$ is the batch size. These are questions or prompts the model will respond to.
Step 2: Sample $G$ Outputs for a Single Query
For simplicity, consider a single query $q$ from the batch. Using the current policy model with parameters $\theta_{old}$ (denoted $\pi_{\theta_{old}}$), generate $G$ different outputs ${o_1, o_2, \dots, o_G}$. Each output $o_i$ is a sequence of tokens:
where $|o_i|$ is the length of the sequence.
Why G Outputs? Sampling multiple outputs allows GRPO to compare them relative to each other, forming a group-based baseline for rewards.
Step 3: Calculate Rewards and Advantages
To make this more concrete, DeepSeek uses a rule-based strategy tailored to each task, such as math or coding. The reward $r_i$ for an output $o_i$ is computed using a weighted combination:
$r_i = \alpha \cdot accuracy\_score + \beta \cdot format\_score,$
where $\alpha$ and $\beta$ are task-specific weights balancing correctness and structure.
For the accuracy score, math tasks use regular expressions to extract the final answer and compare it to the ground truth (1 if correct, 0 otherwise). Coding tasks run the code in a sandbox and assign a score based on how many test cases pass. The format score, on the other hand, checks whether the output follows the expected structure—such as including reasoning within specific tags like <think>
—and is typically binary (1 if well-structured, 0 if not).
Terms:
The idea is that by normalizing the reward relative to the group, indicating how much better or worse $o_i$ is compared to the average, the model learns to favor responses with $A_i > 0$ and suppress those with $A_i < 0$. For instance, if $A_i = 0.94$, the model increases the likelihood of generating that (correct) response.
Note that in standard GRPO (outcome supervision), the advantage $A_i$ is the same for all tokens $o_{i,t}$ in output $o_i$. So, $A_{i,t} = A_i$ for all $t = 1, 2, \dots, |o_i|$. This is because the reward $r_i$ is given for the entire output $o_i$, not per token or step.
Exception: In process supervision (not standard GRPO), rewards are given per reasoning step, and advantages could vary per token or segment. But for this explanation, we assume outcome supervision, so $A_{i,t} = A_i$.
Step 4: Compute the Surrogate Loss
Probability Ratio: For each token $o_{i,t}$ in output $o_i$, compute the ratio of probabilities between the current policy $\pi_{\theta}$ and the old policy $\pi_{\theta_{old}}$: $ratio_{i,t} = \frac{\pi_{\theta}(o_{i,t} \mid q, o_{i,<t})}{\pi_{\theta_{old}}(o_{i,t} \mid q, o_{i,<t})}$
Terms:
The idea is to measure how much the policy has changed for that token.
Clipped Objective: Define the clipped term: $g(\epsilon, A_i) = clip(ratio_{i,t}, 1 - \epsilon, 1 + \epsilon) \cdot A_i$
Terms:
The idea is to limit large policy updates for stability.
Loss per Token: For each token $o_{i,t}$: $L_{i,t} = \min \left( ratio_{i,t} \cdot A_i, \; g(\epsilon, A_i) \right)$
Terms:
The idea is to take the minimum to conservatively update the policy: The clipping restricts the policy update ratio to $[1 - \epsilon, 1 + \epsilon]$ to avoid large shifts from the old policy. This in particular limits overconfident updates.
Example with $\epsilon = 0.2$: if $\pi_{\theta}(o_i|q) = 0.9$, $\pi_{old}(o_i|q) = 0.5$, then ratio $= 1.8 \rightarrow$ clip to 1.2. If new policy gives 0.2, then $0.2 / 0.5 = 0.4 \rightarrow$ clip to 0.8.
Total Surrogate Loss: Average over all tokens and outputs: $L_{GRPO}(\theta) = \frac{1}{G} \sum_{i=1}^{G} \frac{1}{|o_i|} \sum_{t=1}^{|o_i|} L_{i,t}$
Terms:
KL Divergence Penalty: Add a penalty to prevent large deviations from a reference policy $\pi_{ref}$ (e.g., initial policy): $L_{total}(\theta) = L_{GRPO}(\theta) - \beta D_{KL}[\pi_\theta || \pi_{ref}]$
Terms:
The idea is to ensure stability by keeping $\pi_\theta$ close to $\pi_{ref}$. A KL divergence penalty keeps the model’s outputs near the original distribution, preventing extreme shifts while still allowing controlled exploration and refinement.
The $\beta$ parameter controls the strength of the KL divergence penalty. A higher $\beta$ keeps the policy close to the reference, ensuring stability but slowing exploration. A lower $\beta$ allows faster adaptation and more deviation, but risks instability or reward hacking. The original DeepSeekMath paper used $\beta= 0.04$.
Step 5: Backpropagate and Update the Policy
Update: Use an optimizer (e.g., Adam) to adjust $\theta$: $\theta \leftarrow \theta - \eta \nabla_\theta L_{total}(\theta)$
Where $\eta$ is the Learning rate (e.g., $10^{-5}$).
The idea is to minimize the loss, effectively maximizing the expected reward by adjusting token probabilities.
Summary of Key Formulas
Limitations & Challenges of GRPO
The following passage is taken directly from the HuggingFace Reasoning Course3 :
Here’s what you need to know to get started as soon as possible with your own GRPO use case!
Alongside the theoretical concepts outlined earlier, we present a pedagogically focused version of the excellent practical tutorial available online 4, which utilizes the TRL library’s implementation of the GRPO Algorithm 5. This section emphasizes the core components to help you hit the ground running.
The tutorial uses the GSM8K dataset, a collection of math word problems designed to test reasoning skills. Each entry consists of a question and an answer, with the final numerical solution typically marked by #
in the answer text.
To give a clearer picture, here are two sample entries from the dataset:
Sample 1:
# 72
Sample 2:
# 10
The first step is to extract the final answer (e.g., “72” or “10”) for evaluation. We define a function to do this:
def extract_hash_answer(text):
if "####" not in text:
return None
return text.split("####")[1].strip()
For instance, applying this to Sample 1’s answer yields “72”.
Next, we define a system prompt to guide the model’s output format, encouraging it to show its reasoning and provide a solution within specific tags:
reasoning_start = "<start_working_out>"
reasoning_end = "<end_working_out>"
solution_start = "<SOLUTION>"
solution_end = "</SOLUTION>"
system_prompt = f"""You are given a problem.
Think about the problem and provide your working out.
Place it between {reasoning_start} and {reasoning_end}.
Then, provide your solution between {solution_start}{solution_end}"""
The dataset is then mapped to pair each question with this system prompt and the extracted answer, preparing it for GRPO training.
Reward functions are the heart of GRPO, as they evaluate the quality of the model’s outputs and drive policy optimization.
The tutorial defines multiple reward functions, but we highlight two key examples here for clarity.
The first, match_format_exactly
, awards points if the output adheres precisely to the expected structure, including both reasoning and solution sections:
def match_format_exactly(completions, **kwargs):
scores = []
for completion in completions:
score = 0
response = completion[0]["content"]
if match_format.search(response) is not None:
score += 3.0
scores.append(score)
return scores
Here, match_format
is a regular expression ensuring the presence of all required tags in the correct order (defined earlier in the code, omitted here for brevity). An output like:
<start_working_out>Let’s think!<end_working_out><SOLUTION>42</SOLUTION>
would score 3.0, while a malformed response would score 0.
The second function, check_answer
, evaluates the correctness of the solution by comparing the extracted answer to the ground truth:
def check_answer(prompts, completions, answer, **kwargs):
scores = []
for completion, true_answer in zip(completions, answer):
score = 0
response = completion[0]["content"]
guess = match_format.search(response).group(1) if match_format.search(response) else None
if guess == true_answer:
score += 3.0
elif guess.strip() == true_answer.strip():
score += 1.5
scores.append(score)
return scores
This function awards 3.0 for an exact match (e.g., “72” vs. “72”), 1.5 for a match ignoring whitespace, and 0 otherwise. Additional reward functions (e.g., partial format matching or numerical closeness) enhance flexibility.
Training is performed using the GRPOTrainer
from the TRL library 5. The Gemma3 model, enhanced with LoRA adapters via Unsloth for efficient fine-tuning, is trained on the prepared dataset with the defined reward functions. Key hyperparameters are set as follows:
from trl import GRPOConfig, GRPOTrainer
training_args = GRPOConfig(
learning_rate=5e-6,
per_device_train_batch_size=1,
num_generations=4, # Number of outputs G per query
max_steps=50,
max_prompt_length=256,
max_completion_length=768,
output_dir="outputs",
)
trainer = GRPOTrainer(
model=model,
processing_class=tokenizer,
reward_funcs=[match_format_exactly, check_answer],
args=training_args,
train_dataset=dataset,
)
trainer.train()
Here, num_generations=4
corresponds to $G$ in our theoretical explanation, generating four outputs per query to compute group-based advantages.
Some Practical tips include:
per_device_train_batch_size
and gradient_accumulation_steps
to fit GPU memory; enable use_vllm=True
for faster generation if supportedreward
(average across completions), reward_std
(variation within groups), and kl
(divergence from reference model)This concludes the preparation and training procedure, where reward-guided optimization plays a central role in refining the model’s ability to reason and answer accurately.
In this section, we present a simplified yet functional implementation of the GRPO (Generalized Reward Policy Optimization) algorithm, fully taken from the HuggingFace Reasoning Course 3, but rearranged for improved pedagogical clarity. It bridges the theoretical framework outlined in Section GRPO Algorithm: Deep Dive with concrete code, using the small model Qwen/Qwen2-Math-1.5B
and a basic math prompt for focus and accessibility. In the subsequent section, we explore the optimized, industrial-grade implementation of GRPO provided by HuggingFace’s TRL library, illustrating how the same algorithm scales to production-ready use cases.
This stage corresponds to Step 1 and Step 2 of the theoretical overview. We load a pre-trained language model and generate multiple responses for a batch of prompts.
We use two prompts:
Here, the batch size $B = 2$ (two queries), and we generate $G = 4$ responses per query, totaling 8 responses.
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
# Load the model and tokenizer
model_name = "Qwen/Qwen2-Math-1.5B"
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
model.eval()
# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# Input prompts (batch of 2 queries)
prompts = [
"Solve y = 2x + 1 for x = 2, y = ", # Correct answer: 5
"Solve y = 2x + 1 for x = 4, y = " # Correct answer: 9
]
inputs = tokenizer(prompts, return_tensors="pt", padding=True)
input_ids = inputs["input_ids"].to(device) # Shape: (2, prompt_len)
attention_mask = inputs["attention_mask"].to(device)
# Generate 4 responses per prompt (B=2, G=4, total 8 responses)
batch_size = len(prompts) # 2
num_generations = 4
outputs = model.generate(
input_ids=input_ids, # Shape: (2, prompt_len)
attention_mask=attention_mask,
max_new_tokens=1, # Single-token response
num_return_sequences=num_generations, # 4 per prompt
do_sample=True,
top_k=10,
temperature=0.7,
pad_token_id=tokenizer.eos_token_id,
return_dict_in_generate=True,
output_scores=True,
)
Comments:
Qwen/Qwen2-Math-1.5B
and its tokenizer, representing the current policy $\pi_{\theta_{old}}$ (Step 1). The model is set to evaluation mode and moved to the GPU if available.input_ids
with shape $(2, prompt\_len)$, matching Step 1’s batch of queries ${q_1, q_2}$.model.generate
call produces $G = 4$ responses per prompt. With input_ids
of shape $(2, prompt\_len)$ and num_return_sequences=4
, it generates $2 \times 4 = 8$ total responses (Step 2). The max_new_tokens=1
ensures single-token outputs (e.g., “5”, “9”). Sampling parameters (top_k=10
, temperature=0.7
) ensure diversity. Example output:
This stage implements Step 3: assigning rewards, computing group-wise statistics, and calculating advantages.
For the generated responses:
We use a binary reward: $r_i = 1$ if correct, $0$ otherwise:
# Rewards for the 8 responses (flattened)
rewards = torch.tensor([1, 0, 0, 1, 0, 0, 1, 1], dtype=torch.float32) # Shape: (8,)
# Note: In practice, rewards are computed by comparing generated tokens to correct answers (5, 9)
# Group rewards: Shape (B, G) = (2, 4)
rewards_grouped = rewards.view(batch_size, num_generations)
# Mean per group: Shape (B,) = (2,)
mean_grouped_rewards = rewards_grouped.mean(dim=1)
# Std per group: Shape (B,) = (2,)
std_grouped_rewards = rewards_grouped.std(dim=1)
# Broadcast to match rewards: Shape (8,)
mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(num_generations)
std_grouped_rewards = std_grouped_rewards.repeat_interleave(num_generations)
# Advantages: Shape (8,)
advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-8)
# Reshape to match logits: Shape (8, 1)
advantages = advantages.unsqueeze(1)
Explanation:
rewards_grouped
becomes $(2, 4)$:This stage implements Step 4 (surrogate loss) and Step 5 (policy update), using advantages to refine the model.
import torch.nn.functional as F
# Assume log probs are available (Shape: (8, 1))
# In practice, computed by passing outputs through old and new models
per_token_logps = ... # Old policy log probs
new_per_token_logps = ... # New policy log probs
# Probability ratio: Shape (8, 1)
ratio = torch.exp(new_per_token_logps - per_token_logps)
# Clipping
eps = 0.2
pg_losses1 = -advantages * ratio
pg_losses2 = -advantages * torch.clamp(ratio, 1.0 - eps, 1.0 + eps)
pg_loss_max = torch.max(pg_losses1, pg_losses2)
# KL penalty: Shape (8, 1)
per_token_kl = F.kl_div(
F.log_softmax(new_per_token_logps, dim=-1),
F.softmax(per_token_logps, dim=-1),
reduction="none",
).sum(dim=-1, keepdim=True)
# Total loss
beta = 0.01
per_token_loss = pg_loss_max + beta * per_token_kl
total_loss = per_token_loss.mean()
# Update model
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
optimizer.zero_grad()
total_loss.backward()
optimizer.step()
Explanation:
In this section, we explore how the GRPO algorithm is implemented in the TRL library’s GRPOTrainer
class. The difference between the previous section is that the following code is an optimized and industrial-grade implementation of GRPO provided by HuggingFace’s TRL library. Each step outlined in Section 1 is meticulously mapped to specific methods and code segments, providing a clear bridge between theory and practice. We use the code from the TRL library (version as of April 2025).
What It Does in Theory: The first step involves preparing a batch of training queries ${q_1, q_2, \dots, q_B}$, where $B$ is the batch size. These queries serve as the prompts that the model will respond to, forming the foundation for subsequent steps.
Implementation in TRL: In the GRPOTrainer
class, this step is handled by the data loading mechanism inherited from the Trainer
class in the transformers
library, customized with a special sampler. The _get_train_sampler
method defines a RepeatRandomSampler
that prepares batches in a unique way:
def _get_train_sampler(self) -> Sampler:
effective_batch_size = (
self.args.per_device_train_batch_size
* self.accelerator.num_processes
* self.args.gradient_accumulation_steps
)
return RepeatRandomSampler(
data_source=self.train_dataset,
mini_repeat_count=self.num_generations,
batch_size=effective_batch_size // self.num_generations,
repeat_count=self.num_iterations,
seed=self.args.seed,
)
train_dataset
contains the prompts (stored under the key "prompt"
).RepeatRandomSampler
repeats each prompt num_generations
times (denoted $G$ in the theory) within each batch. This ensures that for every unique prompt $q_i$, there are $G$ instances in the batch, allowing the generation of multiple outputs later.effective_batch_size
accounts for the number of devices and gradient accumulation steps, ensuring scalability across distributed setups. The number of unique prompts per batch is $effective\_batch\_size / G$.repeat_count=self.num_iterations
parameter allows the same batch to be reused across multiple optimization steps, a feature unique to GRPO for efficiency.This setup guarantees that the batch is structured to support the generation of $G$ outputs per query, aligning with Step 2. The sampler’s design also ensures consistency across processes in distributed training, which is crucial for reward normalization later.
What It Does in Theory: For each query $q$ in the batch, the current policy $\pi_{\theta_{old}}$ generates $G$ different outputs ${o_1, o_2, \dots, o_G}$, where each $o_i$ is a sequence of tokens.
Implementation in TRL: This step occurs in the _generate_and_score_completions
method, called within _prepare_inputs
during training:
def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]:
mode = "eval" if self.control.should_evaluate else "train"
if mode == "train":
buffer_index = self._step % self.args.gradient_accumulation_steps
buffered_inputs = self._buffered_inputs[buffer_index]
if self.state.global_step % self.num_iterations == 0 or buffered_inputs is None:
inputs = self._generate_and_score_completions(inputs)
self._buffered_inputs[buffer_index] = inputs
else:
inputs = buffered_inputs
self._step += 1
else:
inputs = self._generate_and_score_completions(inputs)
return inputs
Inside _generate_and_score_completions
:
prompts = [x["prompt"] for x in inputs]
prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs]
prompt_inputs = self.processing_class(
text=prompts_text, return_tensors="pt", padding=True, padding_side="left", add_special_tokens=False
)
prompt_inputs = super()._prepare_inputs(prompt_inputs)
prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]
if self.args.use_vllm:
all_prompts_text = gather_object(prompts_text)
if self.accelerator.is_main_process:
ordered_set_of_prompts = all_prompts_text[:: self.num_generations]
with profiling_context(self, "vLLM.generate"):
completion_ids = self.vllm_client.generate(
prompts=ordered_set_of_prompts,
n=self.num_generations,
max_tokens=self.max_completion_length,
# ... other sampling parameters ...
)
completion_ids = broadcast_object_list(completion_ids, from_process=0)
process_slice = slice(
self.accelerator.process_index * len(prompts),
(self.accelerator.process_index + 1) * len(prompts),
)
completion_ids = completion_ids[process_slice]
completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids]
completion_ids = pad(completion_ids, padding_value=self.processing_class.pad_token_id)
prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)
else:
with unwrap_model_for_generation(self.model_wrapped, self.accelerator) as unwrapped_model:
prompt_completion_ids = unwrapped_model.generate(
prompt_ids, attention_mask=prompt_mask, generation_config=self.generation_config
)
prompt_length = prompt_ids.size(1)
completion_ids = prompt_completion_ids[:, prompt_length:]
prompt_ids
.use_vllm
is enabled, the main process generates $G$ completions per unique prompt using the vLLM client. Since the batch has duplicates (from the sampler), it takes unique prompts and generates num_generations
outputs, which are then distributed to all processes.generate
method produces one completion per prompt instance. Because the sampler repeats each prompt $G$ times, this results in $G$ outputs per unique prompt across the batch.completion_ids
, concatenated with prompt_ids
for later processing.Key Detail: The num_generations
parameter directly corresponds to $G$, controlling how many outputs are sampled per query, fulfilling the theoretical requirement.
What It Does in Theory: Each output $o_i$ is evaluated by a reward model to obtain rewards ${r_1, r_2, \dots, r_G}$. The mean $\bar{r}$ and standard deviation $\sigma_r$ are computed, and the advantage for each output is calculated as $A_i = \frac{r_i - \bar{r}}{\sigma_r + \epsilon}$.
Implementation in TRL: This is also handled in _generate_and_score_completions
:
rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device)
for i, (reward_func, reward_processing_class) in enumerate(
zip(self.reward_funcs, self.reward_processing_classes)
):
if isinstance(reward_func, nn.Module):
texts = [p + c for p, c in zip(prompts, completions)]
reward_inputs = reward_processing_class(
text=texts, return_tensors="pt", padding=True, padding_side="right", add_special_tokens=False
)
reward_inputs = super()._prepare_inputs(reward_inputs)
with torch.inference_mode():
rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0]
else:
reward_kwargs = {key: [example[key] for example in inputs] for key in inputs[0] if key != "prompt"}
output_reward_func = reward_func(prompts=prompts, completions=completions, **reward_kwargs)
output_reward_func = [r if r is not None else torch.nan for r in output_reward_func]
rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device)
rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).nansum(dim=1)
mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1)
std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1)
advantages = rewards - mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
if self.args.scale_rewards:
advantages = advantages / (std_grouped_rewards.repeat_interleave(self.num_generations, dim=0) + 1e-4)
Reward Computation: For each completion, rewards are calculated using multiple reward_funcs
(e.g., models or custom functions). The total reward $r_i$ is a weighted sum of individual rewards, matching the theoretical $r_i = \alpha \cdot accuracy\_score + \beta \cdot format\_score$.
self.num_generations
) to compute per-group statistics.scale_rewards
is True, it’s normalized to $\frac{r_i - \bar{r}}{\sigma_r + 10^{-4}}$, directly implementing the formula from Step 3.The group-based normalization is a hallmark of GRPO, enabling relative comparisons within each query’s outputs.
What It Does in Theory: This step computes the probability ratio $ratio_{i,t} = \frac{\pi_\theta(o_{i,t} \mid q, o_{i,<t})}{\pi_{\theta_{old}}(o_{i,t} \mid q, o_{i,<t})}$, the clipped term $g(\epsilon, A_i)$, and the per-token loss $L_{i,t} = min(ratio_{i,t} \cdot A_i, g(\epsilon, A_i))$. The total loss includes a KL penalty:
\[L_{total}(\theta) = \frac{1}{G} \sum_{i=1}^{G} \frac{1}{|o_i|} \sum_{t=1}^{|o_i|} L_{i,t} - \beta D_{KL}[\pi_\theta || \pi_{ref}]\]Implementation in TRL: This is implemented in the compute_loss
method:
@profiling_decorator
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"]
input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
logits_to_keep = completion_ids.size(1)
per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep)
if self.beta != 0.0:
ref_per_token_logps = inputs["ref_per_token_logps"]
per_token_kl = (
torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
)
advantages = inputs["advantages"]
old_per_token_logps = inputs["old_per_token_logps"] if self.num_iterations > 1 else per_token_logps.detach()
coef_1 = torch.exp(per_token_logps - old_per_token_logps)
coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high)
per_token_loss1 = coef_1 * advantages.unsqueeze(1)
per_token_loss2 = coef_2 * advantages.unsqueeze(1)
per_token_loss = -torch.min(per_token_loss1, per_token_loss2)
if self.beta != 0.0:
per_token_loss = per_token_loss + self.beta * per_token_kl
loss = (per_token_loss * completion_mask).sum() / completion_mask.sum()
return loss
def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep):
# We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
logits = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits
logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
input_ids = input_ids[:, -logits_to_keep:]
# For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves.
# See https://github.com/huggingface/trl/issues/2770
logits = logits[:, -logits_to_keep:]
# Divide logits by sampling temperature.
# See https://huggingface.co/blog/the_n_implementation_details_of_rlhf_with_ppo#policy-training-implementation-details
logits = logits / self.temperature
return selective_log_softmax(logits, input_ids) # compute logprobs for the input tokens
_get_per_token_logps
method computes per-token log probabilities for the current model (per_token_logps
) and, if needed, the reference model (ref_per_token_logps
).coef_1
.This is the exponential of the log probability difference, equivalent to $\frac{\pi_{\theta}(o_{i,t})}{\pi_{\theta_{old}}(o_{i,t})}$
coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high)
multiplied by advantages
, ensuring the ratio stays within $[1 - \epsilon, 1 + \epsilon]$.per_token_loss1 = coef_1 * advantages
is the unclipped term.per_token_loss2 = coef_2 * advantages
is the clipped term.per_token_loss = -torch.min(per_token_loss1, per_token_loss2)
computes $L_{i,t}$, negated because the training loop minimizes the loss, while GRPO aims to maximize the surrogate objective.beta != 0
, the KL term is approximated as $\exp(ref\_logps - logps) - (ref\_logps - logps) - 1$, added to the loss scaled by beta
.loss
averages $L_{i,t}$ over all tokens, masked by completion_mask
, matching $\frac{1}{G} \sum_{i=1}^{G} \frac{1}{|o_i|} \sum_{t=1}^{|o_i|} L_{i,t}$.Why the Negative Sign?: In RL, we maximize the surrogate objective, but the Trainer
minimizes the loss. Thus, $L_{i,t}$ is negated to align with this convention.
What It Does in Theory: Compute the gradient $\nabla_\theta L_{total}(\theta)$ and update the policy parameters using an optimizer: $\theta \leftarrow \theta - \eta \nabla_\theta L_{total}(\theta)$.
Implementation in TRL: This step leverages the Trainer
class’s training loop, with no explicit override in GRPOTrainer
for the update itself:
compute_loss
method returns the loss, as shown above.Trainer.train
method (inherited from transformers
), the following occurs:
compute_loss
to get the loss.loss.backward()
to compute gradients.optimizers
in __init__
) updates the parameters using the learning rate $\eta$ (e.g., learning_rate=5e-6
from your tutorial).Code Context: While not explicitly shown in GRPOTrainer
, the inherited training step can be conceptualized as:
# From transformers.Trainer.train (simplified)
for step, inputs in enumerate(epoch_iterator):
inputs = self._prepare_inputs(inputs)
loss = self.compute_loss(model, inputs)
loss = loss / self.args.gradient_accumulation_steps
loss.backward()
if (step + 1) % self.args.gradient_accumulation_steps == 0:
self.optimizer.step()
self.optimizer.zero_grad()
gradient_accumulation_steps > 1
, the loss is scaled and gradients are accumulated before the update, enhancing efficiency.GRPOConfig
.This step finalizes the policy improvement, adjusting token probabilities to favor higher-reward outputs while maintaining stability via clipping and KL regularization.
Here’s a concise mapping of all steps to the GRPOTrainer
code:
_get_train_sampler
prepares batches with repeated prompts._generate_and_score_completions
samples $G$ outputs per query._generate_and_score_completions
computes rewards and advantages.compute_loss
calculates the surrogate loss with clipping and KL penalty.Trainer
training loop backpropagates the loss and updates the policy.GRPO empowers large language models with specialized skills, controlled outputs, and enhanced reasoning, surpassing traditional fine-tuning by optimizing multiple responses via a reward model. This paper delivers a concise yet thorough exploration of GRPO, from theoretical steps to practical implementation in tools like TRL. No single prior work combines its theory, practice, and nitty-gritty details—leaving gaps we now fill. By clarifying how GRPO achieves deep expertise, style precision, and debiasing, this guide equips readers to apply it confidently, advancing LLM performance for tailored, impactful use cases.
DeepSeek-AI et al. (2025). DeepSeek-R1: Incentivizing Reasoning Capability in LLMs via Reinforcement Learning. arXiv:2501.12948. ↩
DeepSeek-AI et al. (2025). DeepSeek-V3 Technical Report. arXiv:2412.19437. ↩
Hugging Face, Understanding the DeepSeek R1 Paper, Open R1 for Students. Available at: https://huggingface.co/learn/nlp-course/chapter12/3?fw=pt (Accessed: April 2, 2025). ↩ ↩2
HuggingFace and UnslothAI. Colab: HuggingFace Course - Gemma3 (1B) - GRPO. Available at: [https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/HuggingFace%20Course-Gemma3_(1B)-GRPO.ipynb]. Accessed: April 1, 2025. ↩
HuggingFace. GRPO Trainer in TRL Library. Available at: https://github.com/huggingface/trl/blob/main/trl/trainer/grpo_trainer.py. Accessed: April 1, 2025. ↩ ↩2