Paper Review - Disentangled Contrastive Learning on Graphs
Revisit Contrastive learning
Contrastive learning is an instance-wise discriminative approach that aims at making similar instances closer and dissimilar instances far from each other in representation space.
- treats each instance in the dataset as a distinct class of its own and trains a classifier to distinguish between individual instance classes
Given a dataset , each instance is assigned with a unique surrogate label (since no ground-truth labels are given).
- The unique surrogate label is often regarded as the ID of the instance in the dataset so that .
Classifier
Then the classifier is defined as:
- is the parameters of the encoder
- and are the embeddings from
- generated from two different encoders (or one shared encoder)
- is the similarity function
- often cosine similarity with temperature , assuming the embedding is L2-normalized
- often cosine similarity with temperature , assuming the embedding is L2-normalized
Objective
- To maximize the joint probability
- That means minimize the negative log-likelihood function if let
The loss could be NCE loss, InfoNCE loss, or NT-Xent loss.
InfoNCE Loss
Information Noise Contrastive Estimation (InfoNCE) loss function is a popular method.
In the learned representation space:
- maximizes the agreement between positive samples
- minimizes the agreement between negative samples
The model learns to discriminate between similar and dissimilar instances, leading to improved performance on downstream tasks.
Disentangled Contrastive Learning on Graphs
Paper: Disentangled Contrastive Learning on Graphs (NIPS 2021)
Problem:
- entanglement is severe with the existing graph contrastive learning method
- The key of these methods is to maximize the agreement (i.e. similarity) between different transformations of the input graphs
This paper:
- First work to apply Constrastive Learning into disentangled self-supervised graph representation learning (DGCL)
- Tailored graph encoder for disentangled contrastive learning
- Tailored discrimination tasks for disentangled contrastive learning on graphs
The paper also explained although some works use VAE on graph for disentanglement, the reconstruction in generative methods could be computationally expensive and may introduce bias that has a negative effect on the learned representation. The reconstruction for graph-structure data often involves discrete decisions that are not differentiable.
Method
The framework of DGCL Model.
- The input graph undergoes graph augmentations to produce , and both of them are fed into the shared disentangled graph encoder .
- The node features are first aggregated by L message-passing layers and then taken as the input of a multi-channel message-passing layer.
- Based on the disentangled graph representation , the factor-wise contrastive learning aims to maximize the agreement under each latent factor and provide feedback for the encoder to improve the disentanglement.
- This example assumes that there are three latent factors (), hence the three channels.
Inputs
Graph Augmentation
randomly perform one type of data augmentations for graphs as follows:
- Node dropping.
- Given the input graph, it will randomly discard 20% nodes along with their edges, implying that the missing nodes do not affect the model predictions much.
- Edge perturbation.
- Given the input graph, it will randomly add or cut a certain portion of connections between nodes with the probability of 0.2. This augmentation can prompt robustness of the graph encoder to the edge connectivity pattern variances.
- Attribute masking.
- It will set the feature of 20% nodes in the graph to Gaussian noises with mean and standard deviation is 0.5. The underlying prior is that missing part of the features do not affect the semantic information of the whole graph.
- Subgraph sampling.
- It will sample a subgraph, including 20% nodes from the input graph, using random walk. The assumption is that the semantic information of the whole graph can be reflected by its partial structure.
Disentangled Graph Encoder
Message Passing
GNNs generally adopt a neignborhood aggregation (message passing) paradigm
- the embedding of node is iteratively updated by aggregating embeddings of its neighbors
- is the representation of node at the -th layer
- is the neighborhood to node
Multi-Channel Message Passing Layer
- Each channel is tailored to aggregate features only from one disentangled latent factor
- A separate READOUT opereation (pooling) in each channel summarizes the specific aspect of the graph according to the corresponding latent factor
- Produce disentangled graph representation
Factor-wise Contrastive Learning
A novel factor-wise instance discriminative task and learns to solve this task under each latent factor independently.
- makes similar samples closer and dissimilar samples far from each other in the representation space (like traditional Contrastive Learning)
- but also encourages the learned representation to incorporate factor-level information for disentanglement
Idea: assume that the formation of real-world graphs is usually driven by multiple latent heterogeneous factors
The instance discriminative task should be represented as the expectation of several subtasks under the latent factors:
the instance discrimination subtask under the -th latent factor setting:
- latent factor prototypes
- and are the disentangled representations produced by the shared graph encoder, and is the unique surrogate label of the graph.
Conduct factor-wise contrastive learning for each latent factor independently.
Objective
Learn the model parameters by maximizing the joint probability over the graph dataset :
However Directly maximizing the log-likelihood function is difficult because of the latent factors
- Optimize the evidence lower bound (ELBO) of the log-likelihood function.
The Evidence Lower Bound (ELBO) Theorem
Theorem: The log likelihood function of each graph is lower bounded by the ELBO.
Proof:
To make the ELBO as tight as possible:
- Require that is close to
- is a variational distribution to infer the posterior distribution of the latent factors after observing both and its correlated view .
- NT-Xent loss
We calculate and maximize the ELBO over a mini-batch using mini-batch gradient ascent.
Overall Algorithm
GNN to indicate the message-passing layer
Eq. (9) in the paper: