Revisit Contrastive learning

Contrastive learning is an instance-wise discriminative approach that aims at making similar instances closer and dissimilar instances far from each other in representation space.

  • treats each instance in the dataset as a distinct class of its own and trains a classifier to distinguish between individual instance classes

Given a dataset X={xi}i=1NX = \{x_i\}^N_{i=1}, each instance xix_i is assigned with a unique surrogate label yiy_i (since no ground-truth labels are given).

  • The unique surrogate label is often regarded as the ID of the instance in the dataset so that yi=iy_i = i.

Classifier

Then the classifier is defined as:

pθ(yixi)=expϕ(vi,vyi)j=1Nexpϕ(vi,vyj)p_{\theta}(y_i|x_i) = \frac{\exp \phi(\mathbf{v}_i, \mathbf{v}'_{y_i})}{\sum_{j=1}^{N} \exp \phi(\mathbf{v}_i, \mathbf{v}'_{y_j})}

  • θ\theta is the parameters of the encoder
  • viv_i and vyiv_{y_i}' are the embeddings from xix_i
    • generated from two different encoders (or one shared encoder)
  • ϕ\phi is the similarity function
    • often cosine similarity with temperature τ\tau , assuming the embedding is L2-normalized
      • ϕ(vi,vyi)=cosine similarity(vi,vyi)τ\phi(\mathbf{v}_i, \mathbf{v}'_{y_i}) = \frac{\text{cosine similarity}(\mathbf{v}_i, \mathbf{v}'_{y_i})}{\tau}

Objective

  • To maximize the joint probability i=1Np(yixi)\prod_{i=1}^{N} p(y_i | x_i)
  • That means minimize the negative log-likelihood function i=1Ni\sum_{i=1}^{N} \ell_i if let i=logp(yixi).\ell_i = -\log p(y_i|x_i).

The loss \ell could be NCE loss, InfoNCE loss, or NT-Xent loss.

InfoNCE Loss

Information Noise Contrastive Estimation (InfoNCE) loss function is a popular method.

In the learned representation space:

  • maximizes the agreement between positive samples
  • minimizes the agreement between negative samples

The model learns to discriminate between similar and dissimilar instances, leading to improved performance on downstream tasks.

Disentangled Contrastive Learning on Graphs

Paper: Disentangled Contrastive Learning on Graphs (NIPS 2021)

Problem:

  • entanglement is severe with the existing graph contrastive learning method
    • The key of these methods is to maximize the agreement (i.e. similarity) between different transformations of the input graphs

This paper:

  • First work to apply Constrastive Learning into disentangled self-supervised graph representation learning (DGCL)
    • Tailored graph encoder for disentangled contrastive learning
    • Tailored discrimination tasks for disentangled contrastive learning on graphs

The paper also explained although some works use VAE on graph for disentanglement, the reconstruction in generative methods could be computationally expensive and may introduce bias that has a negative effect on the learned representation. The reconstruction for graph-structure data often involves discrete decisions that are not differentiable.

Method

img

The framework of DGCL Model.

  • The input graph GiG_i undergoes graph augmentations to produce GiG_i^{\text{'}} , and both of them are fed into the shared disentangled graph encoder fθ()f_\theta(\cdot).
  • The node features H0H^0 are first aggregated by L message-passing layers and then taken as the input of a multi-channel message-passing layer.
  • Based on the disentangled graph representation ziz_i , the factor-wise contrastive learning aims to maximize the agreement under each latent factor and provide feedback for the encoder to improve the disentanglement.
    • zi=[zi,1,zi,2,...,zi,K]z_i = [z_{i,1}, z_{i,2}, ... , z_{i,K}]
    • This example assumes that there are three latent factors (K=3K=3), hence the three channels.

Inputs

Graph Augmentation

randomly perform one type of data augmentations for graphs as follows:

  • Node dropping.
    • Given the input graph, it will randomly discard 20% nodes along with their edges, implying that the missing nodes do not affect the model predictions much.
  • Edge perturbation.
    • Given the input graph, it will randomly add or cut a certain portion of connections between nodes with the probability of 0.2. This augmentation can prompt robustness of the graph encoder to the edge connectivity pattern variances.
  • Attribute masking.
    • It will set the feature of 20% nodes in the graph to Gaussian noises with mean and standard deviation is 0.5. The underlying prior is that missing part of the features do not affect the semantic information of the whole graph.
  • Subgraph sampling.
    • It will sample a subgraph, including 20% nodes from the input graph, using random walk. The assumption is that the semantic information of the whole graph can be reflected by its partial structure.

Disentangled Graph Encoder

Message Passing

GNNs generally adopt a neignborhood aggregation (message passing) paradigm

  • the embedding of node is iteratively updated by aggregating embeddings of its neighbors

hvl=COMBINEl(hvl1,AGGREGATEl({hul1:uN(v)}))h_v^l = \text{COMBINE}^l \left( h_v^{l-1}, \text{AGGREGATE}^l \left( \{ h_u^{l-1} : u \in \mathcal{N}(v) \} \right) \right)

  • hvlh_v^l is the representation of node vv at the ll-th layer
  • N(v)\mathcal{N}(v) is the neighborhood to node vv

Multi-Channel Message Passing Layer

  • Each channel is tailored to aggregate features only from one disentangled latent factor
  • A separate READOUT opereation (pooling) in each channel summarizes the specific aspect of the graph according to the corresponding latent factor
  • Produce disentangled graph representation

Factor-wise Contrastive Learning

A novel factor-wise instance discriminative task and learns to solve this task under each latent factor independently.

  • makes similar samples closer and dissimilar samples far from each other in the representation space (like traditional Contrastive Learning)
  • but also encourages the learned representation to incorporate factor-level information for disentanglement

Idea: assume that the formation of real-world graphs is usually driven by multiple latent heterogeneous factors

The instance discriminative task should be represented as the expectation of several subtasks under the latent factors:

pθ(yiGi)=Epθ(kGi)[pθ(yiGi,k)]p_{\theta}(y_i|G_i) = \mathbb{E}_{p_{\theta}(k|G_i)} \left[ p_{\theta}(y_i|G_i, k) \right]

the instance discrimination subtask under the kk-th latent factor setting:

pθ(kGi)=expϕ(zi,k,ck)k=1Kexpϕ(zi,k,ck)p_{\theta}(k|G_i) = \frac{\exp \phi(\mathbf{z}_{i, k}, \mathbf{c}_k)}{\sum_{k=1}^{K} \exp \phi(\mathbf{z}_{i, k}, \mathbf{c}_k)}

  • KK latent factor prototypes {ck}k=1K\{c_k\}^K_{k=1}

pθ(yiGi,k)=expϕ(zi,k,zyi,k)j=1Nexpϕ(zi,k,zyj,k)p_{\theta}(y_i|G_i, k) = \frac{\exp \phi(\mathbf{z}_{i, k}, \mathbf{z}'_{y_i, k})}{\sum_{j=1}^{N} \exp \phi(\mathbf{z}_{i, k}, \mathbf{z}'_{y_j, k})}

  • zi,k\mathbf{z}_{i,k} and zyi,k\mathbf{z}'_{y_i,k} are the disentangled representations produced by the shared graph encoder, and yiy_i is the unique surrogate label of the graph.

Conduct factor-wise contrastive learning for each latent factor independently.

Objective

Learn the model parameters θ\theta by maximizing the joint probability i=1Np(yiGi)\prod^N_{i=1}p(y_i\mid G_i) over the graph dataset G={Gi}i=1NG = \{G_i\}^N_{i=1}:

θ=argmaxθi=1Nlogpθ(yiGi)=argmaxθi=1NlogEpθ(kGi)[pθ(yiGi,k)]\theta^* = \arg\max_{\theta} \sum_{i=1}^{N} \log p_{\theta}(y_i|G_i) = \arg\max_{\theta} \sum_{i=1}^{N} \log \mathbb{E}_{p_{\theta}(k|G_i)} \left[ p_{\theta}(y_i|G_i, k) \right]

However Directly maximizing the log-likelihood function is difficult because of the latent factors

  • Optimize the evidence lower bound (ELBO) of the log-likelihood function.

The Evidence Lower Bound (ELBO) Theorem

Theorem: The log likelihood function of each graph is lower bounded by the ELBO.

Proof:

img

To make the ELBO as tight as possible:

  • Require that qθ(kGi,yi)q_\theta(k\mid G_i, y_i) is close to pθ(kGi,yi)p_\theta(k \mid G_i, y_i)
    • qθ(kGi,yi)q_\theta(k\mid G_i, y_i) is a variational distribution to infer the posterior distribution of the latent factors after observing both GiG_i and its correlated view GyiG'_{y_i}.

qθ(kGi,yi)=pθ(kGi)p^θ(yiGi,k)k=1Kpθ(kGi)p^θ(yiGi,k)q_{\theta}(k|G_i, y_i) = \frac{p_{\theta}(k|G_i) \hat{p}_{\theta}(y_i|G_i, k)}{\sum_{k=1}^{K} p_{\theta}(k|G_i) \hat{p}_{\theta}(y_i|G_i, k)}

  • NT-Xent loss

We calculate qθ(kGi,yi)q_\theta(k\mid G_i, y_i) and maximize the ELBO over a mini-batch B\mathcal{B} using mini-batch gradient ascent.

Overall Algorithm

GNN to indicate the message-passing layer

Eq. (9) in the paper:

qθ(kGi,yi)=pθ(kGi)p^θ(yiGi,k)k=1Kpθ(kGi)p^θ(yiGi,k)q_{\theta}(k|G_i, y_i) = \frac{p_{\theta}(k|G_i) \hat{p}_{\theta}(y_i|G_i, k)}{\sum_{k=1}^{K} p_{\theta}(k|G_i) \hat{p}_{\theta}(y_i|G_i, k)}

img