Sharpness-Aware Minimization for Efficiently Improving Generalization

Written by Michael (Mike) Erlihson, Ph.D.

This review is part of a series of reviews in Machine & Deep Learning that are originally published in Hebrew, aiming to make it accessible in a plain language under the name #DeepNightLearners.

Good night friends, today we are again in our section DeepNightLearners with a review of a Deep Learning article. Today I've chosen to review the article Sharpness-Aware Minimization for Efficiently Improving Generalization.

Reviewer Corner:

Reading recommendation From Mike: A must if you're interested in the "behind the scenes" process of training a neural network
Clarity of writing: Very High
Math and Deep Learning knowledge level: Good knowledge with optimization methods for multi-parametric problems.
Practical applications: Improving neural network generalization capabilities by replacing the common loss minimization algorithm with SAM

Article Details:

Article link: available for download here
Code link: availablte here
Published on: 03/10/2021 on Arxiv
Presented at: ICLR 2021

Article Domains:

  • Neural network training optimization method research

Mathematical Tools, Concepts, and Marks:

The Article in Essence:

The article reformulates the optimization process, which occurs during neural network training. Instead of finding a weight vector, which successfully minimizes the loss function for the given training set, the article suggests solving a different optimization problem, locating the loss function minima surrounding. Instead of using Gradient-Descent to determine the function global (absolute) minima and to update the network weights towards that direction, their algorithm is calibrating the network, so that also around that point, the loss function will output minimal values.

Furthermore, the article rigorously proves that solving this offered optimization problem, called Sharpness Aware Minimization - SAM positively contributes to the generalization capabilities of the trained model.

Basic Idea

As you probably know, most modern neural networks are significantly over-parameterized. Meaning, optimizing the network weights to minimize a loss function into absolute minima may lead to over-fitting, hence losing generalization capabilities. The main reason for it is the complex and non-convex geometric structure of the loss surface. A classic example is when the loss function minima are very "sharp". The loss function values will be very high even around the global minima, and not only at that point. Such sharp global minima may be a result of noisy data and will lead to over-fitting. The article offers a solution for this. They reformulate the optimization problem, so that it takes into account the loss function outputs of the surrounding points of potential global minima, and not only at those minima alone. Their proposed reformulation of the problem is explicitly considers the geometric properties of the loss surface around the minima points.

Deep Dive

Many different methods attempt to increase neural network model generalization capabilities. These methods can be clustered into two:

  1. Changes in the training process
    - Such as early-stopping, BatchNorm, stochastic depth, data augmentation, and many others.
  2. Optimizer modifications
    - ADAM, RmsProp, adding Momentum etc.

These methods are trying to solve the same optimization mechanism of minimizing the loss function. However, this article suggests replacing the optimization problem formulation itself (!!!).

Technical Details

The suggested Loss function L contains two items. The first is the maximal loss in a small surrounding ε of the point w (the size of this surrounding ε is a hyper-parameter) and the second item is a standard regularization item with Lp (L2) Norm, similar to the proximal point optimization method. Lp can be written as the sum of the differences between the maximal value of the loss function around w (in the article, it is called "sharpness") and a new regularization item, which is the sum of Lp Norm and the weight vector w and the loss value at the point w.

Theoretical aspects

The article proves that for a given training set, the SAM loss in every point w is an upper bound of the population loss with high probability (in the article, this is proven for a general case and works on a broader set of regularization functions). Of course, everything is under technical conditions of the distribution, from which the dataset is probed. It means that solving SAM leads to a more generalizing model. The proof is not trivial and includes PAC-bayesian generalization bounds.

Solving SAM problem

First, a first-order Taylor expansion is used to find the point around w, which maximizes the loss. Then, the problem is translated to the classic dual norm problem, for which there is a specific solution e_w. After placing e_w in the SAM formula expression, we get a standard optimization problem (minimization problem with a cost function L(e_w)) which can be solved as usual using gradient descent. Since e_w contains the original loss function L gradient, L(e_w) contains the Hessian Matrix of L. Calculating the Hessian when w has hundreds of millions of components is a computationally and memory-heavy task. But luckily, this expression includes the multiplication of the Hessian in a vector, which allows computing the gradient L(e_w) value without calculating the Hessian. Eventually, their algorithm can be run similarly to Gradient Descent with automatic differentiation tools, such as TensorFlow or PyTorch.

Achievements

The article successfully demonstrates that the proposed method performs better than previous optimization methods, such as different augmentation types, other optimizer. etc.) on a wide range of datasets and neural network architecture. In each comparisson, they just replaced the original optimization method with SAM and compared the performance on the test set. Furthermore, they compared SAM performance on datasets with noisy labels, and identified the changes in the eigenvalues of the Hessian Matrix of SAM problem's solution.

Results for SAM on state-of-the-art models on CIFAR-{10, 100} (WRN = WideResNet; AA = AutoAugment; SGD is the standard non-SAM procedure used to train these models).

Noisy Labels

SAM showed a significant improvement when it was executed during training on datasets with noisy labels. This shouldn't be a surprise, since the main strength of the algorithm is preventing convergence in sharp minima. The existence of noisy labels can easily lead to such minima in classic optimization algorithms.

Hessian Matrix structure around optima points

To prove their assumptions regarding SAM capabilities in preventing sharp minima, the article examined the Hessian matrix eigenvalues in optima points that were located by SAM, and compared it to such localities which were found by other algorithms. Specifically, they looked into maximal eigenvalue and the relation between these maximal eigenvalues to other relatively high eigenvalues. We know that the sharper the minimum is, the higher the Hessian matrix eigenvalues would be, as well as the ratio between the highest eigenvalues to the second-highest ones. The article shows that by using SAM these metrics are lowered significantly.

Datasets

CIFAR10, CIFAR100, Flowers, Stanford_cars, Birdsnap, Food101, Oxford_IIIT_Pets, FGVC_Aircraft, Fashion-MNIST and few others.

Network Architectures

Wide-ResNet-28-10, Shake-Shake , EffNet, TBMSL-Net, Gpipe and few others

P.S.

This very relevant article suggests an unusual method to improve neural network regularization abilities. I believe that this method has a great potential to enter the standard neural network training toolkit. I was also impressed by the detailed comparison to other methods.

#deepnightlearners

This post was written by Michael (Mike) Erlihson, Ph.D.

Michael works in the cyber-security company Salt Security as a principal data scientist. Michael researches and works in the deep learning field while lecturing and making scientific material more accessible to the public audience.