prime-rl from a single GPU to a 1000-GPU cluster: single-node and multi-node deployments, FSDP / expert parallelism / context parallelism, and throughput benchmarking. See Training for detailed documentation of the trainer configuration and Inference for the inference configuration.
Table of Contents
Single-Node vs. Multi-Node Deployment
Therl, sft, and inference entrypoints all accept a [deployment] block (type = "single_node" or "multi_node") that picks how the trainer / orchestrator / inference processes are placed across hardware. Single-node runs locally; multi-node currently goes through SLURM — the launcher writes an sbatch script that places inference replicas, the orchestrator, and the trainer with the right rendezvous endpoints, IPs, ports, and shared-filesystem paths wired in.
Single-Node
RL Placement
rl defaults to 1 trainer GPU and 1 inference GPU. To give inference 6 GPUs with data parallelism and the trainer the remaining 2 on an 8-GPU node:
CUDA_VISIBLE_DEVICES (or all visible GPUs): inference first, trainer next, teacher last. To target a specific physical subset, pin CUDA_VISIBLE_DEVICES before launching.
For quick A/B ablations on the same node, run two RL instances side-by-side in separate tmux sessions, each pinned to half the GPUs and a separate inference port:
SFT and Torchrun
uv run sft handles distributed launch internally. To scale from 1 to N GPUs, set the deployment GPU count (or just let it pick up WORLD_SIZE). For non-default layouts, the manual equivalent is:
--local-ranks-filter 0 keeps console output to rank 0 only; per-rank stdout/stderr is still captured in <output_dir>/logs/trainer/torchrun/.
Multi-Node
Multi-node deployments (RL or SFT) are launched via SLURM — set[deployment] type = "multi_node" plus the matching [slurm] block, and the launcher writes the sbatch script that places inference, orchestrator, and trainer across the requested nodes with the inter-process wiring set up correctly. See SLURM § Examples for full configs.
Parallelism Knobs
FSDP
FSDP2 is the default model sharding strategy. By default the trainer fully shards parameters, gradients, and optimizer state across the data-parallel mesh. Tweakable knobs:| Knob | Effect |
|---|---|
trainer.model.dp_replicate | Number of dimensions to replicate instead of shard. Set to 2 to run 2-way DP replication × FSDP sharding within each replica — useful for very large clusters where pure FSDP communication dominates. |
trainer.model.reshard_after_forward | If true (default), parameters are resharded after the forward pass to free memory; the backward pass re-gathers. Set false to keep params resident — faster but more memory. |
trainer.model.fsdp_cpu_offload | Offload params + grads + optimizer state to CPU. Big memory win, large throughput hit. |
trainer.model.optim_cpu_offload | Offload only optimizer state. Mid-ground — small throughput cost, decent memory savings, especially at low GPU count. |
Expert Parallelism
EP shards MoE expert weights across the EP mesh, dramatically reducing the FSDP communication volume per layer and improving the training throughput. EP is only available with the custom model implementation (model.impl = "custom" or "auto" for supported families).
ep_comm_backend = "deepep" uses DeepEP’s custom dispatch/combine kernels for speed, with two extra knobs (deepep_num_sms, deepep_token_chunk_size) — tune on your hardware.
Context Parallelism
CP shards a single sequence across multiple GPUs along the token dimension — for long-context sequences. We reccomend usingulysses style CP for most of the models to get the most throughput. Some models (e.g. GLM-5) only support ring style CP. Wrong setting will be rejected on validation.
Activation Checkpointing and Offloading
| Knob | Memory ↓ | Throughput ↓ |
|---|---|---|
trainer.model.ac | large | ~25% |
trainer.model.ac.mode = "selective" | medium | small |
trainer.model.ac_offloading | extra | a bit more |
ac_offloading and ac_offloading.max_inflight_activations = 5 to further reduce the memory footprint in tradeoff for some throughput. We’ve observed this feature to be very effective, lowering the peak memory usage by 30-40% in some cases, while only lossing ~3-5% of throughput:
Optimizer Offloading
Offloading optimizer states to CPU is a near-free memory win at low GPU counts:fsdp_cpu_offload. Also incompatible with trainer.max_concurrent_runs > 1 (multi-tenant training). Muon doesn’t support fsdp_cpu_offload but does support optim_cpu_offload.
LM Head Chunking
The vanilla LM head materializes a[batch * seq, vocab] logits tensor on every step — a major memory tax when the vocabulary is large (often >100K). fused_lm_head_token_chunk_size swaps in a custom fused linear + logprob/entropy kernel that streams through chunk_size tokens at a time, avoiding the materialization:
auto is a safe starting point for RL. Drop the chunk size further when peak memory is still tight (e.g. with very long sequences); raise it to amortize kernel-launch overhead. Only available with model.impl = "custom", and currently RL-only — the SFT trainer rejects integer values.
Memory-Tight Recipe
The kitchen-sink config for fitting large MoE on limited GPUs at acceptable throughput:torch.compile reduces fragmentation, optim offload moves Adam state off GPU. Apply selectively — each knob has a throughput cost.
SLURM
Therl, sft, and inference entrypoints all submit to SLURM when a [slurm] table is present — there’s no separate entrypoint.
Activation
A SLURM config is usually a thin overlay that adds[slurm] (and [deployment] for multi-node) on top of a base config. Configs are composed left-to-right via the @ CLI syntax — see Configuration § TOML Composition:
[deployment] Block
[deployment] is a discriminated union picked by type — single_node or multi_node for RL/SFT, with an extra disaggregated variant for inference. RL multi-node:
Examples
Full multi-node configs ship inexamples/multinode/:
rl.toml— two-node RL run with NCCL weight broadcast on a 30B MoE student.sft.toml— two-node SFT against the same model.
[deployment] type = "multi_node" on an inference TOML — each node runs an independent vLLM replica (TP and DP must fit within one node), and the launcher prints one URL per node. Front the URLs with a router or point clients at any of them.
Custom Templates
For unusual partitions, module loads, or environment setup, supply your own Jinja2 template:src/prime_rl/templates/ — copy one as a starting point.
Benchmarking
Every entrypoint supports a--bench flag that runs a few warm-up + measurement steps with fake data and prints a rich-formatted throughput / MFU table:
--bench.output-json. Use this to compare parallelism configs before committing a multi-day run.