Skip to main content
PRIME-RL implements asynchronous off-policy training, instead of the traditional synchronous on-policy training. This means that we allow inference to generate rollouts from a stale policy up to kk (in the code we call this async_level) steps ahead of the trainer. With k=1 and trainer and inference step timings being equal, this allows to run without any idle time on either the trainer or inference. By default, we set k=2 to allow overlap with a weight broadcast over the Internet, which is needed for decentralized training. Two-Step Off-Policy Training

Loss Objective

We adopt a loss objective capable of handling the natural distribution shift caused by the off-policy nature of the training. By default, we use a token-level loss variant of the AIPO training objective introduced in Llama-RL, but omit the entropy and KL loss terms. At each step, we sample NN prompts from our dataset. For each prompt xx, we sample a group of rollouts {yi}i=1G\{y_i\}^G_{i=1} and use a verifier to assign scores sis_i to each yiy_i. Then, the optimization objective is given by JAIPO(θ)=1j=1Ni=1Gyi(j)j=1Ni=1Gt=1yi(j)min(π(yi,t(j)xj,yi,<t(j))μ(yi,t(j)xj,yi,<t(j)),δ)A^i,t(j)\mathcal{J}_{\text{AIPO}}(\theta) = \frac{1}{\sum_{j=1}^N \sum_{i=1}^G |y_i^{(j)}|} \sum_{j=1}^N \sum_{i=1}^G \sum_{t=1}^{|y_i^{(j)}|} \min\left( \frac{\pi(y^{(j)}_{i,t}\mid x_j, y^{(j)}_{i,<t})}{\mu(y^{(j)}_{i,t}\mid x_j, y^{(j)}_{i,<t})}, \delta \right)\hat{A}^{(j)}_{i,t} where μ\mu refers to the policy that generated the rollout, π\pi refers to the current policy, A^i,t\hat{A}_{i,t} is the token-level advantage, and δ\delta is the importance sampling clipping ratio.

Step Semantics

PRIME-RL uses a global training step n=1,2,3,n=1,2,3,\dots that is used to tag artifacts:
  • Trainer: Produces policy πn\pi_n with weights θn\theta_n from rollouts (xn,yn)(x_n, y_n)
  • Inference: Produces rollouts (xn,yn)(x_n, y_n) from policy πmax(0,nk)\pi_{max(0, n-k)}
Here, kk is the async_level parameter, which defaults to 2. Note, that we use 0-indexed steps to cleanly indicate that at each step the divergence off-policy gap is at most kk steps.