Documentation Index
Fetch the complete documentation index at: https://docs.primeintellect.ai/llms.txt
Use this file to discover all available pages before exploring further.
Prime-RL supports custom implementations for key algorithmic components, allowing you to experiment with different RL objectives and techniques.
1. Custom Loss Functions
The loss is computed per-sequence (per-sample). You provide a function that computes the loss for a single sequence, and the framework handles iteration and aggregation.
Interface
from prime_rl.trainer.rl.loss import LossInputs, LossOutputs
def my_custom_loss(inputs: LossInputs, **kwargs) -> LossOutputs:
...
@dataclass
class LossInputs:
trainer_logprobs: Float[Tensor, "seq"] # Log probs from current policy
inference_logprobs: Float[Tensor, "seq"] # Log probs from reference policy
teacher_logprobs: Float[Tensor, "seq"] | None # Optional teacher log probs
advantages: Float[Tensor, "seq"] # Per-token advantages
loss_mask: Bool[Tensor, "seq"] # Mask for valid tokens
LossOutputs
@dataclass
class LossOutputs:
loss: Float[Tensor, ""] # Scalar loss for this sequence
metrics: dict[str, Tensor] # Metrics to log
Example: PPO Clipped Loss
import torch
from prime_rl.trainer.rl.loss import LossInputs, LossOutputs
def ppo_clip_loss(inputs: LossInputs, clip_eps: float = 0.2) -> LossOutputs:
ratio = torch.exp(inputs.trainer_logprobs - inputs.inference_logprobs)
clipped_ratio = torch.clamp(ratio, 1 - clip_eps, 1 + clip_eps)
surr1 = ratio * inputs.advantages
surr2 = clipped_ratio * inputs.advantages
loss = -torch.min(surr1, surr2)[inputs.loss_mask].sum()
return LossOutputs(
loss=loss,
metrics={"clip_frac": (ratio != clipped_ratio)[inputs.loss_mask].float().mean()},
)
Configuration
[loss]
type = "custom"
import_path = "my_module.ppo_clip_loss"
kwargs = { clip_eps = 0.2 }
2. Custom Advantage Functions
Advantages are computed per-example (grouped by rollouts_per_example). You provide a function that computes advantages for a batch of examples.
Interface
from prime_rl.orchestrator.advantage import AdvantageInputs, AdvantageOutputs
def my_custom_advantage(inputs: AdvantageInputs, **kwargs) -> AdvantageOutputs:
...
@dataclass
class AdvantageInputs:
# Rollouts grouped by problem: rollouts[i][j] is the j-th rollout for problem i.
rollouts: list[list[vf.RolloutOutput]]
Each vf.RolloutOutput carries the full rollout (reward, trajectory, etc.), so custom advantages can read any metadata they need (e.g. completion-token counts, turn counts, tool calls).
AdvantageOutputs
@dataclass
class AdvantageOutputs:
advantages: Float[Tensor, "num_examples rollouts_per_example"]
Example: Normalized Advantage
import torch
from prime_rl.orchestrator.advantage import AdvantageInputs, AdvantageOutputs
def normalized_advantage(inputs: AdvantageInputs, eps: float = 1e-8) -> AdvantageOutputs:
"""Normalize advantages to zero mean and unit variance per example."""
rewards = torch.tensor([[r["reward"] for r in group] for group in inputs.rollouts])
mean = rewards.mean(dim=1, keepdim=True)
std = rewards.std(dim=1, keepdim=True)
advantages = (rewards - mean) / (std + eps)
return AdvantageOutputs(advantages=advantages)
Configuration
[advantage]
type = "custom"
import_path = "my_module.normalized_advantage"
kwargs = { eps = 1e-8 }
Default Implementations
If no custom function is specified:
- Loss: Uses
default_loss_fn (masked importance sampling with KL against the inference policy, and optional masking strategies)
- Advantage: Uses
default_advantage_fn (reward minus per-example baseline, a.k.a. DR-GRPO without std normalization)
See LossConfig and AdvantageConfig for available parameters.
Tips
- Your functions receive structured inputs via dataclasses with jaxtyping annotations
- Return metrics as scalars or 1D tensors - they’ll be aggregated automatically
- Use the
loss_mask / tensor shapes to handle variable-length sequences
- Test your custom functions with the provided test patterns before training