Paper Review - Axial Transformer and MSA Transformer
Revisit Self-Attention
1 | class SelfAttention(nn.Module): |
Axial Transformer
Axial Attention in Multidimensional Transformers (2019)
The Axial Transformer is a transformer architecture that enables efficient self-attention for high-dimensional data like images and videos. It employs axial/row-column attention to model dependencies along the different axes of the input tensor, making it suitable for various computer vision tasks involving 2D or 3D data.
The Axial Transformer model for 2-dimensional tensors. Before sampling a channel we encode all previous channels and frames with 8 blocks of unmasked row and unmasked column attention (left). Then, for each row, we apply 4 blocks of unmasked row and masked column attention to integrate the previously sampled rows for the active channels into our encoded representation (middle). Finally, we shift the encoded representation up to make sure the conditioning information satisfies causality, and we run the inner decoder consisting of 4 blocks of masked row attention to sample a new row in the image (right).
Key ideas of the paper:
- Decomposition of Multi-Dimensional Attention
- Rather than applying attention to a flattened string of tensor elements, our model instead applies attention along a single axis of the tensor without flattening—we refer to this as “axial attention.”
- Decomposes the multi-dimensional attention problem into several lower-dimensional attention operations.
- Reduced Computational Complexity
- Since the length of any single axis (that is, the height or width of an image) is typically much smaller than the total number of elements, an axial attention operation enjoys a significant saving in computation and memory over standard self-attention.
By attending to rows and columns sequentially, Axial Attention can still aggregate information across the entire spatial dimension. Although each attention operation is limited to one dimension, the sequential application allows the model to capture interactions across all dimensions:
- First, attention along rows captures dependencies within each row.
- Then, attention along columns captures dependencies within each column. This two-step process ensures that every position in the 2D space can attend to every other position, albeit indirectly.
- High Flexibility
- Axial Attention’s flexibility in handling high-dimensional data makes it suitable for various applications beyond images, such as video data, 3D data, and more, where full attention mechanisms would be even more impractical.
Model | Full receptive field | Attention faster than (O(N^2)) | Needs no custom kernels | Semi-parallel context aggregation |
---|---|---|---|---|
Transformer (Vaswani et al., 2017) | yes | no | yes | no |
Image Transformer (Parmar et al., 2018) | no | yes | yes | no |
Block Transformer (Weissenborn et al., 2019) | no | yes | yes | no |
Strided Sparse Transformer (Child et al., 2019) | yes | yes | no | no |
Axial Transformer | yes | yes | yes | yes |
1 | # https://github.com/lucidrains/axial-attention/blob/master/axial_attention/axial_attention.py#L153 |
Example Usage:
1 | #======================================= |
MSA Transformer
Multiple Sequence Alignment Transformer (2021)
The MSA Transformer is a transformer model designed to learn representations of evolutionary-related protein sequences from multiple sequence alignments (MSAs). It uses a masked language modeling objective to predict masked amino acids, capturing biologically relevant information useful for protein structure prediction tasks like contact prediction.
Left: Sparsity structure of the attention. By constraining attention to operate over rows and columns, computational cost is reduced from to where M is the number of rows and L the number of columns in the MSA. Middle: Untied row attention uses different attention maps for each sequence in the MSA. Tied row attention uses a single attention map for all sequences in the MSA, thereby constraining the contact structure. Ablation studies consider the use of both tied and untied attention. The final model uses tied attention. Right: A single MSA Transformer block. The depicted architecture is from the final model, some ablations alter the ordering of row and column attention.
Key ideas of the paper:
-
Alternates attention over rows and columns of the 2D state
-
This sparsity pattern in the attention over the MSA brings the attention cost to O(LM 2) for the column attention, and O(M L2) for the row attention.
-
Rather than applying a feedforward layer after each row or column attention, we apply row and column attention followed by a single feedforward layer.
-
1 | # https://github.com/rmrao/msa-transformer/blob/main/modules.py#L191 |