Skip to content

A Short Guide to PyTorch DDP

In this blog post, we explore what torchrun and DistributedDataParallel are and how they can be used to speed up your neural network training by using multiple GPUs.

Neural networks, or even deep neural networks, are popular models for machine learning. Mathematically, they can be interpreted as nested functions with millions of parameters. If the parameters are tuned well, they can be used to make predictions, such as when given a photo, it predicts what that photo contains. A famous example is Google Lens. These parameters are tuned by adjusting them using data. For example, if you show the network a photo of a dog and point out it's a dog, the parameters are adjusted to make it likely that the next time it sees the same photo, it will recognise that it's a photo of a dog. This is done for millions of photos.

The Python package PyTorch can be used to specify a neural network and train it with data on a GPU. Even better, it can be trained using multiple GPUs, speeding up the training process.

Basic Concept of torchrun

After installing torch on your virtual environment (installation instructions are available on PyTorch's website) the command torchrun is usually available. It allows you to launch multiple workers, each running a copy of your Python script and they can interact with each other.

You can try a simple example on your own computer to understand this. Consider a hello world script called hello_world.py

import os

RANK = int(os.environ["RANK"])

def say_hello(name):
    print(f"Hello {name}")

if __name__ == "__main__":
    names = [
        "Barry",
        "Alice",
        "Barbara",
        "Tom",
    ]
    say_hello(names[RANK])

You can launch four instances of hello_world.py, each with a different value of RANK, thus greeting a different person. This can be done with

torchrun --nproc-per-node 4 hello_world.py

This is pretty much the same as mpirun as described in a previous blog post.

Useful environment variables

Information about a worker can be obtained by using os.environ[] as shown above. Here are more useful environment variables:

  • RANK - The rank of the worker within a worker group.
  • WORLD_SIZE - The total number of workers in a worker group.
  • LOCAL_RANK - The rank of the worker within a local worker group.

There are more variables in PyTorch's documentation.

To use torchrun on Apocrita, launching a worker for each requested GPU, use the following in your job script

torchrun --nproc-per-node gpu --rdzv-backend=c10d --rdzv-endpoint=localhost:0 \
  training_script training_script_args

where the positional arguments are

positional arguments:
  training_script       Full path to the (single GPU) training program/script to
                        be launched in parallel, followed by all the arguments
                        for the training script.
  training_script_args  Arguments to pass to your script, these can be, for
                        example, programmed using argparse

Ensure you use a free port

The options --rdzv-backend=c10d and --rdzv-endpoint=localhost:0 should be used on Apocrita to ensure there are no port clashes should multiple users be running torchrun on the same node. This automatically assigns a free port to your job, otherwise, you may encounter a torch.distributed.DistNetworkError exception. See PyTorch's documentation for more information.

Assign a worker for each GPU

The option --nproc-per-node gpu will automatically launch a process for each requested GPU.

Single node multiple GPUs jobs on Apocrita

At the time of writing, only single-node multiple GPUs jobs are available on Apocrita. Up to 4 GPUs can be requested this way. torchrun does work with multiple nodes of GPUs but this is beyond the scope of this blog.

With torchrun, you can use PyTorch's features such as DistributedDataParallel . This is a way to parallelise your machine learning code across multiple GPUs, and even multiple nodes of GPUs.

GPU Architecture for Machine Learning

In this section, we recap the architecture of a GPU and how it applies to typical PyTorch code. Figure 1 shows a photo of a commercial grade graphics card, which houses many components such as a GPU chip, memory, a heat sink, fans and many more. If it isn't needed to distinguish the components, quite commonly the entire card is referred to as the GPU.

A photo of a graphics card

Figure 1: A photo of a commercial grade graphics card, Nvidia GTX 1660

For this blog, we will focus on the GPU and the memory housed on the card, commonly referred to as video RAM (VRAM). Figure 2 shows an illustration of this. Commercial grade cards, such as the GTX and RTX series, typically have VRAM of about 2-24 GB. Whereas enterprise cards such as our A100 and H100 cards have either 40 GB or 80 GB of VRAM. The amount of VRAM the card you're using is important as it will need to contain the parameters of the neural network you want to train. Small and older neural networks such as AlexNet can fit on a commercial grade card, but larger newer models, like RegNet and ConvNeXt typically require enterprise cards.

Diagram showing a graphics card, consisting of a GPU and VRAM

Figure 2: A diagram showing a graphics card, consisting of a graphics processing unit (GPU) and memory, commonly referred to as video RAM (VRAM). The box underneath the VRAM shows its content in a typical PyTorch code.

It is also worth pointing out that datasets used in machine learning, such as ImageNet, are in the order of hundreds or thousands of GB. It just isn't possible to fit it in VRAM. Thus typically, batches of data are loaded to VRAM, one at a time, when training a neural network. Optimisation methods, such as stochastic gradient descent, can use batches of data to train the neural network. Figure 2 illustrates how VRAM contains the parameters of the neural network and a batch of data. Figure 3 illustrates how a dataset can be split into batches, a GPU will train the neural network one batch of data at a time.

Diagram showing a dataset being split into 4 batches, each batch being loaded onto the GPU, one at a time

Figure 3: A diagram showing a dataset being split into batches. During training each batch of data is used to train a neural network on the GPU, one batch at a time.

Assessing how much VRAM you need

Use tools such as nvidia-smi, nvtop and nvitop to monitor your GPU utility and how much VRAM is being used. They are available on Apocrita, please see the documentation.

If you encounter memory errors, you may need a GPU with more VRAM. It is also possible to fit your model across multiple GPUs but this is beyond the scope of this blog.

How DistributedDataParallel Uses Multiple GPUs

The idea behind DistributedDataParallel is that each GPU trains the same neural network using different batches of data in parallel, getting through all of the batches faster. Figure 4 illustrates this.

Diagram showing two GPUs, each working on different batches

Figure 4: A diagram showing two GPUs, each working on different batches of a dataset.

It should be noted that this is not a pleasingly parallel task. In a previous blog post, we looked at pleasingly parallel GPU tasks. They are so because each GPU can work independently, working on different models or parameters.

However, in the case of DistributedDataParallel, each GPU is training the same model and is working together, not independently. After each GPU works out the gradient for a batch, they are collated together and used to update the parameters of the neural network. The updated neural network is sent to all GPUs, ready to work on the next batch. This is illustrated in Figure 5. This cycle continues until some stopping condition is met. This type of parallelisation can be described as tightly coupled.

Diagram showing two GPUs in a cycle of gathering gradients and updating parameters

Figure 5: For each batch, the gradients calculated from each GPU are gathered and used to update the parameters of the neural network. The updated parameters are then sent to all GPUs, ready for the next batch of data.

Using DistributedDataParallel With Your Code

For reference, PyTorch has documentation on DistributedDataParallel such as in their API documentation, their beginner's tutorial and their intermediate's tutorial. PyTorch also has example code on their GitHub. There are also blog posts such as one by Kevin Kaichuang Yang and Jackson Kek, the latter being my recommendation to read.

When reading these posts, it may be useful to note the following:

  • The alternate to DistributedDataParallel is DataParallel. For performance reasons, it should be noted that using DistributedDataParallel should be favoured over using DataParallel. See PyTorch's comparison for further information.
  • torchrun was introduced at around v1.10. Previously, you would run the library module python -m torch.distributed.launch.
  • On Apocrita where GPU jobs are only limited to one node, you do not need to set the variables os.environ['MASTER_ADDR'] and os.environ['MASTER_PORT'].

In this blog post, I'll provide some steps to help rewrite your existing PyTorch code to use DistributedDataParallel and then a full example script.

Steps To Use DistributedDataParallel

Get information about each worker

We first get information about each worker, such as the rank and the world size.

RANK = int(os.environ["RANK"])
LOCAL_RANK = int(os.environ["LOCAL_RANK"])
WORLD_SIZE = int(os.environ["WORLD_SIZE"])

MASTER_RANK = 0

I personally like to set them as global variables to make them easier to access in the middle of the code. They shouldn't need to change value so their content should be crystal clear during the run time of the software.

I've also defined MASTER_RANK to designate a worker to be the master.

When to use the master rank

When doing tasks which only require one worker, such as printing or logging, only one worker should do that task whereas the rest do not. This can be done by using a conditional, for example

if RANK == MASTER_RANK:
    print("Here's some information")

The difference between LOCAL_RANK and RANK

The RANK will be unique for each worker whereas LOCAL_RANK will be unique for each worker in a group. Typically in multi-node jobs, LOCAL_RANK is used to identify each GPU on a node. Apocrita does not support multi-node GPU jobs, however, the distinction is important should your code need to scale up.

Set the GPU

We set the GPU, or device, we want a worker to use. Typically we want one GPU per worker so we use torch.device(f"cuda:{LOCAL_RANK}") to assign a unique GPU to each worker. This is then followed by init_process_group() to make the workers aware of each other.

device = torch.device(f"cuda:{LOCAL_RANK}")
torch.cuda.set_device(device)
torch.distributed.init_process_group(
    backend="nccl", world_size=WORLD_SIZE, rank=RANK)
torch.distributed.barrier()

The nccl backend

We choose the nccl backend as we find this works best for multiple GPUs. See the PyTorch's documentation for further information.

Move the model to the GPU and wrap it with DistributedDataParallel

model.to(device)
model = nn.parallel.DistributedDataParallel(model, device_ids=[device])

Running code on a CPU instead

It is possible to run your code on multiple CPU cores instead of multiple GPUs by using device = torch.device("cpu") and tinkering the code a bit more. You may want your code to be runnable on either CPU or GPU so that it can be compatible with different systems. However, this can bloat your code.

Plan accordingly.

Set the DataLoader and DistrbutedSampler

Use a DistributedSampler, this ensures each GPU is allocated different data points. This is provided to your usual DataLoader.

train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)

# set data loader
data_loader = torch.utils.data.DataLoader(
    dataset,
    batch_size=batch_size,
    sampler=train_sampler,
    num_workers=num_workers,
    pin_memory=True
)

Setting num_workers

The argument num_workers sets how many subprocesses to use. Remember that when using torchrun, multiple copies of your script are executed. Thus, to avoid overthreading, num_workers should be no bigger than the number of CPU cores per worker (or GPU) minus one.

At the time of writing, on Apocrita, set this argument to or less than 11.

I recommend doing this programmatically. For example, num_workers can be an argument for your script. Or specifically for Apocrita

num_workers = int(os.getenv("NSLOTS")) // torch.cuda.device_count() - 1

Train your model

Train your model as usual. By default, the DistributedSampler randomise the data. We provide set_epoch() what epoch we are on so that for every epoch, different randomised data are distributed to each GPU.

model.train()

for epoch in range(n_epoch):

    train_sampler.set_epoch(epoch)

    for image, target in data_loader:
        image = image.to(device)
        target = target.to(device)
        output = model(image)

        ...

Run your code

Once your script has been modified, run your script with torchrun as explained in the previous section.

torchrun --nproc-per-node gpu --rdzv-backend=c10d --rdzv-endpoint=localhost:0 \
  training_script training_script_args

Example Script

If you're still stuck, you can study and run the example script below. We wrote our own neural network model and trained it on the MNIST dataset (the dataset was studied previously in a blog post).

The code demonstrates how to use DistributedDataParallel and DistributedSampler and can run on Apocrita with one or more GPUs. We've also restricted downloading the MNIST dataset and printing to the master process to avoid downloading the dataset more than once. The function torch.distributed.reduce is used to gather results from all GPUs.

To run the script, call

torchrun --nproc-per-node gpu --rdzv-backend=c10d --rdzv-endpoint=localhost:0 \
  training_script
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms

import os


RANK = int(os.environ["RANK"])
LOCAL_RANK = int(os.environ["LOCAL_RANK"])
WORLD_SIZE = int(os.environ["WORLD_SIZE"])

MASTER_RANK = 0

N_EPOCH = 5
BATCH_SIZE = 100


# Our own custom neural network on the MNIST dataset
class Net(nn.Module):
    def __init__(self, num_classes=10):
        super(Net, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2))
        self.layer2 = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2))
        self.fc = nn.Linear(7*7*32, num_classes)

    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = out.reshape(out.size(0), -1)
        out = self.fc(out)
        return out


def main():

    device = torch.device(f"cuda:{LOCAL_RANK}")
    torch.cuda.set_device(device)
    torch.distributed.init_process_group(backend="nccl", world_size=WORLD_SIZE,
                                         rank=RANK)
    torch.distributed.barrier()

    model = Net()
    model.to(device)
    model = nn.parallel.DistributedDataParallel(model, device_ids=[device])

    # place this condition so only one process downloads the mnist dataset
    if RANK == MASTER_RANK:
        dataset = torchvision.datasets.MNIST(
            root='.', train=True, transform=transforms.ToTensor(),
            download=True)
        torch.distributed.barrier()
    else:
        # all remaining processes can read the mnist dataset once the master
        # process finished downloading the mnist dataset
        torch.distributed.barrier()
        dataset = torchvision.datasets.MNIST(
                root='.', train=True, transform=transforms.ToTensor(),
                download=True)

    sampler = torch.utils.data.distributed.DistributedSampler(dataset)
    data_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=BATCH_SIZE,
        sampler=sampler,
        num_workers=11,
        pin_memory=True
    )

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), 1e-4)

    model.train()  # set the model to training mode

    for epoch in range(N_EPOCH):

        sampler.set_epoch(epoch)
        total_loss = torch.zeros(1).to(device)

        for image, target in data_loader:
            image = image.to(device)
            target = target.to(device)
            output = model(image)
            loss = criterion(output, target)
            total_loss += loss
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # gather and sum the loss from each worker
        torch.distributed.reduce(total_loss, MASTER_RANK,
                                 torch.distributed.ReduceOp.SUM)

        # only the master rank prints the loss at every epoch
        if RANK == MASTER_RANK:
            print(f"Total loss: {total_loss[0]}")

if __name__ == "__main__":
    main()

Summary

We have explained what DistributedDataParallel is and how it can be used with torchrun and multiple GPUs to speed up your machine learning training scripts. We've provided some guidelines and tutorials to get you started using it on Apocrita. As you progress in your research, you may need additional features such as checkpointing, to save and resume training your model, and seeding, to make your training reproducible.

In the next blog, we will benchmark DistributedDataParallel on different GPUs and neural networks to see how much of a performance gain we get from using DistributedDataParallel.

Acknowledgement

We like to thank Niki Foteinopoulou for raising a ticket with us and discussing how to get DistributedDataParallel working on Apocrita with us.

The GPU illustration is from vecteezy.com.