Tips for training your neural networks

I asked a dozen Deci deep learning legends to share their favorite neural network training tips…

These 5 came up most frequently:

  1. Exponential moving average
  2. Weight averaging
  3. Batch accumulation
  4. Precise batchnorm
  5. Zero weight decay on batch norm and bias

Let me break these down for you :point_down:t4:

1) Exponential Moving Average

When you’re training a neural network chances are you’re using mini-batches.

Nothing wrong with that, it just happens to introduce noise and less accurate gradients when gradient descent updates model parameters between batches. On one hand, thats nice because noisy gradients can sometimes help optimization and lead to a better local optimum than if you trained on the entire data set. On the other hand, the noisiness might lead to converging to a false local minima.

Luckily you have the EMA method at your disposal.

Here’s how it works:

  • Let W_{m} be the current set of weights after performing an optimization step
  • Copy those weights before the next optimization step
  • Take a weighted average of weights you just copied and the weights from the previous step
  • This weighted average the update at the current step

Here’s what it looks like more formally:

W^{ema}_{n} = (1 - a(n))*W_{n} + a(n)*W^{ema}_{n-1}

EMA makes your models more stable, improves convergence, and helps your network find a better solution.

2) Weight Averaging

Everyone likes a free boost in model accuracy.

And that’s weight averaging gives you. It’s a post-training method that takes the best model weights across training epochs and averages them into a single model. By averaging weights for N best checkpoints we’re effectively making an ensemble of N models. Is it not exactly the same as having N models and averaging their predictions - which comes at the price of running inference on N models - but it could help with squeezing out some extra accuracy.

It does this by overcoming the optimization tendency to alternate between adjacent local minimas in the later stages of the training. It also has the added benefit of reducing bias.

This trick doesn’t affect the training at all, it just keeps a few additional weights on the disk, but can give you a power boost in performance and stability.

3) Batch Accumulation

Have you ever tried to amend a recipe to fit the ingredients, cookware, and oven you have on hand?

It’s not an easy thing to do!

Most ‘off the shelf’ models come with a suggested training recipe. Which usually suggest a powerful GPU for training. I don’t know about you, but I don’t have powerful GPUs just laying around and I’m not about to pay for a premium Colab subscription. If you just try to reduce your batch size so it works with your hardware, you’ll have to tune other parameters as well.

Which means you won’t always get the same training results.

There’s got to be a way to train a model thats appropriate for your target hardware. That’s where batch accumulation comes it. Here’s how it works…

  1. Perform several consecutive forward steps over the model
  2. Accumulate the gradients
  3. Backpropagate them once every few batches.
  4. Then you draw the rest of the owl. Just kidding, you can see how it works with code here.

4) Precise BatchNorm

BatchNorm is a wonderful invention.

Ever since it hit the scene in 2015 its been making models less sensitive to learning rates and choice of initialization. It’s also helped speed up model convergence. Hell, it’s even helped wage war against overfitting.

It’s no wonder it’s been used in nearly ever state of the art CNN in recent years.

BatchNorm does, however, have it’s problems…

Batch normalization in the mind of many people, including me, is a necessary evil. In the sense that nobody likes it, but it kind of works, so everybody uses it, but everybody is trying to replace it with something else because everybody hates it - Yann LeCun

[It’s ] a very common source of bugs – CS231n 2019 Lecture7

Why does BatchNorm catch such flack?

Well, BatchNorm layers are meant to normalize the data based on the dataset’s distribution. Ideally, you want to estimate the distribution according to the entire dataset. But this isn’t possible. So, BatchNorm layers are used to evaluate the statistics of a given mini-batch throughout the training.

But a 2021 paper by Facebook AI Research titled Rethinking “Batch” in BatchNorm showed that these mini-batch based statistics are sub-optimal.

The researchers propose estimating the data statistics parameters (the mean and standard deviation variables) across several mini-batches, while keeping the trainable parameters fixed.

This method, titled Precise BatchNorm, helps improve both the stability and performance of a model.

5) Zero-weight Decay on BatchNorm and Bias

Most computer vision tasks have BatchNorm layers and biases along with linear or convolution layers.

This is tends to work well because you’ll have more parameters in your model. More parameters means more ways to capture interactions between parts of your network. More parameters, however, also means more opportunities to overfit your model.

This is where a regularization technique called weight decay comes into play.

Weight decay (aka L_{2} regularization) helps reduce the complexity of a model and prevent overfitting by modifying the update rule for weights in the following way:

L_{new}\left(w\right) = L_{original}\left(w\right) + \lambda{w^{T}w}

Where \lambda is value that determines the strength of the penalty and encourages smaller weights.

This method essentially pulls the weights towards 0. This is great for convolutional and linear layer weights. No so much for BatchNorm layers.

BatchNorm scales and shifts the normalized input to a layer. Forcing these values to a lower value would affect the distribution and result in suboptimal results.

This blog post sheds some more light on the peculiar effects of BatchNorm.

Now I know you’re probably asking yourself…

“How can I implement these tricks?”

You could try to find code online that does all this for you, search StackOverflow, or copy paste form others, or…

You could use SuperGradients.

SuperGradients contains all the tricks listed above and more, so you can use it immediately to train deep learning models to their best possible accuracy and beyond.

Super what?


SuperGradients, a neural network training package that is constantly being updated with the best recipes, the latest cutting-edge models, and the newest training techniques. SuperGradients makes it easier to keep your production models up-to-date and allows you to use the optimal research and practices for your computer vision tasks.

Check it out and smash a star:

What other tips would you add. Seeing if @lu.riera, @chris, @LisaLi, @yashikajain201, @sGx_tweets, @kbaheti, @EKhvedchenya, or @iseeag have any thoughts here.