Have We Been Thinking About Batch Sizes Wrong This Entire Time?

season 4 episode 3 GIF by SpongeBob SquarePants

BatchNorm was introduced in a 2015 paper by Sergey Ioffe and Christian Szegedy, both Google researchers at the time.

They introduced a technique for normalizing activations in each neural network layer to stabilize the training process and improve a network’s ability to generalize to new data.

Since its introduction, BatchNorm has become a widely used technique in deep learning. It has been included in many state-of-the-art models across various tasks, such as image classification, object detection, and machine translation.

It’s now considered a fundamental deep-learning technique in nearly all CNN architectures.

Why Do You Need BatchNorm?

Confused Liam Neeson GIF

A major challenge when training deep neural networks is the distribution of activation values.

These values can vary wildly as data flows through your network. Especially in early layers, where activation distributions might drastically differ from the input data. These changes in distribution will inevitably cause your gradients to vanish or explode during the training process, making it difficult for your network to converge.

Let me explain using an analogy where a building represents your network, the foundation of the building represents your activations, and the labourers represent your neurons.

crowd construction site GIF by South Park

Imagine you’re at a construction site where labourers are laying the foundation of a building.

I’m sure you’d agree that the foundation is the most critical component to the stability of any building. It must be level. Suppose you observe these labourers working independently, each with their tools and materials, and focused only on their small patch of foundation without consideration of the other labourers.

What do you think would happen?

I’d wager the foundation ends up uneven. You might notice that some parts of the foundation are higher or lower than others. Or it might not be noticeable until the higher floors are constructed. By that time, the entire building will be unstable and at risk of toppling over.

If the labourers (the neurons) aren’t in sync, the foundation (the activations) will be uneven, and the building (the network) will be unstable.

BatchNorm is like a foreman on the construction site that regulates the output of the labourers. Ensuring the labourers are in sync and laying a level foundation (has a consistent mean and variance) ensures the foundation of the building (your network) won’t be a modern-day Leaning Tower of Pisa.

episode 8 GIF

If activations are level, you end up with a more stable training process, faster convergence, better overall performance, and less risk of overfitting.

How Does BatchNorm Work?

Figure It Out What GIF by CBC

BatchNorm works by normalizing the activations in each layer to a consistent mean and variance.

All you do is subtract the mean and divide by the standard deviation of the activations for each mini-batch. If you’re working with image data, you will do this separately for each activation channel. For example, each filter would have its own mean and standard deviation in a convolutional layer.

There are also two parameters learned during backpropagation that correct for any discrepancies between training and test data: scale and shift.

The network can’t learn very well when activations in a layer have a small variance. This is where the scale parameter comes in. It’s a multiplicative factor that scales the variance of the activations, making them larger or smaller as needed. The network also has difficulty learning when activations in a layer have a mean difference from zero. That’s where the shift parameter comes in. You add a value to the normalized activations that shift toward zero.

Where in a network BatchNorm should reside is a hotly debated topic.

Debate Argue GIF by 505 Games

The original paper (and Andrew Ng) say you should apply BatchNorm to the outputs of a layer before it’s passed to an activation function. That’s because you get the great benefit from activation functions (like ReLU or tanh) when the values passed to them are in a small range centered around 0. This helps stabilize the training process and improves the network’s ability to generalize.

You can apply BatchNorm to any type of layer - including fully connected, convolutional, and recurrent layers.

The details will vary depending on the type of layer and the architecture of your network. For example, in a convolutional neural network (CNN), you might apply BatchNorm to each convolutional layer’s outputs before passing through a non-linear activation function. In a fully connected network, apply it to the outputs of each fully connected layer.

There’s no hard rule on where it should be applied, though it’s commonly applied after linear or convolutional layers.

You also don’t need to apply it to every single layer. You can choose a specific network layer like the last fully connected layers or only on bottleneck layers of the network.

Make sure you use it in layers where the activations have a large variance or your network isn’t converging as expected.

The Power of BatchNorm…

He Man Reaction GIF

There’s a reason BatchNorm is in nearly all state-of-the-art CNN architectures.

  1. It improves stability. It helps prevent your model from oscillating or diverging during training.

  2. It improves generalization (which is a nice side effect). Normalizing activations within a network prevents the model from relying too much on any feature or pattern in the data, making the model more robust.

  3. It improves training speed. When distributions aren’t changing drastically during training, it prevents something called internal covariate shift, and your model converges faster.

  4. It lets you use higher learning rates and makes the training process less sensitive to weight initialization and hyperparameter tuning.

  5. It reduces the effect of large outlier gradients during gradient descent and diminishes the gradients’ reliance on the parameters’ scale or their underlying values.

  6. Reduces the need for dropout. Because values going into an activation function are normalized to a mean of zero, BatchNorm gives you some of the regularization benefits of dropout. You can either go with a lower dropout rate or skip a dropout layer altogether.

If you want to go deep on these points and get into the theory of how it helps optimization, check out this 2018 paper presented at NeurIPs.

The Thing About BatchNorm

Angry Michael Scott GIF

BatchNorm works, but it’s not exactly understood why it works.

Yann LeCun once referred to BatchNorm as a “necessary evil” in the design of convolutional neural networks (CNNs) due to its potential drawbacks and limitations.

A paper published in May 2021, titled “Rethinking ‘Batch’ in BatchNorm” by Yuxin Wu and Justin Johnson from Facebook AI Research (FAIR), delved into the inner workings of BatchNorm. The authors point out that BatchNorm operates on batches of data rather than individual samples and ask the question (I paraphrase) …

What even is a batch?

What Is This Wtf GIF

There are so many different choices of what the “batch” is to normalize over include.

BatchNorm can compute statistics from the entire dataset, mini-batches during SGD, or a subset of mini-batches. There’s an added challenge when inputs come from different domains. Because statistics can be computed from them, combined from each domain separately, or a mixture of samples from each domain.

The result is a normalization process whose output depends on how samples get grouped into batches. In the paper, the authors highlight the negative impacts on model performance, if not carefully considered:

• BatchNorm mixes information across batches and computes normalization statistics; other operators independently process each sample in a batch.

• Not considering different choices for batch construction can negatively impact model performance.

• When model test time comes around, there is usually no concept of a mini-batch anymore. At test time, features get normalized by population statistics computed over the entire training set using the population statistics. That is, using the whole population as a “batch.”

The authors also discuss the perils of using the Exponential Moving Averaging (EMA) algorithm to estimate population statistics. EMA may lead to a suboptimal estimate of population statistics for the following reasons:

• Slow convergence of statistics when momentum is large

• EMA statistics become dominated by a smaller number of recent mini-batches and don’t represent the whole population.

• EMA cannot accurately represent mini-batch statistics or population statistics in the early stage of training.

As a result, using EMA as the “population statistics” may harm the model’s accuracy.

The authors also analyze the impact of these issues on model accuracy and suggest different alternatives to EMA to overcome these limitations, such as using moving statistics or a control variate method. You can read more about the relationship between BatchNorm and EMA here.

But they do discuss a solution…

Have You Met Precise BatchNorm?

Ken Jeong Asian GIF

The authors propose a solution to the limitations of traditional BatchNorm using a method called “Precise BatchNorm.”

Precise BatchNorm computes more accurate population statistics by using two steps:

  1. Applying a fixed model* on multiple mini-batches to collect batch statistics.

  2. It aggregates these per-batch statistics to form population statistics.

A fixed model is a trained model whose parameters are no longer changing. Using a fixed model ensures that computed statistics accurately represent the entire population.

Precise population statistics maintain two important properties of true population statistics that are different from Exponential Moving Averaging (EMA):

  1. Statistics are computed entirely from a fixed model state, unlike EMA, which uses historical states of the model

  2. All samples are equally weighted.

Another limitation of traditional BatchNorm is its imprecision when dealing with small batch sizes, leading to inaccurate population statistics.

Precise BatchNorm aims to solve this problem by using larger batch sizes, ensuring that the computed statistics are more accurate. The authors say a batch size of B > 100 appears to be sufficient in their experiments. They also note that EMA is not a good choice when estimating population statistics because it relies on the same mini-batch size used in Stochastic Gradient Descent (SGD) during training.

The authors emphasize that recomputing population statistics like Precise BatchNorm is the original formulation of BatchNorm; however, doing it this way hasn’t been widely adopted in the field. For information in addition to the original paper, check out this slide deck by one of the authors, Yuxin Wu.

The team at Deci AI has integrated Precise BatchNorm in the SuperGradients training library; let’s check it out in action.

In this toy example, using the FashionMNIST, I set the batch size to a small value to mimic the scenario where I can’t fit a large batch on a GPU (not the case here since these are small images). I then experimented with Precise BatchNorms of varying orders of magnitude to see the impact on training.

Check out the notebook here

I think you might be interested in this @salmankhaliq22 , @sGx_tweets , @chris , @lu.riera , @ZubiaMansoor, @Raouf , @mjcullan

1 Like