Revisit Self-Attention

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
class SelfAttention(nn.Module):
def __init__(self, dim, heads, dim_heads = None):
super().__init__()
self.dim_heads = (dim // heads) if dim_heads is None else dim_heads
dim_hidden = self.dim_heads * heads

self.heads = heads
self.to_q = nn.Linear(dim, dim_hidden, bias = False)
self.to_kv = nn.Linear(dim, 2 * dim_hidden, bias = False)
self.to_out = nn.Linear(dim_hidden, dim)

def forward(self, x, kv = None):
kv = x if kv is None else kv
q, k, v = (self.to_q(x), *self.to_kv(kv).chunk(2, dim=-1))

b, t, d, h, e = *q.shape, self.heads, self.dim_heads

merge_heads = lambda x: x.reshape(b, -1, h, e).transpose(1, 2).reshape(b * h, -1, e)
q, k, v = map(merge_heads, (q, k, v))

dots = torch.einsum('bie,bje->bij', q, k) * (e ** -0.5)
dots = dots.softmax(dim=-1)
out = torch.einsum('bij,bje->bie', dots, v)

out = out.reshape(b, h, -1, e).transpose(1, 2).reshape(b, -1, d)
out = self.to_out(out)
return out

Axial Transformer

Axial Attention in Multidimensional Transformers (2019)

The Axial Transformer is a transformer architecture that enables efficient self-attention for high-dimensional data like images and videos. It employs axial/row-column attention to model dependencies along the different axes of the input tensor, making it suitable for various computer vision tasks involving 2D or 3D data.

The Axial Transformer model for 2-dimensional tensors. Before sampling a channel we encode all previous channels and frames with 8 blocks of unmasked row and unmasked column attention (left). Then, for each row, we apply 4 blocks of unmasked row and masked column attention to integrate the previously sampled rows for the active channels into our encoded representation (middle). Finally, we shift the encoded representation up to make sure the conditioning information satisfies causality, and we run the inner decoder consisting of 4 blocks of masked row attention to sample a new row in the image (right).

Key ideas of the paper:

  • Decomposition of Multi-Dimensional Attention
    • Rather than applying attention to a flattened string of tensor elements, our model instead applies attention along a single axis of the tensor without flattening—we refer to this as “axial attention.”
    • Decomposes the multi-dimensional attention problem into several lower-dimensional attention operations.
  • Reduced Computational Complexity
    • Since the length of any single axis (that is, the height or width of an image) is typically much smaller than the total number of elements, an axial attention operation enjoys a significant saving in computation and memory over standard self-attention.

By attending to rows and columns sequentially, Axial Attention can still aggregate information across the entire spatial dimension. Although each attention operation is limited to one dimension, the sequential application allows the model to capture interactions across all dimensions:

  • First, attention along rows captures dependencies within each row.
  • Then, attention along columns captures dependencies within each column. This two-step process ensures that every position in the 2D space can attend to every other position, albeit indirectly.
  • High Flexibility
    • Axial Attention’s flexibility in handling high-dimensional data makes it suitable for various applications beyond images, such as video data, 3D data, and more, where full attention mechanisms would be even more impractical.
Model Full receptive field Attention faster than (O(N^2)) Needs no custom kernels Semi-parallel context aggregation
Transformer (Vaswani et al., 2017) yes no yes no
Image Transformer (Parmar et al., 2018) no yes yes no
Block Transformer (Weissenborn et al., 2019) no yes yes no
Strided Sparse Transformer (Child et al., 2019) yes yes no no
Axial Transformer yes yes yes yes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
# https://github.com/lucidrains/axial-attention/blob/master/axial_attention/axial_attention.py#L153
class AxialAttention(nn.Module):
def __init__(self, dim, num_dimensions = 2, heads = 8, dim_heads = None, dim_index = -1, sum_axial_out = True):
assert (dim % heads) == 0, 'hidden dimension must be divisible by number of heads'
super().__init__()
self.dim = dim
self.total_dimensions = num_dimensions + 2
self.dim_index = dim_index if dim_index > 0 else (dim_index + self.total_dimensions)

attentions = []
# calculate_permutations function generates permutations to rearrange
# the tensor dimensions for each axial attention operation.
for permutation in calculate_permutations(num_dimensions, dim_index):
# Each PermuteToFrom instance in self.axial_attentions applies SelfAttention
# to one specific dimension (either rows or columns).
attentions.append(PermuteToFrom(permutation, SelfAttention(dim, heads, dim_heads)))


self.axial_attentions = nn.ModuleList(attentions)
self.sum_axial_out = sum_axial_out # determines whether it sums the results or processes them in sequence.

def forward(self, x):
assert len(x.shape) == self.total_dimensions, 'input tensor does not have the correct number of dimensions'
assert x.shape[self.dim_index] == self.dim, 'input tensor does not have the correct input dimension'

if self.sum_axial_out:
return sum(map(lambda axial_attn: axial_attn(x), self.axial_attentions))

out = x
for axial_attn in self.axial_attentions:
out = axial_attn(out)
return out

class PermuteToFrom(nn.Module):
def __init__(self, permutation, fn):
super().__init__()
self.fn = fn
_, inv_permutation = sort_and_return_indices(permutation)
self.permutation = permutation
self.inv_permutation = inv_permutation

def forward(self, x, **kwargs):
axial = x.permute(*self.permutation).contiguous()

shape = axial.shape
*_, t, d = shape

# merge all but axial dimension
axial = axial.reshape(-1, t, d)

# attention
axial = self.fn(axial, **kwargs)

# restore to original shape and permutation
axial = axial.reshape(*shape)
axial = axial.permute(*self.inv_permutation).contiguous()
return axial

def calculate_permutations(num_dimensions, emb_dim):
total_dimensions = num_dimensions + 2
emb_dim = emb_dim if emb_dim > 0 else (emb_dim + total_dimensions)
axial_dims = [ind for ind in range(1, total_dimensions) if ind != emb_dim]

permutations = []

for axial_dim in axial_dims:
last_two_dims = [axial_dim, emb_dim]
dims_rest = set(range(0, total_dimensions)) - set(last_two_dims)
permutation = [*dims_rest, *last_two_dims]
permutations.append(permutation)

return permutations

Example Usage:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
#=======================================
# Image
img = torch.randn(1, 3, 256, 256)

attn = AxialAttention(
dim = 3, # embedding dimension
dim_index = 1, # where is the embedding dimension
dim_heads = 32, # dimension of each head. defaults to dim // heads if not supplied
heads = 1, # number of heads for multi-head attention
num_dimensions = 2, # number of axial dimensions (images is 2, video is 3, or more)
sum_axial_out = True # whether to sum the contributions of attention on each axis, or to run the input through them sequentially. defaults to true
)

attn(img) # (1, 3, 256, 256)

#=======================================
# Channel-last image latents
img = torch.randn(1, 20, 20, 512)

attn = AxialAttention(
dim = 512, # embedding dimension
dim_index = -1, # where is the embedding dimension
heads = 8, # number of heads for multi-head attention
num_dimensions = 2, # number of axial dimensions (images is 2, video is 3, or more)
)

attn(img) # (1, 20, 20 ,512)

#=======================================
# Video

video = torch.randn(1, 5, 128, 256, 256)

attn = AxialAttention(
dim = 128, # embedding dimension
dim_index = 2, # where is the embedding dimension
heads = 8, # number of heads for multi-head attention
num_dimensions = 3, # number of axial dimensions (images is 2, video is 3, or more)
)

attn(video) # (1, 5, 128, 256, 256)
#=======================================

MSA Transformer

Multiple Sequence Alignment Transformer (2021)

The MSA Transformer is a transformer model designed to learn representations of evolutionary-related protein sequences from multiple sequence alignments (MSAs). It uses a masked language modeling objective to predict masked amino acids, capturing biologically relevant information useful for protein structure prediction tasks like contact prediction.

Left: Sparsity structure of the attention. By constraining attention to operate over rows and columns, computational cost is reduced from O(M2L2)O(M ^2L^2) to O(LM2)+O(ML2)O(LM^2) + O(ML^2) where M is the number of rows and L the number of columns in the MSA. Middle: Untied row attention uses different attention maps for each sequence in the MSA. Tied row attention uses a single attention map for all sequences in the MSA, thereby constraining the contact structure. Ablation studies consider the use of both tied and untied attention. The final model uses tied attention. Right: A single MSA Transformer block. The depicted architecture is from the final model, some ablations alter the ordering of row and column attention.

Key ideas of the paper:

  • Alternates attention over rows and columns of the 2D state

    • This sparsity pattern in the attention over the MSA brings the attention cost to O(LM 2) for the column attention, and O(M L2) for the row attention.

    • Rather than applying a feedforward layer after each row or column attention, we apply row and column attention followed by a single feedforward layer.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
# https://github.com/rmrao/msa-transformer/blob/main/modules.py#L191
class AxialTransformerLayer(nn.Module):
"""Implements an Axial MSA Transformer block."""

def __init__(
self,
embedding_dim: int = 768,
ffn_embedding_dim: int = 3072,
num_attention_heads: int = 8,
dropout: float = 0.1,
attention_dropout: float = 0.1,
activation_dropout: float = 0.1,
max_tokens_per_msa: int = 2 ** 14,
) -> None:
super().__init__()

# Initialize parameters
self.embedding_dim = embedding_dim
self.dropout_prob = dropout

row_self_attention = RowSelfAttention(
embedding_dim,
num_attention_heads,
dropout=dropout,
max_tokens_per_msa=max_tokens_per_msa,
)

column_self_attention = ColumnSelfAttention(
embedding_dim,
num_attention_heads,
dropout=dropout,
max_tokens_per_msa=max_tokens_per_msa,
)

feed_forward_layer = FeedForwardNetwork(
embedding_dim,
ffn_embedding_dim,
activation_dropout=activation_dropout,
max_tokens_per_msa=max_tokens_per_msa,
)

self.row_self_attention = self.build_residual(row_self_attention)
self.column_self_attention = self.build_residual(column_self_attention)
self.feed_forward_layer = self.build_residual(feed_forward_layer)

def build_residual(self, layer: nn.Module):
return NormalizedResidualBlock(
layer,
self.embedding_dim,
self.dropout_prob,
)

def forward(
self,
x: torch.Tensor,
self_attn_mask: Optional[torch.Tensor] = None,
self_attn_padding_mask: Optional[torch.Tensor] = None,
need_head_weights: bool = False,
):
"""
LayerNorm is applied either before or after the self-attention/ffn
modules similar to the original Transformer implementation.
"""
x, row_attn = self.row_self_attention(
x,
self_attn_mask=self_attn_mask,
self_attn_padding_mask=self_attn_padding_mask,
)
x, column_attn = self.column_self_attention(
x,
self_attn_mask=self_attn_mask,
self_attn_padding_mask=self_attn_padding_mask,
)
x = self.feed_forward_layer(x)
if need_head_weights:
return x, column_attn, row_attn
else:
return x

class RowSelfAttention(nn.Module):
"""Compute self-attention over rows of a 2D input."""

def __init__(
self,
embed_dim,
num_heads,
dropout=0.0,
max_tokens_per_msa: int = 2 ** 16,
):
super().__init__()
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
self.scaling = self.head_dim ** -0.5
self.max_tokens_per_msa = max_tokens_per_msa
self.attn_shape = "hnij"

self.k_proj = nn.Linear(embed_dim, embed_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)
self.q_proj = nn.Linear(embed_dim, embed_dim)

self.out_proj = nn.Linear(embed_dim, embed_dim)
self.dropout_module = nn.Dropout(dropout)

def align_scaling(self, q):
num_rows = q.size(0)
return self.scaling / math.sqrt(num_rows)

def _batched_forward(
self,
x,
self_attn_mask=None,
self_attn_padding_mask=None,
):
num_rows, num_cols, batch_size, embed_dim = x.size()
max_rows = max(1, self.max_tokens_per_msa // num_cols)
attns = 0
scaling = self.align_scaling(x)
for start in range(0, num_rows, max_rows):
attn_weights = self.compute_attention_weights(
x[start : start + max_rows],
scaling,
self_attn_mask=self_attn_mask,
self_attn_padding_mask=self_attn_padding_mask[
:, start : start + max_rows
]
if self_attn_padding_mask is not None
else None,
)
attns += attn_weights
attn_probs = attns.softmax(-1)
attn_probs = self.dropout_module(attn_probs)

outputs = []
for start in range(0, num_rows, max_rows):
output = self.compute_attention_update(
x[start : start + max_rows], attn_probs
)
outputs.append(output)

output = torch.cat(outputs, 0)
return output, attn_probs

def compute_attention_weights(
self,
x,
scaling: float,
self_attn_mask=None,
self_attn_padding_mask=None,
):
num_rows, num_cols, batch_size, embed_dim = x.size()
q = self.q_proj(x).view(
num_rows, num_cols, batch_size, self.num_heads, self.head_dim
)
k = self.k_proj(x).view(
num_rows, num_cols, batch_size, self.num_heads, self.head_dim
)
q *= scaling
if self_attn_padding_mask is not None:
# Zero out any padded aligned positions - this is important since
# we take a sum across the alignment axis.
q *= 1 - self_attn_padding_mask.permute(1, 2, 0).unsqueeze(3).unsqueeze(
4
).to(q)

attn_weights = torch.einsum(f"rinhd,rjnhd->{self.attn_shape}", q, k)

if self_attn_mask is not None:
raise NotImplementedError
# Mask Size: [B x R x C], Weights Size: [H x B x C x C]

if self_attn_padding_mask is not None:
attn_weights = attn_weights.masked_fill(
self_attn_padding_mask[:, 0].unsqueeze(0).unsqueeze(2),
-10000,
)

return attn_weights

def compute_attention_update(
self,
x,
attn_probs,
):
num_rows, num_cols, batch_size, embed_dim = x.size()
v = self.v_proj(x).view(
num_rows, num_cols, batch_size, self.num_heads, self.head_dim
)
context = torch.einsum(f"{self.attn_shape},rjnhd->rinhd", attn_probs, v)
context = context.contiguous().view(num_rows, num_cols, batch_size, embed_dim)
output = self.out_proj(context)
return output

def forward(
self,
x,
self_attn_mask=None,
self_attn_padding_mask=None,
):
num_rows, num_cols, batch_size, embed_dim = x.size()
if (
num_rows * num_cols > self.max_tokens_per_msa
) and not torch.is_grad_enabled():
return self._batched_forward(x, self_attn_mask, self_attn_padding_mask)
else:
scaling = self.align_scaling(x)
attn_weights = self.compute_attention_weights(
x, scaling, self_attn_mask, self_attn_padding_mask
)
attn_probs = attn_weights.softmax(-1)
attn_probs = self.dropout_module(attn_probs)
output = self.compute_attention_update(x, attn_probs)
return output, attn_probs


class ColumnSelfAttention(nn.Module):
"""Compute self-attention over columns of a 2D input."""

def __init__(
self,
embed_dim,
num_heads,
dropout=0.0,
max_tokens_per_msa: int = 2 ** 16,
):
super().__init__()

self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
self.scaling = self.head_dim ** -0.5
self.max_tokens_per_msa = max_tokens_per_msa

self.k_proj = nn.Linear(embed_dim, embed_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)
self.q_proj = nn.Linear(embed_dim, embed_dim)

self.out_proj = nn.Linear(embed_dim, embed_dim)
self.dropout_module = nn.Dropout(dropout)

def _batched_forward(
self,
x,
self_attn_mask=None,
self_attn_padding_mask=None,
):
num_rows, num_cols, batch_size, embed_dim = x.size()
max_cols = max(1, self.max_tokens_per_msa // num_rows)
outputs = []
attns = []
for start in range(0, num_cols, max_cols):
output, attn = self(
x[:, start : start + max_cols],
self_attn_mask=self_attn_mask,
self_attn_padding_mask=self_attn_padding_mask[
:, :, start : start + max_cols
]
if self_attn_padding_mask is not None
else None,
)
outputs.append(output)
attns.append(attn)
output = torch.cat(outputs, 1)
attns = torch.cat(attns, 1)
return output, attns

def compute_attention_update(
self,
x,
self_attn_mask=None,
self_attn_padding_mask=None,
):
num_rows, num_cols, batch_size, embed_dim = x.size()
if num_rows == 1:
# if there is only 1 position, this is equivalent and doesn't break with
# padding
attn_probs = torch.ones(
self.num_heads,
num_cols,
batch_size,
num_rows,
num_rows,
device=x.device,
dtype=x.dtype,
)
output = self.out_proj(self.v_proj(x))
else:
q = self.q_proj(x).view(
num_rows, num_cols, batch_size, self.num_heads, self.head_dim
)
k = self.k_proj(x).view(
num_rows, num_cols, batch_size, self.num_heads, self.head_dim
)
v = self.v_proj(x).view(
num_rows, num_cols, batch_size, self.num_heads, self.head_dim
)
q *= self.scaling

attn_weights = torch.einsum("icnhd,jcnhd->hcnij", q, k)

if self_attn_mask is not None:
raise NotImplementedError
if self_attn_padding_mask is not None:
attn_weights = attn_weights.masked_fill(
self_attn_padding_mask.permute(2, 0, 1).unsqueeze(0).unsqueeze(3),
-10000,
)

attn_probs = attn_weights.softmax(-1)
attn_probs = self.dropout_module(attn_probs)
context = torch.einsum("hcnij,jcnhd->icnhd", attn_probs, v)
context = context.contiguous().view(
num_rows, num_cols, batch_size, embed_dim
)
output = self.out_proj(context)
return output, attn_probs

def forward(
self,
x,
self_attn_mask=None,
self_attn_padding_mask=None,
):
num_rows, num_cols, batch_size, embed_dim = x.size()
# if False and num_rows * num_cols > 2 ** 14 and not torch.is_grad_enabled():
if (
num_rows * num_cols
) > self.max_tokens_per_msa and not torch.is_grad_enabled():
return self._batched_forward(
x,
self_attn_mask,
self_attn_padding_mask,
)
else:
return self.compute_attention_update(
x, self_attn_mask, self_attn_padding_mask
)