Prerequisite Knowledge

AlphaFold 2 (2021)

Highly accurate protein structure prediction with AlphaFold

img

  • Inputs: Protein Data Bank (PDB) dataset (> 1TB)
  • EvoFormer (Recycling): Process MSA and Pair representation into Single and Pair representation
    • Recycling: The iterative refinement using the whole network contributes markedly to accuracy with minor extra training time.
  • Structure Module

How Recycling is done:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
 # In the init function 
# ....
# recycling related

self.recycle_single = nn.Sequential(
nn.LayerNorm(dim_single),
LinearNoBias(dim_single, dim_single)
)

self.recycle_pairwise = nn.Sequential(
nn.LayerNorm(dim_pairwise),
LinearNoBias(dim_pairwise, dim_pairwise)
)
# ....


# In the forward function
# init recycled single and pairwise

recycled_pairwise = recycled_single = None
single = pairwise = None

# for each recycling step

for _ in range(num_recycling_steps):

# handle recycled single and pairwise if not first step

recycled_single = recycled_pairwise = 0.

if exists(single):
recycled_single = self.recycle_single(single)

if exists(pairwise):
recycled_pairwise = self.recycle_pairwise(pairwise)

single = single_init + recycled_single
pairwise = pairwise_init + recycled_pairwise

# Then single, pairwise are collected after passing the Template, MSA and PairFormer module

# ....
  • A extra LayerNorm -> Linear layer is introduced as Recycle layer. In each cycle the result is stored and added into

EvoFormer

img

Usage of EvoFormer: To view the prediction of protein structures as a graph inference problem in 3D space in which the edges of the graph are defined by residues in proximity.

MSA Representation

  • The columns of the MSA representation encode the individual residues of the input sequence while the rows represent the sequences in which those residues appear.
  • The MSA representation updates the pair representation through an element-wise outer product that is summed over the MSA sequence dimension.

Pair Representation

  • The elements of the pair representation encode information about the relation between the residues.

    • img

      Pair Representation (r, r’, c):

      • Grid Layout (Left):
        • This shows a matrix representation where each cell represents a pair of nodes (i, j) with a particular feature dimension.
        • The indices i, j, k represent different nodes in the graph.
        • Cells like ij, ik, ji, jk, ki, kj are highlighted, indicating the pairs of nodes being considered.
        • This grid helps in visualizing how each pair of nodes (i, j) interacts with others.

      Corresponding Edges in a Graph (Right):

      • Graph Layout:
        • This shows a directed graph where nodes are connected by directed edges.
        • Nodes i, j, and k are connected by edges such as ij, ji, ik, jk, ki, and kj.
        • These edges represent the relationships between the nodes in the graph.

img

AlphaFold utilizes two different update patterns within the pair representation, inspired by the necessity to satisfy the constraints of a 3D protein structure. For a pairwise description of amino acids to form a consistent 3D structure, several constraints, including the triangle inequality on distances, must be met. The triangle inequality states that the sum of the lengths of any two sides of a triangle must be greater than the length of the remaining side. This principle is vital for the geometric consistency of the protein structure.

Triangle Multiplicative Update Outgoing Edges:

  • Updates edge ijij using edges ikik and jkjk: This corresponds to the mix equation ... i k d, ... j k d -> ... i j d in Einsum notation.
  • Understand pairwise interactions.

Triangle Multiplicative Update Incoming Edges:

  • Updates edge ijij using edges kiki and kjkj: This corresponds to the mix equation ... k j d, ... k i d -> ... i j d in Einsum notation.
  • Understand pairwise interactions.

Triangle Self-Attention Around Starting Node:

  • Focuses on how node ii is related to jj and kk: Attention is on outgoing edges from ii​.
  • Understanding how updates can be propagated through the graph using different edges.

Triangle Self-Attention Around Ending Node:

  • Focuses on how node jj is related to ii and kk: Attention is on outgoing edges from jj​.
  • Understanding how updates can be propagated through the graph using different edges.
  • Triangle Multiplicative Update and Triangle Self-Attention
    • we define a non-attention update operation ‘triangle multiplicative update’ that uses two edges to update the missing third edge
      • By using two edges of a triangle to update the third edge, this update operation establishes geometric consistency
    • we add an extra logit bias to axial attention to include the ‘missing edge’ of the triangle
      • This means when considering a triangle formed by nodes i,j,ki, j, k, the logit bias helps to account for the edge not directly considered in a particular attention step, ensuring that all pairwise relationships within the triangle are considered.
  • Transition
    • Its just a FF (linear layer).

Somehow in the paper they just mentioned: The triangle multiplicative update was developed originally as a more symmetric and cheaper replacement for the attention, and networks that use only the attention or multiplicative update are both able to produce high-accuracy structures. However, the combination of the two updates is more accurate.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
class TriangleMultiplication(Module):

@typecheck
def __init__(
self,
*,
dim,
dim_hidden = None,
mix: Literal["incoming", "outgoing"] = 'incoming',
dropout = 0.,
dropout_type: Literal['row', 'col'] | None = None
):
super().__init__()

dim_hidden = default(dim_hidden, dim) # Use dim_hidden if provided, otherwise use dim
self.norm = nn.LayerNorm(dim) # Normalize input tensor

self.left_right_proj = nn.Sequential(
LinearNoBias(dim, dim_hidden * 4), # Linear projection to higher dimension
nn.GLU(dim = -1) # Apply Gated Linear Unit activation and split into left and right components
)

self.left_right_gate = LinearNoBias(dim, dim_hidden * 2) # Linear projection for gating values
self.out_gate = LinearNoBias(dim, dim_hidden) # Linear projection for output gate

if mix == 'outgoing':
self.mix_einsum_eq = '... i k d, ... j k d -> ... i j d' # Define einsum equation for outgoing mix
# for each batch (...), we are taking elements indexed by i and j and summing over the k dimension.
# In practical terms, this mixes the left and right components along the k dimension, resulting in an output tensor with indices i and j.
elif mix == 'incoming':
self.mix_einsum_eq = '... k j d, ... k i d -> ... i j d' # Define einsum equation for incoming mix
# for each batch (...), we are taking elements indexed by k and summing over the k dimension.
# Here, the mixing is performed in such a way that it considers the elements indexed by j and i first and then sums over the k dimension, resulting in an output tensor with indices i and j.

self.to_out_norm = nn.LayerNorm(dim_hidden) # Normalize mixed output tensor

self.to_out = Sequential(
LinearNoBias(dim_hidden, dim), # Project back to original dimension
Dropout(dropout, dropout_type = dropout_type) # Apply dropout for regularization
)

@typecheck
def forward(
self,
x: Float['b n n d'],
mask: Bool['b n'] | None = None
) -> Float['b n n d']:

if exists(mask):
mask = einx.logical_and('b i, b j -> b i j 1', mask, mask) # Expand and apply mask

x = self.norm(x) # Normalize input tensor

left, right = self.left_right_proj(x).chunk(2, dim = -1) # Project to higher dimension and split into left and right components

if exists(mask):
left = left * mask # Apply mask to left component
right = right * mask # Apply mask to right component

out = einsum(left, right, self.mix_einsum_eq) # Mix left and right components using einsum

out = self.to_out_norm(out) # Normalize mixed output tensor

out_gate = self.out_gate(x).sigmoid() # Compute output gate
out = out * out_gate # Apply output gate to mixed output tensor

return self.to_out(out) # Project back to original dimension and apply dropout
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
class TriangleAttention(Module):
def __init__(
self,
*,
dim,
heads,
node_type: Literal['starting', 'ending'],
dropout = 0.,
dropout_type: Literal['row', 'col'] | None = None,
**attn_kwargs
):
super().__init__()
self.need_transpose = node_type == 'ending' # Determine if transposition is needed

self.attn = Attention(dim = dim, heads = heads, **attn_kwargs) # Initialize attention mechanism

self.dropout = Dropout(dropout, dropout_type = dropout_type) # Initialize dropout layer

self.to_attn_bias = nn.Sequential(
LinearNoBias(dim, heads), # Linear layer projecting to attention heads
Rearrange('... i j h -> ... h i j') # Rearrange dimensions for attention bias
)

@typecheck
def forward(
self,
pairwise_repr: Float['b n n d'],
mask: Bool['b n'] | None = None,
**kwargs
) -> Float['b n n d']:

if self.need_transpose:
pairwise_repr = rearrange(pairwise_repr, 'b i j d -> b j i d') # Transpose if needed

attn_bias = self.to_attn_bias(pairwise_repr) # Calculate attention bias

batch_repeat = pairwise_repr.shape[1] # Determine batch repeat size
attn_bias = repeat(attn_bias, 'b ... -> (b repeat) ...', repeat = batch_repeat) # Repeat attention bias

if exists(mask):
mask = repeat(mask, 'b ... -> (b repeat) ...', repeat = batch_repeat) # Repeat mask if it exists

pairwise_repr, packed_shape = pack_one(pairwise_repr, '* n d') # Pack tensor for processing

out = self.attn(
pairwise_repr,
mask = mask,
attn_bias = attn_bias,
**kwargs
) # Apply attention

out = unpack_one(out, packed_shape, '* n d') # Unpack tensor to original shape

if self.need_transpose:
out = rearrange(out, 'b j i d -> b i j d') # Transpose back if needed

return self.dropout(out) # Apply dropout and return output

Structure Module

img

IPA (Invariant Point Attention)

A geometry-aware attention operation (i.e. IPA) is used to update single representation without changing the 3D positions, then an equivariant update operation is performed on the residue gas using the updated activations.

  • Augments each of the usual attention queries, keys and values with 3D points that are produced in the local frame of each residue such that the final value is invariant to global rotations and translations.

Training of AlphaFold 2

AlphaFold 2 Training is split into 2 phase:

  • Initial Training Phase
  • More Computationally Intensive Fine-tuning Phase
    • Size of protein fragments used for training is increased to 384 residues
    • Additional loss function to penalizes structural violations

OpenFold (2024)

OpenFold: retraining AlphaFold2 yields new insights into its learning mechanisms and capacity for generalization

  • A group of researchers reimplemented AlphaFold2 to look for more insights.
    • Introduced OpenProteinSet, a training dataset to reproduce AlphaFold2 performance

img

OpenFold generalization capacity on elided training sets.

  • (A) Validation set lDDT-Cα as a function of training step for models trained on elided training sets (10k random split repeated 3x).
  • (B) Same as (A) but for CATH-stratified dataset elisions. Validation sets vary across stratifications and are not directly comparable.
  • © Experimental structures (orange) and mainly alpha-trained (yellow) and mainly beta-trained (red) predictions of largely helical Lsi1 (top) and beta sheet-heavy TMED1 (bottom).

Insights: Efficient Training

Its possible to train with less compute

OpenFold achieves

  • ~90% of its final accuracy in just 1500 GPU Hours (~3% of training time)
  • ~95% of its final accuracy in just 2500 GPU Hours
  • Total training time ~50000 GPU hours

Small Dataset can achieve full model performance

  • About 7.6% of all training data is suffice to reach the same initial IDDT-Cα\alpha (model accuracy) as a model training on full training set.

AlphaFold2 can Generalize well

  • Filtering out CATH dataset still yield good performance

Insights: Components Analysis

Second Training Phase only resolve chemical constraints

  • Second training phase (fine-tuning phase) has only a modest effect on overall structure
  • The primary utility of fine-tuning appears to be to resolve violations of known chemical constraints

Template is not what you need

  • Templates have a minimal effect except when MSAs are shallow or entirely absent.

Improvements on OpenFold

Improve Training Stability and Speed up model convergence

  • FP 16 low-precision training
  • Uses a different clamping protocol in the primary structural loss, FAPE

AlphaFold 3 (2024)

Accurate structure prediction of biomolecular interactions with AlphaFold 3

img

Modalities: Sequences, Ligands, Covalent bonds (Raw atoms + coarse abstract token)

  • Inputs: Protein Data Bank (PDB) dataset (> 1TB)
  • Input Embedder: extract representations (pair, single)
  • PairFormer (Recycling): to view the prediction of protein structures as a graph inference problem in 3D space in which the edges of the graph are defined by residues in proximity.
  • Diffusion Decoder: builds out the 3D coordinates

PDB Dataset Curation

  • Extract mmCIF file
    • format describe the 3D structures of biomolecules.
  • Filtering mmCIF file
    • filter out certain bioassemblies
  • select the closest chains in large bioassemblies
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
@typecheck
@timeout_decorator.timeout(PROCESS_STRUCTURE_MAX_SECONDS, use_signals=False)
def process_structure_with_timeout(filepath: str, output_dir: str):
"""
Given an input mmCIF file, create a new processed mmCIF file
using AlphaFold 3's PDB dataset filtering criteria under a
timeout constraint.
"""
# Section 2.5.4 of the AlphaFold 3 supplement
structure_id = os.path.splitext(os.path.basename(filepath))[0]
output_file_dir = os.path.join(output_dir, structure_id[1:3])
output_filepath = os.path.join(output_file_dir, f"{structure_id}.cif")
os.makedirs(output_file_dir, exist_ok=True)

# Filtering of targets
structure = parse_structure(filepath)
structure = filter_target(structure)
if exists(structure):
# Filtering of bioassemblies
structure = remove_hydrogens(structure)
structure = remove_all_unknown_residue_chains(structure, STANDARD_RESIDUES)
structure = remove_clashing_chains(structure)
structure = remove_excluded_ligands(structure, LIGAND_EXCLUSION_LIST)
structure = remove_non_ccd_atoms(structure, CCD_READER_RESULTS)
structure = remove_leaving_atoms(structure, CCD_READER_RESULTS)
structure = filter_large_ca_distances(structure)
structure = select_closest_chains(
structure, PROTEIN_RESIDUE_CENTER_ATOMS, NUCLEIC_ACID_RESIDUE_CENTER_ATOMS
)
structure = remove_crystallization_aids(structure, CRYSTALLOGRAPHY_METHODS)
if list(structure.get_chains()):
# Save processed structure
write_structure(structure, output_filepath)
print(f"Finished processing structure: {structure.id}")

Input Embedder

  • MLP as first projection
  • Atom Transformer (DiT)
  • another MLP as last projection

Template Module (TemplateEmbedder)

  • Add Pair representation and template embedding together
  • 2 Stacks of TriangleMultiplication+TriangleAttention+Projection (Same as AlphaFold2)

MSA Module

  • 4 Stacks of TriangleMultiplication+TriangleAttention+Projection (Same as AlphaFold2)
  • and some extra process to modify pairwise representation using the MSA mask

PairFormer

has 48 PairFormerBlocks

  • Triangle Update: Same as AlphaFold2
  • Triangle Self-Attention: Same as AlphaFold2
  • Transition: LinearNoBias(dim, dim*4*2), SwiGLU(), LinearNoBias(dim*4, dim)

img

Diffusion

The diffusion module operates directly on raw atom coordinates, and on a coarse abstract token representation, without rotational frames or any equivariant processing.

img

The diffusion module. Input: coarse arrays depict per-token representations (green, inputs; blue, pair; red, single). Fine arrays depict per-atom representations. The coloured balls represent physical atom coordinates. Cond., conditioning; rand. rot. trans., random rotation and translation; seq., sequence.

  • A Diffusion Transformer conditoned on the single representation and pair representation
  • Elucidated Diffusion Model (EDM) adapted for atom position diffusing
  • At each step, a random point cloud of atoms is denoised, biased by per-token and atom information.

Why Diffusion Prior?

3D Generation cannot do with AR approach.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
class DiffusionTransformer(Module):

def __init__(
self,
*,
depth,
heads,
dim = 384,
dim_single_cond = None,
dim_pairwise = 128,
attn_window_size = None,
attn_pair_bias_kwargs: dict = dict(),
attn_num_memory_kv = False,
num_register_tokens = 0,
serial = False,
use_linear_attn = False,
linear_attn_kwargs = dict(
heads = 8,
dim_head = 16
)
):
super().__init__()
self.attn_window_size = attn_window_size

dim_single_cond = default(dim_single_cond, dim)

layers = ModuleList([])

for _ in range(depth):

linear_attn = None

if use_linear_attn:
linear_attn = TaylorSeriesLinearAttn(
dim = dim,
prenorm = True,
gate_value_heads = True,
**linear_attn_kwargs
)

pair_bias_attn = AttentionPairBias(
dim = dim,
dim_pairwise = dim_pairwise,
heads = heads,
window_size = attn_window_size,
num_memory_kv = attn_num_memory_kv,
**attn_pair_bias_kwargs
)

transition = Transition(
dim = dim
)

conditionable_pair_bias = ConditionWrapper(
pair_bias_attn,
dim = dim,
dim_cond = dim_single_cond
)

conditionable_transition = ConditionWrapper(
transition,
dim = dim,
dim_cond = dim_single_cond
)

layers.append(ModuleList([
linear_attn,
conditionable_pair_bias,
conditionable_transition
]))

self.layers = layers

self.serial = serial

self.has_registers = num_register_tokens > 0
self.num_registers = num_register_tokens

if self.has_registers:
assert not exists(attn_window_size), 'register tokens disabled for windowed attention'

self.registers = nn.Parameter(torch.zeros(num_register_tokens, dim))

@typecheck
def forward(
self,
noised_repr: Float['b n d'],
*,
single_repr: Float['b n ds'],
pairwise_repr: Float['b n n dp'] | Float['b nw w (w*2) dp'],
mask: Bool['b n'] | None = None
):
w = self.attn_window_size
has_windows = exists(w)

serial = self.serial

# handle windowing

pairwise_is_windowed = pairwise_repr.ndim == 5

if has_windows and not pairwise_is_windowed:
pairwise_repr = full_pairwise_repr_to_windowed(pairwise_repr, window_size = w)

# register tokens

if self.has_registers:
num_registers = self.num_registers
registers = repeat(self.registers, 'r d -> b r d', b = noised_repr.shape[0])
noised_repr, registers_ps = pack((registers, noised_repr), 'b * d')

single_repr = F.pad(single_repr, (0, 0, num_registers, 0), value = 0.)
pairwise_repr = F.pad(pairwise_repr, (0, 0, num_registers, 0, num_registers, 0), value = 0.)

if exists(mask):
mask = F.pad(mask, (num_registers, 0), value = True)

# main transformer

for linear_attn, attn, transition in self.layers:

if exists(linear_attn):
noised_repr = linear_attn(noised_repr, mask = mask) + noised_repr

attn_out = attn(
noised_repr,
cond = single_repr,
pairwise_repr = pairwise_repr,
mask = mask
)

if serial:
noised_repr = attn_out + noised_repr

ff_out = transition(
noised_repr,
cond = single_repr
)

if not serial:
ff_out = ff_out + attn_out

noised_repr = noised_repr + ff_out

# splice out registers

if self.has_registers:
_, noised_repr = unpack(noised_repr, registers_ps, 'b * d')

return noised_repr

Training Detail of AlphaFold 3

img

Mini-Rollout

  • The mini rollout helps refine the structure prediction by iteratively improving the output based on the current model’s state.
  • stabilizes the predictions, ensuring they converge towards a more accurate structure.
  • The mini rollout involves a sequence of 20 iterations (as indicated in the diagram) where the model repeatedly refines its prediction.
  • During each iteration, the model uses the output from the previous iteration as the input for the next, gradually reducing noise and correcting errors.
  • After 20 iterations, the refined prediction is used to permute the ground truth chains and ligands, providing a more accurate basis for the subsequent training steps.
  • Note that the mini rollout does not directly contribute to the loss calculation.

Confidence module

  • 1 Pairformer Stack
  • 1 Linear Projection for each kind of logits (PAE, PDE, pLDDT, Resolved)

Total Loss:

1
2
3
4
5
loss = (
distogram_loss * self.loss_distogram_weight +
diffusion_loss * self.loss_diffusion_weight +
confidence_loss * self.loss_confidence_weight
)

Loss functions

Diffusion Loss

  • Definition: Training Loss of diffusion module
  • Role in Training: The loss function used to train the model to reverse the diffusion process.
    • MSE + Bond Loss + Smooth LDDT Loss

Distogram Loss

  • Definition: A distogram is a histogram of distances between pairs of residues in a protein.
  • Role in Training: This loss helps the model learn the correct distance distributions between residue pairs.
    • Cross Entropy

Confidence Loss

  • PAE + PDE + pLDDT + Resolved Loss

PAE (Predicted Aligned Error):

  • Definition: PAE measures the expected positional error in the predicted structure of a protein when aligned to a reference structure.
  • Role in Training: The pae_labels are ground truth values that supervise the model’s prediction of positional errors in the protein structure alignment.
    • Cross Entropy

PDE (Predicted Distance Error):

  • Definition: PDE refers to the predicted error in inter-residue distances within the protein structure.
  • Role in Training: The pde_labels are ground truth values that guide the model in predicting errors in the distance matrix of the protein’s predicted structure compared to the true structure.
    • Cross Entropy

pLDDT (Predicted Local Distance Difference Test):

  • Definition: pLDDT is a per-residue confidence score assigned by AlphaFold to its predictions, indicating the reliability of each residue’s predicted position. Each of these residues will have a predicted position in the 3D structure.
  • Role in Training: The plddt_labels are ground truth confidence scores used to train the model’s prediction of how accurate each residue’s position is.
    • Cross Entropy

Resolved Loss:

  • Definition: Resolved labels the binary indicators that show whether a residue is well-resolved or confidently predicted in the protein structure.
  • Role in Training: The resolved_labels serve as ground truth binary indicators to help the model learn to predict which residues are well-resolved in the structure.
    • Cross Entropy

Usage Example

  • Batch size = 2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import torch
from alphafold3_pytorch import Alphafold3
# pip install alphafold3-pytorch
# https://github.com/lucidrains/alphafold3-pytorch


alphafold3 = Alphafold3(
dim_atom_inputs = 77,
dim_template_feats = 44
)

# mock inputs

seq_len = 16
molecule_atom_lens = torch.randint(1, 3, (2, seq_len)) # shape: (2, seq_len)
atom_seq_len = molecule_atom_lens.sum(dim=-1).amax() # scalar

atom_inputs = torch.randn(2, atom_seq_len, 77) # shape: (2, atom_seq_len, 77). 77 dimension
atompair_inputs = torch.randn(2, atom_seq_len, atom_seq_len, 5) # shape: (2, atom_seq_len, atom_seq_len, 5). 5 Features

additional_molecule_feats = torch.randn(2, seq_len, 9) # shape: (2, seq_len, 9). 9 Features
molecule_ids = torch.randint(0, 32, (2, seq_len)) # shape: (2, seq_len)

template_feats = torch.randn(2, 2, seq_len, seq_len, 44) # shape: (2, 2, seq_len, seq_len, 44). 44 dimension
template_mask = torch.ones((2, 2)).bool() # shape: (2, 2)

msa = torch.randn(2, 7, seq_len, 64) # shape: (2, 7, seq_len, 64). 7 aligned sequence, 64 dimension
msa_mask = torch.ones((2, 7)).bool() # shape: (2, 7)

# required for training, but omitted on inference

atom_pos = torch.randn(2, atom_seq_len, 3) # shape: (2, atom_seq_len, 3). Contains the 3D coordinates for each atom
molecule_atom_indices = molecule_atom_lens - 1 # shape: (2, seq_len). inferred from molecule_atom_lens

distance_labels = torch.randint(0, 37, (2, seq_len, seq_len)) # shape: (2, seq_len, seq_len)
pae_labels = torch.randint(0, 64, (2, seq_len, seq_len)) # shape: (2, seq_len, seq_len)
pde_labels = torch.randint(0, 64, (2, seq_len, seq_len)) # shape: (2, seq_len, seq_len)
plddt_labels = torch.randint(0, 50, (2, seq_len)) # shape: (2, seq_len)
resolved_labels = torch.randint(0, 2, (2, seq_len)) # shape: (2, seq_len)

# train

loss = alphafold3(
num_recycling_steps=2,
atom_inputs=atom_inputs, # shape: (2, atom_seq_len, 77). 77 dimension
atompair_inputs=atompair_inputs, # shape: (2, atom_seq_len, atom_seq_len, 5). 5 Features
molecule_ids=molecule_ids, # shape: (2, seq_len)
molecule_atom_lens=molecule_atom_lens, # shape: (2, seq_len)
additional_molecule_feats=additional_molecule_feats, # shape: (2, seq_len, 9). 9 Features
msa=msa, # shape: (2, 7, seq_len, 64). 7 aligned sequence, 64 dimension
msa_mask=msa_mask, # shape: (2, 7)
templates=template_feats, # shape: (2, 2, seq_len, seq_len, 44). 44 dimension
template_mask=template_mask, # shape: (2, 2)
atom_pos=atom_pos, # shape: (2, atom_seq_len, 3). Contains the 3D coordinates for each atom
molecule_atom_indices=molecule_atom_indices, # shape: (2, seq_len)
distance_labels=distance_labels, # shape: (2, seq_len, seq_len)
pae_labels=pae_labels, # shape: (2, seq_len, seq_len)
pde_labels=pde_labels, # shape: (2, seq_len, seq_len)
plddt_labels=plddt_labels, # shape: (2, seq_len)
resolved_labels=resolved_labels # shape: (2, seq_len)
)

loss.backward()

# after much training ...

sampled_atom_pos = alphafold3(
num_recycling_steps = 4,
num_sample_steps = 16,
atom_inputs = atom_inputs,
atompair_inputs = atompair_inputs,
molecule_ids = molecule_ids,
molecule_atom_lens = molecule_atom_lens,
additional_molecule_feats = additional_molecule_feats,
msa = msa,
msa_mask = msa_mask,
templates = template_feats,
template_mask = template_mask
)

sampled_atom_pos.shape # (2, <atom_seqlen>, 3). 3D coordinates of sampled atom positions for each molecule in the batch

Extra Info

OpenFold: https://github.com/aqlaboratory/openfold/tree/main

Implementation of Alphafold 3 in Pytorch: https://github.com/lucidrains/alphafold3-pytorch