Representation learning via invariant causal mechanisms
Written by Michael (Mike) Erlihson, Ph.D.
This review is part of a series of reviews in Machine & Deep Learning that are originally published in Hebrew, aiming to make it accessible in a plain language under the name #DeepNightLearners.
Good night friends, today we are again in our section DeepNightLearners with a review of a Deep Learning article. Today I've chosen to review the article Representation learning via invariant causal mechanismsֿ
Reviewer Corner:
Reading recommendation From Mike: Recommended if you are into representation learning and have basic knowledge in causality theorem
Clarity of writing: Medium plus
Math and Deep Learning knowledge level: Basic knowledge with representation learning and causality theorem tools
Practical applications: Performance improvements for every Noise-contrastive estimation (NCE) based representation learning
Article Details:
Article link: available for download here
Code link: Not available
Published on: 15/10/2020 on Arxiv
Presented at: ICLR 2021
Article Domains:
- Representation learning
- Causality Theorem
Mathematical Tools, Concepts, and Marks:
- Probabilistic model causality graph
- Noise-contrastive estimation (NCE)
- infoNCE - Contrasive Prediction Coding
- Kullback Leibler (KL) divergence - distance between distributions
- Task refinement
The Article in Essence:
The article suggests a method (called RELIC) to construct data representation in lower dimensionality. The idea is a generalization of InfoNCE, and is executed by adding a regularization item to the loss function. This regularization item goal is to ensure that the representation distribution similarities is invariant under different augmentations which are activated on these examples. In the article, it is also called style-changes and I will be using both forms in this review. More on that, later.
So let's understand first how does this additional item contributes to the loss. All the NCE-based methods are bringing similar samples closer together in the representation vector space. For images, this distance is defined by semantic or content similarity. Augmentations, such as rotation and cropping, don't affect much the image vector representation. The suggested regularization item "forces" the representation to be invariant to non-semantic changes which have no real effect on the distance (e.g. style changes). In other words, augmented image representations are forced to maintain the same distance from other image representations, the same way as the original image (before augmentations). It is a significant addition to the original NCE loss because it 'forces' the representations to be invariant to image the image style, and represent only the image content. Hence, leading to a more relevant and correlative representation in downstream (content-related) tasks. This is the base assumption of the article.
Basic idea
The article's basic idea is built on 3 core assumptions, which treat image creation as a causal graph.
Image creation process:
- An image is created from two latent variables: the content - C, and the style - S
- The variables S and C are independent (the content is independent of the style).
- Only the image content is relevant for downstream tasks, for which the representation is built. The image style is irrelevant for these tasks, changing it won't change the tasks outcome Y_t.
For example, if we have a classification task with two classes (i.e., cats and dogs), the animal body would be the content, while background, lighting, camera lens and such, would attribute to the style.
Under these assumptions, the image content is a good representation for downstream tasks. As a result, the goal of the representation learning is to evaluate the image content. Remember that the content of an image X has all the relevant information for the task inference, and therefore it must remain invariant to style changes.
A short explanation of the article basic concepts
One of the basic concept in the article is NCE based representation learning. Let's refresh what it is about.
NCE methods: NCE is based on the assumption that data representation necessarily can differentiate between similar and randomly selected examples. A famous example of this technique is negative sampling which was used in word2vec. It is possible to prove that for InfoNCE, a specific form of NCE Loss, the smaller the loss is, the bigger the mutual information is, between the original sample and its low-dimension representation. This points on less information lost during the transformation between the original data to its representation. By the way, according to the article, other researchers argue that the downstream task performance depends more on the encoder architecture and less on the mutual information. It's important to mention that the training occurs in the representation space, and not on the original space. That is, the loss is calculated on the low-dimension representation. NCE-Loss takes a pair of similar samples and many random samples, and maximizes the division of the similar pair distance and the sum of the distances between the pair and all the other random samples.
The article essence
To understand the article, we first must understand what task refinement means.
Task refinement: a rigorous definition of this term is taken of causality theorem, but for the sake of simplicity, I will explain it using an example. A classification task YR between different types of dogs (or different types of cats) is a refinement of cats and dogs classification task Yt. A good (enough) representation for YR contains enough information to also to perform well for Yt.
Why is it so important, you ask?
The discrimination task between different image contents, the way it is done in NCE based methods, is the most refined task for a given data-set. This is the additional reason1 that representations which were learnt this way, were proven useful for different downstream tasks. The article proves, that a style-invariant representation for task YR remains invariant for every task Yt for which YR is its refinement. Therefore, if we can learn a representation that can discriminate contents regardless to the style, this representation would also work well on content-based downstream tasks.
So adding a regularization item to the InfoNCE loss contributes to the image's style-independent representation. We require that the images' representations will remain close even when undergoing different style changes, while far of images with different content.
1 - There are other explanations that connect this method to maximizing the mutual information between the data and its representation.
Let's figure out what this new regularization item is all about:
Regularization item calculation:
- We build two sets of augmentation (style change) A1 & A2 . where each set is composed of different augmentation methods (a1, a2).
For every example x1:
- For each style-changing pair in A1, revalue the similarity distribution between the representations of xi under a1i and the rest of the mini-batch samples under a2i. For this, we activate a1i on xi and calculate its similarity vector d with the representations of the rest of the samples under a2i. This similarity is calculated as an exponent of the internal multiplication of the representations, after both are passed through a shallow neural network with one or two layers.
- The vector d is normalized to represent the probabilities, and marked p1
- The KL distance is calculated between p1 and p2 (an interesting way to replace KL with the probability degrees distance and to check how the representation changed) and sum it for all the sample pairs from A1 and A2.
Achievements
The article proves that RELIC representations are stronger than representation learning methods (BYOL, AMDIM, SimCLR) in 3 different aspects:
- Linear discriminant ratio (LDR) which measures the distance between the different classical representations. A higher LDR represents larger distances between different class-clusters centers, as well as smaller classes diameter. A higher LDR indicates that it's easier to distinguish between different categories samples with a linear classifier, hence the representation is stronger and better.
- Better performance on different downstream tasks (classification tasks)
- And this is new and cool: the representation strength was tested on reinforcement learning tasks, and RELIC successfully improved it
Datasets: ImageNet ILSVRC-2012
P.S.
The article suggests an interesting idea to improve the performance of representation learning, based on NCE. They suggest adding a regularization item to the standard NCE loss function. This item's purpose is to make the image representations invariant to image style changes. The article demonstrates that their method successfully creates better and strong representations than previous SOTA methods. I would be happy to see this method generalized also to other domains and and other tasks types.
#deepnightlearners
This post was written by Michael (Mike) Erlihson, Ph.D.
Michael works in the cybersecurity company Salt Security as a principal data scientist. Michael researches and works in the deep learning field while lecturing and making scientific material more accessible to the public audience.