In this tutorial, we will see how to train a model across multiple nodes using pytorch FSDP.

To follow this tutorial you need to have access to multiple nodes. On each node, you need to install torch, transformers, and datasets.

You can access all of the python code used in this tutorial here.

1 - Single GPU training

Let’s start with a simple script to train a 150m language model on the C4 datasets on one GPU.

train_simple.py
import torch
from torch.utils.data import DataLoader
from transformers import (
    AutoTokenizer,
    DataCollatorForLanguageModeling,
    LlamaConfig,
    LlamaForCausalLM,
)
from datasets import load_dataset

GLOBAL_BATCH_SIZE = 128
MICRO_BATCH_SIZE = 8
GRAD_ACCUMULATION_STEPS = GLOBAL_BATCH_SIZE // MICRO_BATCH_SIZE

def main():  
    ## load model
    model_config = LlamaConfig.from_pretrained("PrimeIntellect/llama-150m-fresh")
    model = LlamaForCausalLM.from_pretrained(pretrained_model_name_or_path="PrimeIntellect/llama-150m-fresh", config=model_config)
    model = model.to("cuda")

    optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)

    ## prepare data
    tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", use_fast=True)
    tokenizer.pad_token = "</s>"  # Ensure pad token is set for models that need it

    ds = load_dataset("allenai/c4", "en", streaming=True)

    def tokenize_function(data):
        outputs = tokenizer(
            data["text"],
            truncation=True,
            max_length=1024,
            padding="max_length",
        )
        return outputs

    tokenized_datasets = ds.map(tokenize_function, batched=True, remove_columns=["text", "timestamp", "url"])["train"]

    data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

    train_loader = DataLoader(
        tokenized_datasets,
        collate_fn=data_collator,
        batch_size=MICRO_BATCH_SIZE,
        num_workers=4,
    )

    for step, batch in enumerate(iterable=train_loader):
        
        is_accumulating = step % GRAD_ACCUMULATION_STEPS == 0

        for key in batch.keys():
            batch[key] = batch[key].to("cuda")

        outputs = model(**batch)
        loss = outputs.loss / GRAD_ACCUMULATION_STEPS
        loss.backward()

        if not is_accumulating:
            optimizer.step()
            optimizer.zero_grad()
        
        print(f"step: {step}")


if __name__ == "__main__":
    main()

You can run this on one GPU by doing:

python train_simple.py
>>> step: 0
>>> step: 1
>>> step: 2
...

2 - Torch Distributed

Let’s now turn this script into a multi-GPU (and later multi-node) script using FSDP and torchrun.

Distributed training in pytorch follows the SPMD (Single Program, Multiple Data) paradigm. The same training code is run on multiple GPUs each in its own process and they communicate with each other using torch.distributed. torchrun is a pytorch utility that does the job of spawning a process on each GPU and making sure they can communicate with each other.

To use torch.distributed, we need to initialize the torch process group before training and delete it afterward. We also need to ensure that the default device is set to the LOCAL_RANK, i.e., the GPU number within the current node.

train_multi_gpu.py
import os
import torch
from torch.utils.data import DataLoader
from torch.distributed import destroy_process_group, init_process_group

...

GLOBAL_BATCH_SIZE = 128
MICRO_BATCH_SIZE = 8
GRAD_ACCUMULATION_STEPS = GLOBAL_BATCH_SIZE // MICRO_BATCH_SIZE

# Function to initialize the distributed process group
def ddp_setup():
    init_process_group(backend="nccl")
    torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
...


if __name__ == "__main__":
    ddp_setup()
    main()
    destroy_process_group()

We can now run the code using torchrun this time:

torchrun  --nproc_per_node=1 train_multi_gpu.py
>>> step: 0
>>> step: 1
>>> step: 2
...

3 - Fully Sharded Data Parallel (FSDP)

You can already run the code above on multiple GPUs, but this will only train the same model on different GPUs. Without proper communication and setup, this would be ineffective.

To train your model across multiple nodes, you could either use FSDP or DDP from PyTorch. Both are forms of data parallelism (each GPU processes different data), but FSDP is more flexible and allows for the training of larger models by reducing memory usage.

To use torch FSDP, we need to wrap our model in an FSDP unit:

train_fsdp.py
...

from torch.distributed.fsdp import (
    FullyShardedDataParallel as FSDP,
    ShardingStrategy,
    MixedPrecision,
)
from contextlib import nullcontext

...

from datasets import load_dataset
from datasets.distributed import split_dataset_by_node

...

def main():
    ### load torch distributed env var
    local_rank = int(os.environ["LOCAL_RANK"])
    world_size = int(os.environ["WORLD_SIZE"])
    rank = int(os.environ["RANK"])

    ...

    model = model.to("cuda")
    model = FSDP(model, mixed_precision=MixedPrecision(param_dtype=torch.bfloat16), use_orig_params=True)
    
    ...

    tokenized_datasets = ds.map(tokenize_function, batched=True, remove_columns=["text", "timestamp", "url"])["train"]
    tokenized_datasets = split_dataset_by_node(tokenized_datasets, world_size=world_size, rank=rank)

    ...

    for step, batch in enumerate(iterable=train_loader):
        ...

        with model.no_sync() if is_accumulating else nullcontext():
            outputs = model(**batch)
            loss = outputs.loss / GRAD_ACCUMULATION_STEPS
            loss.backward()

What these modifications do:

  • Wrap the model into a FSDP unit (with mixed precision)
  • Use the split_dataset_by_node utility to split the dataset into a subset for each node/rank
  • Use the no_sync context manager to avoid doing any communication during the gradient accumulation phase

You can now run this code on multiple GPUs using torchrun:

torchrun  --nproc_per_node=8 train_fsdp.py
W0812 12:28:00.240000 139907889326976 torch/distributed/run.py:757] 
W0812 12:28:00.240000 139907889326976 torch/distributed/run.py:757] *****************************************
W0812 12:28:00.240000 139907889326976 torch/distributed/run.py:757] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
W0812 12:28:00.240000 139907889326976 torch/distributed/run.py:757] *****************************************
Resolving data files: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1024/1024 [00:39<00:00, 25.86it/s]
step: 0
step: 1
step: 2
step: 3
...

4 - Run on Multiple Nodes

The code is already ready to be run on multiple nodes. You need to have a copy of the script on each node and be able to ssh into each of them.

First, you need to decide which node will be the master node, you need to know the private IP of the master node. The private IP usually starts with 10. or 192..

You can find the IP using one of the commands below:

ip a | grep 10.

or

ip a | grep 192.

On each node then do:

export RDZV_ENDPOINT=10.15.42.1:1234

Replace 10.15.42.1 with the private IP address of your master node and 1234 with any open port on the master node.

You then need to assign a rank to each node. The master node must have rank 0.

export MY_RANK=0
Don’t name this environment variable RANK, as RANK will conflict with the one used by torch.distributed.

Finally to start the training, run the following command on each node:

torchrun --nproc_per_node=8 \
    --node-rank $MY_RANK \
    --rdzv_endpoint=$RDZV_ENDPOINT \
    --nnodes=2 \
    train_fsdp.py
--nnodes should be adjusted to the number of nodes you have.

Each node should wait for the others to be ready before starting the training.

Troubleshooting

If torchrun is hanging forever or crashes, check that:

  • Rank 0 is the one that exposed its private IP
  • There are no duplicate ranks. (Each node should have a unique MY_RANK)
  • All ranks should be assigned (based on the number of nodes you specified)
  • Other machines can reach the private IP (use ping <MASTER_NODE_IP> to verify). Or start a python server python -m http.server 1234 on the master node and check that the other node can reach it curl http://$RDZV_ENDPOINT

5 - Adapting the FSDP ShardingStrategy

FSDP stands for Fully Sharded Data Parallel which is a data parallelism technique that allows training bigger model by sharding model weight, gradient, and optimizer across all of the ranks.

The FSDP class in PyTorch can go beyond simply doing fully sharded data parallel. It exposes different sharding strategies.

There is no free lunch when it comes to choosing one of these FSDP strategies. It all depends on your hardware, your number of nodes, and the size of the model you are training. A good method to help make a decision is to benchmark the Model FLOPs Utilization (MFU) of different strategies on your hardware and model size. See this paper for more information:

To change the FSDP strategy you can do:

from torch.distributed.fsdp import ShardingStrategy
...
def main():
    ...
    model = FSDP(model, sharding_strategy=ShardingStrategy.FULL_SHARD, mixed_precision=MixedPrecision(param_dtype=torch.bfloat16), use_orig_params=True)
    ...

NO_SHARD

model = FSDP(model, sharding_strategy=ShardingStrategy.NO_SHARD, mixed_precision=MixedPrecision(param_dtype=torch.bfloat16), use_orig_params=True)    

The no-shard strategy is equivalent to DDP ( Distributed Data Parallel). The model as well as the optimizer state and gradient are duplicated on each rank. The gradients are all reduced once per step and the model update is done on each rank.

This strategy is the least communication intensive but requires more memory. It is useful to train small models (below 1B parameters).

SHARD_GRAD_OP

model = FSDP(model, sharding_strategy=ShardingStrategy.SHARD_GRAD_OP, mixed_precision=MixedPrecision(param_dtype=torch.bfloat16), use_orig_params=True)    

This strategy shards the gradient and the optimizer state across all the ranks, but the model weights are replicated on each rank.

It is more communication-intensive than NO_SHARD because you need to communicate each time you want to modify the gradients or use the optimizer state, but it is less memory-intensive because the gradient and optimizer state are never duplicated.

Avoid using this strategy if your intra-connect is not NVLINK / NVSWITCH or PCIE 5.0.

FULL_SHARD

model = FSDP(model, sharding_strategy=ShardingStrategy.FULL_SHARD, mixed_precision=MixedPrecision(param_dtype=torch.bfloat16), use_orig_params=True)    

This strategy is the default FSDP strategy. Nothing is replicated: the model, the optimizer state, and the gradient are all sharded across all the ranks. This is the least memory-intensive strategy but requires a lot of communication. Specifically, at each forward pass of a layer, the weights are virtualized on all ranks via an all-gather operation.

Avoid using this strategy if your intra-connect is not NVSWITCH (SXM machine).

Hybrid techniques

You might have a very fast intraconnect (between GPUs on the same node) but a relatively slow interconnect (e.g. 100gbs ethernet between nodes). In this case, you might want to leverage the least memory-intensive strategy (FULL_SHARD) within a node but still use the less communication-intensive strategy (NO_SHARD) across nodes.

This can be done using a hybrid strategy. You need first to create a devices_mesh which represents your topology.

from torch.distributed.device_mesh import init_device_mesh

local_world_size = int(os.environ["LOCAL_WORLD_SIZE"])
nnodes = world_size // local_world_size
device_mesh = init_device_mesh("cuda", (nnodes, local_world_size), mesh_dim_names=("global", "local"))
model = FSDP(model, sharding_strategy=ShardingStrategy.HYBRID_SHARD, mixed_precision=MixedPrecision(param_dtype=torch.bfloat16), use_orig_params=True, device_mesh=device_mesh)    

If you want to use the SHARD_GRAD_OP strategy on one node and normal NO_SHARD between nodes you can use the _HYBRID_SHARD_ZERO2 strategy.

Hybrid strategy should be used if you don’t have Infiniband or if you are training a relatively small model (< 7b parameters). Use a Hybrid strategy if you are training a relativly small model (< 7b parameters).

What if you have a even slower interconnect bandwith (less than 100gbs) between nodes? Check out OpenDiloco, our framework for low bandwidth distributed training.