Parallel Training Neural Networks
For any serious practitioner of deep learning, there will be a point at which a single GPU is not sufficient for the problem you’re trying to solve. This could be brought about by needed a model that is too large or training data that is high-dimensionality. This blog post will be focused on the application of the former: a large training data-set. Alternatively, you may be tired of sitting around waiting for your model to train and want to take better advantage of the hardware available. Regardless, training deep neural networks on multiple GPUs in parallel is a point many reach in their research/industrial career.
During the earlier parts of my academic career, I was well exposed to the foundational concepts of parallel computing on the HPC scale. In particular, this involved MPI programming in Fortran and AVX programming in C. While I would much rather not touch Fortran currently, these years taught me very important and abstract ideas of how to build and think about parallel programs. These stills, which were once considered essential for scientific computing, are essential non-existent for the overwhelming majority of deep learning practitioners. Now you don’t need to be an expert in MPI to use PyTorch or Tensor flow but having some foundational knowledge of how to train networks in parallel can go a long way to help you improve your deep learning workflow.
For me, the PyTorch parallel training wrappers simply did not operate correctly on the cluster I was on, in particular inter-node communication. Instead of bugging the system admins for a fix and not being too comfortable relying completely on code I can’t see for parallel communication, I decided to just implement the parallel training manually. Thats when I realized that there were way too many useless Median articles and tutorials that tell you nothing on how to implement parallel training. Apparently, some people out there don’t understand that regurgitating the PyTorch docs at a lower quality and no working examples is not a tutorial. And thus here we are…

The Basics of Parallel SGD
Fortunately, understanding how parallel training works for neural networks is astonishingly simple requiring zero prior knowledge on parallel programming.
Generally, speaking the only parallel operation required is all_reduce
, which is used to communicate gradients between processes.
The vanilla training algorithm on each process is as follows:
- Forward pass of the model
- Computer loss
- Back propagation
- All reduce average gradients
- Update model/ Optimizer Step
And that’s it (well neglecting a few details)!
Loop this multiple times and you have a model training in parallel with exact gradient descent.
This process is illustrated in the figure below for four parallel processes.
Notice that most of the computation can be executed independently of other tasks with the exception being the blocking all_reduce
call.

There is some prep-work that we will need to do plus actually talk about what an all reduce call is.
!! If you are using BatchNorm you will need to do more work on the forward pass !!
What About Other Methods?
A quick note on alternative methods.
The one we will focus on is exact gradient decent and is what libraries like PyTorch implement behind the scenes with a few extra bells and whistles.
Unfortunately, the blocking nature of the all_reduce
call needed for averaging the gradients can slow things down as we scale.
So multiple works have proposed alternative, typically approximate, algorithms for training deep learning models in parallel.
While not discussed in detail here, the advantage of implementing parallel training yourself is that you can then implement these other methods.
- Periodic averaging the model parameters instead of the gradients with a modified SGD [1].
- Federated learning which uses subsets of parallel processes to compute the gradient [2].
- Using shared memory to store a global model between processes which is asynchronous updated [3].
Does just averaging model parameters work? NO, optimization objectives are typically non-convex for neural networks with many local minima. Averaging weights alone tends to yield poor results [4].
Data Preparation
Before even touching the neural network, we need to first discuss how we handle the data between the parallel processes. The standard approach is to partition the training data into equal independent segments which are then used as the local training dataset for each task. Given that each process has an ID and we know how many processes exist in the communication network, segmentation can be implemented as follows :
class ParallelDataset(Dataset):
def __init__(self, data, proc_id:int = 0, num_procs:int = 1, seed:int = 1234):
rng = Random()
rng.seed(seed)
data_len = len(data)
indexes = [x for x in range(0, data_len)]
rng.shuffle(indexes)
seg_len = data_len // num_procs
local_indexes = indexes[proc_id*seg_len: (proc_id+1)*seg_len]
self.examples = []
for index in local_indexes:
self.examples.append(data[index])
def __getitem__(self, i:int):
return self.examples[i]
Which is essentially splitting the dataset into separate segments for each task as illustrated below. Note the setting of the random seed, this ensures the shuffle of indices is identical between parallel tasks.

But how do I get my process ID and number of processes?
Getting a unique ID for each parallel task (rank id) as well as getting the number of tasks (world size) is easily accessible for different parallel frameworks. Some are listed below:
Library | proc_id | num_procs | Ref |
---|---|---|---|
PyTorch | torch.distributed.get_rank() | torch.distributed.get_world_size() | [Link] |
Python Threading | threading.get_ident() | threading.active_count() | [Link] |
MPI | MPI_Comm_rank | MPI_Comm_size | [Link] |
MPI4Py | mpi4py.MPI.COMM_WORLD.Get_rank() | mpi4py.MPI.COMM_WORLD.Get_size() | [Link] |
CuPy NCCL | cupy.cuda.nccl.NcclCommunicator.rank_id() | cupy.cuda.nccl.NcclCommunicator.size() | [Link] |
Note: For NCCL/CuPy it is suggested to use MPI for getting rank id/world size. This is because NCCL can have multiple workers on one CPU thread.
Understanding All Reduce
With the training data segmented we can now jump into parallel programming.
As previously stated, the only part necessary for us to implement is an all_reduce
call between processes.
The goal of this is to average the parameter gradients between all processes.
For a chunk of data on all processes, all_reduce
is a parallel algorithm that computes an aggregate value of the combination of data from each parallel task.
The standard supported functions include SUM, MAX, MIN and PROD.
But this can also include some boolean based operations.

all_reduce
is an operation that requires information from all processes involved.
Meaning that in our parallel program this will be a point of synchronous execution.
How this is carried out is dependent on the library we want to use discussed in later sections.
The SUM operation is what we will be using since we want to average the gradients after all which is easily accomplished by dividing all reduced gradients by the number of parallel processes. The parallel averaging of gradients follows the standard back-propagation. Typically, on a single process, gradients for a particular parameter are averaged for all batch data [5].
The global batch size is now $N\times M$. This is important because this means you can now use bigger batches if your model still optimizes okay. Alternatively, this means for the same global batch size, you can use small local batch sizes freeing up VRAM and allowing for a bigger model.
Implementation
Finally let’s get to some parallel code. We’re going to explore several different parallel libraries that have Python wrappers to look at how implementation and performance differs.
MPI (mpi4py)
Message Passing Interface (MPI) is a standard designed for communication on parallel computing architectures. MPI is designed from communication between CPUs, meaning we will need to move the gradients on and off the VRAM. We will be using the Python wrapper for MPI, mpi4py, which supports various MPI libraries while keeping the python function calls the same. I personally used OpenMPI. The implementation of the function to average the gradients between tasks follows:
from mpi4py import MPI
def average_gradients(model, device='cpu'):
"""
Args:
model (nn.module): Pytorch model
device (torch.device, optional): PyTorch device.
"""
mpi_comm = MPI.COMM_WORLD
size = mpi_comm.Get_size()
# Loop through model parameters
for _, param in enumerate(model.parameters()):
# If parameter has a gradient exchange
if not param.grad is None:
param_cpu = param.grad.data.cpu()
# allreduce(Send buffer, Op)
param_sum = mpi_comm.allreduce(param_cpu, op=MPI.SUM)
param.grad.data = param_sum.to(device) / size
A nice feature of the mpi4py library is that you do not need to specify the data type of the input to mpi_comm.allreduce
.
mpi4py will just pickle the data and send it as binary data which the docs claim is near C speed.
The big limitation here is that we need to off load param.grad.data
to the CPU and then move the average param_sum
back to the GPU.
Some MPI versions support GPU aware communication which allows you to pass data on GPU memory, but mpi4py is still in development for these features and there are better libraries for this as we will see.
MPI was originally designed in a time before GPU accelerators were a core part of high-performance clusters (referred to as heterogenous architectures).
The standard way the all_reduce
call in MPI call functions is by forming a binary tree structure of all MPI tasks.
MPI will first reduce the data onto a master process before then sending it (broadcast) to all other processes [6].
This tree structure will be intelligently built on the programs start based on communication latency between nodes.
But MPI alone is missing out on technology that is present with GPU hardware to speed things up.

NCCL (CuPy)
The NVIDIA Collective Communication Library (NCCL) is a Cuda library for multi-GPU and internode communication between GPUs. We will be using the CuPy for getting NCCL operations into Python. NCCL alone is a lot more limited than MPI in terms of built-in operations because it is only designed for GPUs. Thus, I suggest that if you choose to develop NCCL applications, to also use MPI to make things like setting up processes and synchronizing processes easier. This can get complex because you can have multiple NCCL communicators on different Cuda streams assigned to one MPI process. But we are going to keep it 1 NCCL communicator per MPI process.
The wonderful thing about NCCL is that it can take advantage of hardware built onto GPUs to increase multi-GPU communication such as PCIe lanes/ NVLINK for intra-node communication and InfiniBand Verbs/ IP Sockets for inter-node communication. Additionally, NCCL does not require us to take the data off the GPU! Let’s check out some code.
import cupy as cp
from cupy.cuda import nccl
from mpi4py import MPI
import torch
# Setting up NCCL Communicator with MPI
mpi_comm = MPI.COMM_WORLD
size = mpi_comm.Get_size()
rank = mpi_comm.Get_rank()
# NCCL communication ID estabilished communication network
if mpi_comm.Get_rank() == 0:
nccl_comm_id = nccl.get_unique_id()
nccl_comm_id = mpi_comm.bcast(nccl_comm_id, root=0)
# Set active cuda device
cp.cuda.runtime.setDevice(rank)
# Init NCCL communicator
nccl_comm = nccl.NcclCommunicator(size, nccl_comm_id, rank)
def get_NCCL_dtype(self, dtype):
"""
Args:
dtype (torch.dtype): Tensor data type
"""
if dtype == cp.float32 or dtype == torch.float32:
return nccl.NCCL_FLOAT32
elif dtype == cp.float64 or dtype == torch.float64:
return nccl.NCCL_FLOAT64
# ... add more data types if needed ...
else:
raise ValueError("This dtype is not implemented.")
def average_gradients(model, device='cpu', stream = None):
"""
Args:
model (nn.Module): PyTorch model
device (torch.device, optional): PyTorch device.
"""
if stream is None:
stream = cp.cuda.Stream.null.ptr
else:
stream = stream.ptr
# Loop through model parameters
for _, param in enumerate(model.parameters()):
# If parameter has a gradient exchange
if not param.grad is None:
# Get NCCL data type and size of tensor
dtype = get_NCCL_dtype(param.grad.dtype)
size = torch.numel(param.grad)
recv_buff = torch.zeros_like(param.grad)
# allreduce(Send buffer, Recv buffer, Size, Data-type, Op, Stream)
nccl_comm.allReduce(param.grad.data_ptr(), recv_buff.data_ptr(), size, dtype, nccl.NCCL_SUM, stream)
param.grad = recv_buff / size
A little more involved.
With NCCL we need to set things up with MPI before we can get going.
Additionally, unlike mpi4py, we need to specify the datatype of our tensor and the size of it before the all_reduce
(standard practice).
Note that we are providing pointers to the all_reduce
call via param.grad.data_ptr()
and recv_buff.data_ptr()
allowing us to not have to offload the gradients to the GPU.
An interesting characteristic of NCCL is that by default it uses a ring-based collective for structuring communication not a tree. A ring-based communication network can be band-width optimal (the fastest) on many compute topologies [7]. Ring based communication works by chunking the data into a finite number of parts that are individually passed in a circular path between processes.

Note how ring communication naturally splits that data into small messages, which is ideal for big data structures (such as weights of a neural network) that will be throttled by bandwidth. Now the entire communication networks are not just one big ring. NCCL will develop this based on installed technology, but for example each node may be a separate ring and then there may be a larger ring for inter-node communication. That is beyond this blog post though.
IMPORTANT ENVIROMENT VARIABLES
- NCCL_SOCKET_IFNAME: Interface name(s) of the socket for internode communication. Use the command
ifconfig
to figure this out (eg. eth0, ens1f0). Can be ignored if you’re on a single node. See docs for more details. - NCCL_IB_DISABLE: Internode communication option. Setting this will disable the IB/RoCE transport interfaces and fallback to using IP sockets. Can be a useful fallback if things are failing between nodes.
Gloo (PyTorch)
Gloo is a collective communications library developed by Facebook with functions that are specifically for machine learning. By default, Gloo is designed for communication between CPUs but can be compiled to work with NCCL (we won’t be doing that). Considering Gloo designed specifically for deep learning, the focus is on providing collective algorithms. These functions that involve all processes in the network such as reduce, gather, scatter, broadcast, and barrier. From the C++ programming side of thing Gloo may offer a higher level easier way to implement parallel communication on top of existing packages.
Now Gloo doesn’t have an explicit Python wrapper we can use.
Rather we will use the next best thing: PyTorch.
After all PyTorch is developed by Facebook so it makes sense it would be supported.
Now this may seem backwards to the original message of this blog post, but we will still be manually implementing communication not using the easy way out with DistributedDataParallel
.
Let’s check out the code.
import os
import torch.distributed as torch_dist
from mpi4py import MPI
# Setting up torch distributed with MPI
mpi_comm = MPI.COMM_WORLD
size = mpi_comm.Get_size()
rank = mpi_comm.Get_rank()
backend = "gloo"
os.environ['MASTER_ADDR'] = '127.0.0.1' # IP of master task
os.environ['MASTER_PORT'] = '29500'
torch_dist.init_process_group(backend, rank=self.rank, world_size=self.size)
def average_gradients(model):
"""
Args:
model (nn.Module): PyTorch model
"""
# Loop through model parameters
for _, param in enumerate(model.parameters()):
# If parameter has a gradient exchange
if not param.grad is None:
torch_dist.all_reduce(param.grad, op=torch_dist.ReduceOp.SUM)
param.grad.data = param.grad.data / size
Wow that was easy compared to the others, however there are a few environment variables we need to make sure to set prior to initialization. Even with mpi4py’s pickling features, the convivence of Gloo being cooked right into PyTorch requires minimal work on our end. Similar to NCCL, Gloo tends to use ring-based communication, but has multiple approaches implemented based on what in the Github repo.
IMPORTANT ENVIROMENT VARIABLES
- MASTER_ADDR: The IP address of the master task. If you are just using intranode communication this can be set to the default
127.0.0.1
orlocalhost
. For internode communication you will either need to know this ahead of time or query it at run time to then broadcast to other tasks. Check blog code for example of how to dynamically do this. - MASTER_PORT: The port of the master task for communication to use.
29500
has worked for me thus far. - GLOO_SOCKET_IFNAME: Interface name(s) of the socket for internode communication. Use the command
ifconfig
to figure this out (e.g. eth0, ens1f0). Can be ignored if you’re on a single node.
NCCL (PyTorch)
Now that we know how to implement NCCL calls from the lower level Cuda calls with CuPy, lets now have a look at the PyTorch approach. Generally speaking this is exactly the same as the Gloo approach. Again, we will use MPI to set up the parallel network, and then leave all the heavy lifting to NCCL.
import os
import torch.distributed as torch_dist
from mpi4py import MPI
# Setting up torch distributed with MPI
mpi_comm = MPI.COMM_WORLD
size = mpi_comm.Get_size()
rank = mpi_comm.Get_rank()
backend = "nccl"
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29500'
torch_dist.init_process_group(backend, rank=self.rank, world_size=self.size)
def average_gradients(model):
"""
Args:
model (nn.Module): PyTorch model
"""
# Loop through model parameters
for _, param in enumerate(model.parameters()):
# If parameter has a gradient exchange
if not param.grad is None:
torch_dist.all_reduce(param.grad, op=torch_dist.ReduceOp.SUM)
param.grad.data = param.grad.data / size
The only difference is that the backend parameter is now nccl
.
Of course your probably wondering why even both with CuPy if its this easy in PyTorch.
Well CuPy gives you more control over the details of the communication.
You could create a really complicated parallel network of multiple processes and Cuda streams with CuPy that isn’t achievable in PyTorch.
IMPORTANT ENVIROMENT VARIABLES
- MASTER_ADDR: The IP address of the master task. If you are just using intranode communication this can be set to the default
127.0.0.1
orlocalhost
. For internode communication you will either need to know this ahead of time or query it at run time to then broadcast to other tasks. Check blog code for example of how to dynamically do this. - MASTER_PORT: The port of the master task for communication to use.
29500
has worked for me thus far. - NCCL_SOCKET_IFNAME: Interface name(s) of the socket for internode communication. Use the command
ifconfig
to figure this out (e.g. eth0, ens1f0). Can be ignored if you’re on a single node.
Results
For the code used to get these results please see the Github link after the conclusion. Additional quality of life parallel functions are implemented such as a broadcast call to distribute initial weights from the root. Additionally, there is a function used to compare weights between models on each process to ensure training consistency.
Franke’s Equation
As a sanity check to make sure everything is implemented we will consider a deep neural network to predict the Franke’s equation. This is just an arbitrary 2D function that has several modes used as a benchmark of interpolation methods. \[ \begin{split}f(x,y) = 0.75\textrm{exp}\left(-\frac{(9x - 2)^{2}}{4}-\frac{(9y - 2)^{2}}{4}\right) + 0.75\textrm{exp}\left(-\frac{(9x + 1)^{2}}{49}-\frac{(9y + 1)^{2}}{10}\right) \\ + 0.5\textrm{exp}\left(-\frac{(9x -7)^{2}}{4}-\frac{(9y -3)^{2}}{4}\right) - 0.2\textrm{exp}\left(-(9x -4)^{2}-(9y -7)^{2}\right) \end{split}\] Our neural network will just be a simple fully-connected. This problem will give us a good base line to explore how these different parallel training frameworks compare on a low-dimensional setting.
Model Parameters | Training Parameters | ||
---|---|---|---|
Network | $2\rightarrow 100\rightarrow 100\rightarrow 1$ | Training Cases | 10000 |
Activation | ReLU | Testing Cases | 1000 |
Mini-batch Size | 256 | ||
Epochs | 100 |
Since parallel training for this problem is overkill we will focus on the concept of weak scaling which is used to assess how communication overhead scales. We will keep the global batch size the same while increasing the number of GPUs, meaning each model will be updated exactly the same number of times. This allows for easy verification that each communication implementation is working identically.

We measure the performance of each parallel library using the time it takes to train for 100 epochs. Each parallel configuration is ran three times from which the mean and std are plotted in the figure below. Overall, both NCCL implementation perform the best with minimal decay in performance due to communication overhead with CuPy implementation being the best. Both MPI and Gloo show a decay in performance as the number of GPUs increase, with Gloo performing the worst. It’s interesting that MPI performs better than Gloo here because this will not be the case for the next example.

Additionally this is a great example that more does not always equal better. Here we reach a saturation in performance past 4 GPUs, there is no advantage going past that amount in this case. So benchmark!
MNIST
With the implementations verified, we will look at training a MNIST variational auto-encoder as fast as possible. We will assume that the VRAM is limited on our GPU and increase batch size as we increase the number of GPUs we use, which is a realistic scenario to be in. (Side node: Bigger batch size isn’t always better! So adding GPUs may hurt your predictive accuracy.) Not exactly realistic with current hardware but this is just for illustrative purposes.
Model Parameters | Training Parameters | ||
---|---|---|---|
Encoder | $784 \rightarrow 512 \rightarrow 256 \rightarrow 4$ | Training Cases | 60000 |
Decoder | $2\rightarrow 256 \rightarrow 512 \rightarrow 784$ | Testing Cases | 10000 |
Activation | ReLU | Mini-batch Size | 64, 128, 256, 512 |
Epochs | 300 |
Considering this is just a fully-connected auto-encoder, one should not expect any results too revolutionary. Below are a few reconstructions using the trained model as well as some results from interpolating around the learned latent space.


The wall-clock times for training this fairly small model are plotted below. Once again we see the NCCL implementations perform much better here with MPI and Gloo being worse than serial. This is because of the communication overhead. In reality MNIST has a small enough dimensionality to use just a single GPU just fine. For a data-set with an input that has a much larger memory footprint, parallelizing your training will be essential. But once again, remember to benchmark. You may get a saturation in performance, in particular when you go to inter-node communication.

Final Thoughts
Implementing your own parallel communication is actually pretty easy if you have a guide. But does that mean you should do it yourself? Well that depends… if you have some complicated training process or looking to explore parallel optimization then 100%. But if you’re just sticking with some standard model with supervised learning, well, just stick with PyTorch distributed if it works. But keep in mind, sometimes what the PyTorch docs may recommend may not be ideal for your hardware! If your model is simple enough, follow one of the 100 other parallel tutorials on the Distributed Data Parallel class in PyTorch. DDP does have some minor optimizations under the hood to make things go a little faster, but nothing completely game changing. Hope you enjoyed and learned something.
A big thanks to Oliver from Nvidia for giving me the idea to play with NCCL in the first place.
Stay critical and thanks for reading!
Code
You can grab all the code for this post on Github.
References
- D. Povey, X. Zhang, S. Khudanpur, Parallel training of deep neural networks with natural gradient and parameter averaging, ArXiv Preprint ArXiv:1410.7455. (2014). [Link]
- B. McMahan, E. Moore, D. Ramage, S. Hampson, B.A. y Arcas, Communication-efficient learning of deep networks from decentralized data, in: Artificial Intelligence and Statistics, PMLR, 2017: pp. 1273–1282. [Link]
- B. Recht, C. Re, S. Wright, F. Niu, Hogwild!: A Lock-Free Approach to Parallelizing Stochastic Gradient Descent, in: J. Shawe-Taylor, R. Zemel, P. Bartlett, F. Pereira, K.Q. Weinberger (Eds.), Advances in Neural Information Processing Systems, Curran Associates, Inc., 2011. [Link]
- I.J. Goodfellow, O. Vinyals, A.M. Saxe, Qualitatively characterizing neural network optimization problems, ArXiv Preprint ArXiv:1412.6544. (2014). [Link]
- C.M. Bishop, Pattern Recognition and Machine Learning (Information Science and Statistics), Springer-Verlag New York, Inc., Secaucus, NJ, USA, 2006.
- A. Bienz, L. Olson, W. Gropp, Node-aware improvements to allreduce, in: 2019 IEEE/ACM Workshop on Exascale MPI (ExaMPI), IEEE, 2019: pp. 19–28. [Link]
- P. Patarasuk, X. Yuan, Bandwidth optimal all-reduce algorithms for clusters of workstations, Journal of Parallel and Distributed Computing. 69 (2009) 117–124. [Link]