Skip to main contentPRIME-RL implements asynchronous off-policy training, instead of the traditional synchronous on-policy training. This means that we allow inference to generate rollouts from a stale policy up to k (in the code we call this async_level) steps ahead of the trainer. With k=1 and trainer and inference step timings being equal, this allows to run without any idle time on either the trainer or inference. By default, we set k=2 to allow overlap with a weight broadcast over the Internet, which is needed for decentralized training.
Loss Objective
We adopt a loss objective capable of handling the natural distribution shift caused by the off-policy nature of the training. By default, we use a token-level loss variant of the AIPO training objective introduced in Llama-RL,
but omit the entropy and KL loss terms.
At each step, we sample N prompts from our dataset. For
each prompt x, we sample a group of rollouts {yi}i=1G
and use a verifier to assign scores si to each yi.
Then, the optimization objective is given by
JAIPO(θ)=∑j=1N∑i=1G∣yi(j)∣1j=1∑Ni=1∑Gt=1∑∣yi(j)∣min(μ(yi,t(j)∣xj,yi,<t(j))π(yi,t(j)∣xj,yi,<t(j)),δ)A^i,t(j)
where μ refers to the policy that generated the rollout, π refers to the current policy, A^i,t is the token-level advantage, and δ is the importance sampling clipping ratio.
Step Semantics
PRIME-RL uses a global training step n=1,2,3,… that is used to tag artifacts:
- Trainer: Produces policy πn with weights θn from rollouts (xn,yn)
- Inference: Produces rollouts (xn,yn) from policy πmax(0,n−k)
Here, k is the async_level parameter, which defaults to 2. Note, that we use 0-indexed steps to cleanly indicate that at each step the divergence off-policy gap is at most k steps.