Paper Review - AlphaFold2 and AlphaFold3
Prerequisite Knowledge
- MSA Transformers
- Diffusion Models
- Domain Knowledge of Protein
AlphaFold 2 (2021)
Highly accurate protein structure prediction with AlphaFold
- Inputs: Protein Data Bank (PDB) dataset (> 1TB)
- EvoFormer (Recycling): Process MSA and Pair representation into Single and Pair representation
- Recycling: The iterative refinement using the whole network contributes markedly to accuracy with minor extra training time.
- Structure Module
How Recycling is done:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42 # In the init function
# ....
# recycling related
self.recycle_single = nn.Sequential(
nn.LayerNorm(dim_single),
LinearNoBias(dim_single, dim_single)
)
self.recycle_pairwise = nn.Sequential(
nn.LayerNorm(dim_pairwise),
LinearNoBias(dim_pairwise, dim_pairwise)
)
# ....
# In the forward function
# init recycled single and pairwise
recycled_pairwise = recycled_single = None
single = pairwise = None
# for each recycling step
for _ in range(num_recycling_steps):
# handle recycled single and pairwise if not first step
recycled_single = recycled_pairwise = 0.
if exists(single):
recycled_single = self.recycle_single(single)
if exists(pairwise):
recycled_pairwise = self.recycle_pairwise(pairwise)
single = single_init + recycled_single
pairwise = pairwise_init + recycled_pairwise
# Then single, pairwise are collected after passing the Template, MSA and PairFormer module
# ....
- A extra LayerNorm -> Linear layer is introduced as Recycle layer. In each cycle the result is stored and added into
EvoFormer
Usage of EvoFormer: To view the prediction of protein structures as a graph inference problem in 3D space in which the edges of the graph are defined by residues in proximity.
MSA Representation
- The columns of the MSA representation encode the individual residues of the input sequence while the rows represent the sequences in which those residues appear.
- The MSA representation updates the pair representation through an element-wise outer product that is summed over the MSA sequence dimension.
Pair Representation
-
The elements of the pair representation encode information about the relation between the residues.
-
Pair Representation (r, r’, c):
- Grid Layout (Left):
- This shows a matrix representation where each cell represents a pair of nodes (i, j) with a particular feature dimension.
- The indices i, j, k represent different nodes in the graph.
- Cells like ij, ik, ji, jk, ki, kj are highlighted, indicating the pairs of nodes being considered.
- This grid helps in visualizing how each pair of nodes (i, j) interacts with others.
Corresponding Edges in a Graph (Right):
- Graph Layout:
- This shows a directed graph where nodes are connected by directed edges.
- Nodes i, j, and k are connected by edges such as ij, ji, ik, jk, ki, and kj.
- These edges represent the relationships between the nodes in the graph.
- Grid Layout (Left):
-
AlphaFold utilizes two different update patterns within the pair representation, inspired by the necessity to satisfy the constraints of a 3D protein structure. For a pairwise description of amino acids to form a consistent 3D structure, several constraints, including the triangle inequality on distances, must be met. The triangle inequality states that the sum of the lengths of any two sides of a triangle must be greater than the length of the remaining side. This principle is vital for the geometric consistency of the protein structure.
Triangle Multiplicative Update Outgoing Edges:
- Updates edge using edges and : This corresponds to the mix equation
... i k d, ... j k d -> ... i j d
in Einsum notation.- Understand pairwise interactions.
Triangle Multiplicative Update Incoming Edges:
- Updates edge using edges and : This corresponds to the mix equation
... k j d, ... k i d -> ... i j d
in Einsum notation.- Understand pairwise interactions.
Triangle Self-Attention Around Starting Node:
- Focuses on how node is related to and : Attention is on outgoing edges from .
- Understanding how updates can be propagated through the graph using different edges.
Triangle Self-Attention Around Ending Node:
- Focuses on how node is related to and : Attention is on outgoing edges from .
- Understanding how updates can be propagated through the graph using different edges.
- Triangle Multiplicative Update and Triangle Self-Attention
- we define a non-attention update operation ‘triangle multiplicative update’ that uses two edges to update the missing third edge
- By using two edges of a triangle to update the third edge, this update operation establishes geometric consistency
- we add an extra logit bias to axial attention to include the ‘missing edge’ of the triangle
- This means when considering a triangle formed by nodes , the logit bias helps to account for the edge not directly considered in a particular attention step, ensuring that all pairwise relationships within the triangle are considered.
- we define a non-attention update operation ‘triangle multiplicative update’ that uses two edges to update the missing third edge
- Transition
- Its just a FF (linear layer).
Somehow in the paper they just mentioned: The triangle multiplicative update was developed originally as a more symmetric and cheaper replacement for the attention, and networks that use only the attention or multiplicative update are both able to produce high-accuracy structures. However, the combination of the two updates is more accurate.
1 | class TriangleMultiplication(Module): |
1 | class TriangleAttention(Module): |
Structure Module
IPA (Invariant Point Attention)
A geometry-aware attention operation (i.e. IPA) is used to update single representation without changing the 3D positions, then an equivariant update operation is performed on the residue gas using the updated activations.
- Augments each of the usual attention queries, keys and values with 3D points that are produced in the local frame of each residue such that the final value is invariant to global rotations and translations.
Training of AlphaFold 2
AlphaFold 2 Training is split into 2 phase:
- Initial Training Phase
- More Computationally Intensive Fine-tuning Phase
- Size of protein fragments used for training is increased to 384 residues
- Additional loss function to penalizes structural violations
OpenFold (2024)
- A group of researchers reimplemented AlphaFold2 to look for more insights.
- Introduced OpenProteinSet, a training dataset to reproduce AlphaFold2 performance
OpenFold generalization capacity on elided training sets.
- (A) Validation set lDDT-Cα as a function of training step for models trained on elided training sets (10k random split repeated 3x).
- (B) Same as (A) but for CATH-stratified dataset elisions. Validation sets vary across stratifications and are not directly comparable.
- © Experimental structures (orange) and mainly alpha-trained (yellow) and mainly beta-trained (red) predictions of largely helical Lsi1 (top) and beta sheet-heavy TMED1 (bottom).
Insights: Efficient Training
Its possible to train with less compute
OpenFold achieves
- ~90% of its final accuracy in just 1500 GPU Hours (~3% of training time)
- ~95% of its final accuracy in just 2500 GPU Hours
- Total training time ~50000 GPU hours
Small Dataset can achieve full model performance
- About 7.6% of all training data is suffice to reach the same initial IDDT-C (model accuracy) as a model training on full training set.
AlphaFold2 can Generalize well
- Filtering out CATH dataset still yield good performance
Insights: Components Analysis
Second Training Phase only resolve chemical constraints
- Second training phase (fine-tuning phase) has only a modest effect on overall structure
- The primary utility of fine-tuning appears to be to resolve violations of known chemical constraints
Template is not what you need
- Templates have a minimal effect except when MSAs are shallow or entirely absent.
Improvements on OpenFold
Improve Training Stability and Speed up model convergence
- FP 16 low-precision training
- Uses a different clamping protocol in the primary structural loss, FAPE
AlphaFold 3 (2024)
Accurate structure prediction of biomolecular interactions with AlphaFold 3
Modalities: Sequences, Ligands, Covalent bonds (Raw atoms + coarse abstract token)
- Inputs: Protein Data Bank (PDB) dataset (> 1TB)
- Input Embedder: extract representations (pair, single)
- PairFormer (Recycling): to view the prediction of protein structures as a graph inference problem in 3D space in which the edges of the graph are defined by residues in proximity.
- Diffusion Decoder: builds out the 3D coordinates
PDB Dataset Curation
- Extract mmCIF file
- format describe the 3D structures of biomolecules.
- Filtering mmCIF file
- filter out certain bioassemblies
- select the closest chains in large bioassemblies
1 |
|
Input Embedder
- MLP as first projection
- Atom Transformer (DiT)
- another MLP as last projection
Template Module (TemplateEmbedder)
- Add Pair representation and template embedding together
- 2 Stacks of TriangleMultiplication+TriangleAttention+Projection (Same as AlphaFold2)
MSA Module
- 4 Stacks of TriangleMultiplication+TriangleAttention+Projection (Same as AlphaFold2)
- and some extra process to modify pairwise representation using the MSA mask
PairFormer
has 48 PairFormerBlocks
- Triangle Update: Same as AlphaFold2
- Triangle Self-Attention: Same as AlphaFold2
- Transition: LinearNoBias(dim, dim*4*2), SwiGLU(), LinearNoBias(dim*4, dim)
Diffusion
The diffusion module operates directly on raw atom coordinates, and on a coarse abstract token representation, without rotational frames or any equivariant processing.
The diffusion module. Input: coarse arrays depict per-token representations (green, inputs; blue, pair; red, single). Fine arrays depict per-atom representations. The coloured balls represent physical atom coordinates. Cond., conditioning; rand. rot. trans., random rotation and translation; seq., sequence.
- A Diffusion Transformer conditoned on the single representation and pair representation
- Elucidated Diffusion Model (EDM) adapted for atom position diffusing
- At each step, a random point cloud of atoms is denoised, biased by per-token and atom information.
Why Diffusion Prior?
3D Generation cannot do with AR approach.
1 | class DiffusionTransformer(Module): |
Training Detail of AlphaFold 3
Mini-Rollout
- The mini rollout helps refine the structure prediction by iteratively improving the output based on the current model’s state.
- stabilizes the predictions, ensuring they converge towards a more accurate structure.
- The mini rollout involves a sequence of 20 iterations (as indicated in the diagram) where the model repeatedly refines its prediction.
- During each iteration, the model uses the output from the previous iteration as the input for the next, gradually reducing noise and correcting errors.
- After 20 iterations, the refined prediction is used to permute the ground truth chains and ligands, providing a more accurate basis for the subsequent training steps.
- Note that the mini rollout does not directly contribute to the loss calculation.
Confidence module
- 1 Pairformer Stack
- 1 Linear Projection for each kind of logits (PAE, PDE, pLDDT, Resolved)
Total Loss:
1 | loss = ( |
Loss functions
Diffusion Loss
- Definition: Training Loss of diffusion module
- Role in Training: The loss function used to train the model to reverse the diffusion process.
- MSE + Bond Loss + Smooth LDDT Loss
Distogram Loss
- Definition: A distogram is a histogram of distances between pairs of residues in a protein.
- Role in Training: This loss helps the model learn the correct distance distributions between residue pairs.
- Cross Entropy
Confidence Loss
- PAE + PDE + pLDDT + Resolved Loss
PAE (Predicted Aligned Error):
- Definition: PAE measures the expected positional error in the predicted structure of a protein when aligned to a reference structure.
- Role in Training: The
pae_labels
are ground truth values that supervise the model’s prediction of positional errors in the protein structure alignment.- Cross Entropy
PDE (Predicted Distance Error):
- Definition: PDE refers to the predicted error in inter-residue distances within the protein structure.
- Role in Training: The
pde_labels
are ground truth values that guide the model in predicting errors in the distance matrix of the protein’s predicted structure compared to the true structure.- Cross Entropy
pLDDT (Predicted Local Distance Difference Test):
- Definition: pLDDT is a per-residue confidence score assigned by AlphaFold to its predictions, indicating the reliability of each residue’s predicted position. Each of these residues will have a predicted position in the 3D structure.
- Role in Training: The
plddt_labels
are ground truth confidence scores used to train the model’s prediction of how accurate each residue’s position is.- Cross Entropy
Resolved Loss:
- Definition: Resolved labels the binary indicators that show whether a residue is well-resolved or confidently predicted in the protein structure.
- Role in Training: The
resolved_labels
serve as ground truth binary indicators to help the model learn to predict which residues are well-resolved in the structure.- Cross Entropy
Usage Example
- Batch size = 2
1 | import torch |
Extra Info
OpenFold: https://github.com/aqlaboratory/openfold/tree/main
Implementation of Alphafold 3 in Pytorch: https://github.com/lucidrains/alphafold3-pytorch