**To an extent, the latest advancements in Machine Learning are only possible because of ****dedicated hardware**** and methods to ****parallelize model training****. Let’s dig today into the following aspects:**

**Data Parallelism****Model Parallelism****Why is a TPU faster than a GPU?**

## Data Parallelism

Let's speed up our TRAINING a notch! Do you know how the back-propagation computation gets distributed across GPUs or nodes? The typical strategies to distribute the computation are data parallelism and model parallelism.

The steps for centralized synchronous data parallelism are as follows:

A parameter server is used as the ground truth for the model weights. The weights are duplicated into multiple processes running on different hardware (GPUs on the same machine or on multiple machines).

Each duplicate model receives a different data mini-batch, and they independently run through the forward pass and backward pass where the gradients get computed.

The gradients are sent to the parameter server where they get averaged once they are all received. The weights get updated in a gradient descent fashion and the new weights get broadcast back to all the worker nodes.

This process is called "centralized" where the gradients get averaged. Another version of the algorithm can be "decentralized" where the resulting model weights get averaged:

A master process broadcasts the weights of the model.

Each process can run through multiple iterations of forward and backward passes with different data mini-batches. At this point, each process has very different weights.

The weights get sent to the master process, they get averaged across processes once they get all received, and the averaged weights get broadcast back to all the worker nodes.

The decentralized approach can be much faster because you don't need to communicate between machines as much, but it is not a proper implementation of the back-propagation algorithm. Those processes are synchronous because we need to wait for all the workers to finish their jobs. The same processes can happen asynchronously, only the gradients or weights are not averaged. You can learn more about it here: “Distributed Training of Deep Learning Models: A Taxonomic Perspective”.

When it comes to the centralized synchronous approach, PyTorch and TensorFlow seem to follow a slightly different strategy [1] as it doesn't seem to be using a parameter server as the gradients are synchronized and averaged on the worker processes. This is how the PyTorch *DistributedDataParallel* module is implemented [2], as well as the TensorFlow *MultiWorkerMirroredStrategy* one [3]. It is impressive how simple they have made training a model in a distributed fashion!

## Model parallelism

How would you go about training a 1 trillion parameters model? Turns out there is a bit more to it than just "*model.fit(X, y)*"! I think there is a trend that is starting to emerge where new scaling strategies are getting more intricate with the modeling aspect due to the unprecedented scale of the latest Machine Learning models.

Model Parallelism is a typical paradigm where the model itself is spread across multiple GPUs simply because it is too big to fit on one machine. Just break down the network into small pieces, have each piece on a different GPU, and build connections between the GPUs to communicate the inputs, outputs and gradients. You can mix that process with Data Parallelism where you duplicate that process to have it run on parallel batches of data. Those are the scaling strategies used when models like GPT-3 are trained.

But have you seen the price of a GPU? I have seen estimates of the order of $12M just to train GPT-3. How do you scale beyond that? For example the Megatron-2 [4] and Google's Switch model [5] have 1 trillion parameters, and they have different strategies to scale on a GPU cluster while minimizing on the cost.

I like what the guys at PyTorch did with the Fully Sharded Data Parallel module [6]. This is a simple API to efficiently manage GPU memory in the model parallelism paradigm. Instead of having the whole model always loaded on the GPUs, you only load blocks of the model when you perform a computation on them (forward or backward passes). You just cache the inputs and outputs of those blocks and push them back to the CPU once you finish the local computation.

Here is the experiment where Meta and AWS partnered to train a 1 trillion parameters GPT-3 model on a cluster of 512 GPUs: “Training a 1 Trillion Parameter Model With PyTorch Fully Sharded Data Parallel on AWS“. They estimated that it would take 3 years to actually complete the full training!

## Why is a TPU faster than a GPU?

Why are TPUs (Tensor Processing Unit) faster than GPUs? Well depending on your use-case, that may not be the case! TPUs are only effective for large Deep Learning models and long model training time (weeks or months) that require ONLY matrix multiplications (Matrix multiplication means highly parallelizable). For example, you may not want to use LSTM layers on TPU since it is an iterative process (well actually since the advent of Transformers, you may not want to be using LSTM period!).

There are a lot of downsides to using TPU. TensorFlow is the main framework that can run on TPU. PyTorch can run on it since TPU v3, but it is still considered experimental and may not be as stable or feature-complete as using TensorFlow with TPUs. Even with TensorFlow, you cannot use custom operations. A TPU is also quite expensive but cheaper than a GPU now! Renting a TPU v4 machine will cost minimum $12.88 / hour [7] where a Nvidia A100 GPU will cost $15.72/ hour for 4 GPUs [8].

But TPUs are much faster [9]! For example, this blog shows that it can be up to 5 times faster than GPU: “When to use CPUs vs GPUs vs TPUs in a Kaggle Competition?“. A CPU processes instructions on scalar data in an iterative fashion, with minimal parallelizable capabilities. GPU is very good at dealing with vector data structures and can fully parallelize the computation of a dot product between 2 vectors. Matrix multiplication can be expressed as a series of vector dot products, so a GPU is much faster than a CPU at computing matrix multiplication.

A TPU uses a Matrix-Multiply Unit (MMU) that, as opposed to a GPU, reuses vectors that go through dot-products multiple times in a matrix multiplication, effectively parallelizing matrix multiplications much more efficiently than a GPU. More recent GPUs are also using matrix multiply-accumulate units but to a lesser extent than TPUs.

Only deep learning models can really utilize the parallelizable power of TPUs as most ML models are not using matrix multiplications as the underlying algorithmic implementation (Random Forest, GBM, KNN, …).

You can read about the different TPU architectures here: “System Architecture”. Google provides a nice deep dive into TPU: “An in-depth look at Google’s first Tensor Processing Unit (TPU)”. You can read the original TPU paper: “In-Datacenter Performance Analysis of a Tensor Processing Unit“.

PyTorch’s centralized synchronous approach: https://pytorch.org/docs/stable/notes/ddp.html

PyTorch DistributedDataParallel module: https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html

TensorFlow MultiWorkerMirroredStrategy module: https://www.tensorflow.org/api_docs/python/tf/distribute/MultiWorkerMirroredStrategy

*Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM*by Deepak Narayanan et al: https://arxiv.org/pdf/2104.04473.pdf*Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity*by William Fedus et al: https://arxiv.org/pdf/2101.03961.pdfPyTorch Fully Sharded Data Parallel module: https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/

TPU pricing: https://cloud.google.com/tpu/pricing

GPU pricing: https://cloud.google.com/compute/gpus-pricing

*Benchmarking TPU, GPU, and CPU Platforms for Deep Learning*by Yu Wang el al: https://arxiv.org/pdf/1907.10701.pdf

Very instrutive

it would be so nice with some code example (ex : Deepspeed, LoRa, ...) if possible