123 lines
4.2 KiB
Python
123 lines
4.2 KiB
Python
"""
|
|
Attention with Linear Biases (ALiBi) reference implementation.
|
|
|
|
Code adapted from https://github.com/labmlai/annotated_deep_learning_paper_implementations
|
|
|
|
Licensed under MIT, you may obtain a copy of the License at
|
|
|
|
https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license
|
|
|
|
Source:
|
|
- https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/285cb3735bde02fbc8c19ddeb24d0ae7e77135c1/labml_nn/transformers/mha.py
|
|
- https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/285cb3735bde02fbc8c19ddeb24d0ae7e77135c1/labml_nn/transformers/alibi/__init__.py
|
|
"""
|
|
|
|
import math
|
|
from typing import Optional
|
|
|
|
import torch
|
|
|
|
|
|
def get_slopes(n_heads: int):
|
|
r"""
|
|
## Get head-specific slope $m$ for each head
|
|
|
|
* `n_heads` is the number of heads in the attention layer $n$
|
|
|
|
The slope for first head is
|
|
|
|
$$\frac{1}{2^{\frac{8}{n}}} = 2^{-\frac{8}{n}}$$
|
|
|
|
The slopes for the rest of the heads are in a geometric series with a ratio same as above.
|
|
|
|
For instance when the number of heads is $8$ the slopes are
|
|
$$\frac{1}{2^1}, \frac{1}{2^2}, \dots, \frac{1}{2^8}$$
|
|
"""
|
|
|
|
# Get the closest power of 2 to `n_heads`.
|
|
# If `n_heads` is not a power of 2, then we first calculate slopes to the closest (smaller) power of 2,
|
|
# and then add the remaining slopes.
|
|
n = 2 ** math.floor(math.log2(n_heads))
|
|
# $2^{-\frac{8}{n}}$
|
|
m_0 = 2.0 ** (-8.0 / n)
|
|
# $2^{-1\frac{8}{n}}, 2^{-2 \frac{8}{n}}, 2^{-3 \frac{8}{n}}, \dots$
|
|
m = torch.pow(m_0, torch.arange(1, 1 + n))
|
|
|
|
# If `n_heads` is not a power of 2, then we add the remaining slopes.
|
|
# We calculate the remaining slopes for $n * 2$ (avoiding slopes added previously).
|
|
# And pick the slopes upto `n_heads`.
|
|
if n < n_heads:
|
|
# $2^{-\frac{8}{2n}}$
|
|
m_hat_0 = 2.0 ** (-4.0 / n)
|
|
# $2^{-1\frac{8}{2n}}, 2^{-3 \frac{8}{2n}}, 2^{-5 \frac{8}{2n}}, \dots$
|
|
# Note that we take steps by $2$ to avoid slopes added previously.
|
|
m_hat = torch.pow(m_hat_0, torch.arange(1, 1 + 2 * (n_heads - n), 2))
|
|
# Concatenate the slopes with the remaining slopes.
|
|
m = torch.cat([m, m_hat])
|
|
|
|
return m
|
|
|
|
|
|
@torch.no_grad()
|
|
def get_alibi_biases(n_heads: int, mask: torch.Tensor):
|
|
"""
|
|
## Calculate the attention biases matrix
|
|
|
|
* `n_heads` is the number of heads in the attention layer
|
|
* `mask` is the attention mask of shape `[seq_len_q, seq_len_k]`
|
|
|
|
This returns a matrix of shape `[seq_len_q, seq_len_k, n_heads, ]` with ALiBi attention biases.
|
|
"""
|
|
|
|
# Get slopes $m$ for each head
|
|
m = get_slopes(n_heads).to(mask.device)
|
|
|
|
# Calculate distances $[0, 1, \dots, N]$
|
|
# Here we calculate the distances using the mask.
|
|
#
|
|
# Since it's causal mask we can just use $[0, 1, \dots, N]$ too.
|
|
distance = torch.arange(mask.shape[1], dtype=torch.long, device=mask.device)[
|
|
None, :
|
|
]
|
|
|
|
# Multiply them pair-wise to get the AliBi bias matrix
|
|
return distance[:, :, None] * m[None, None, :]
|
|
|
|
|
|
def alibi_attention(
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
value: torch.Tensor,
|
|
mask: Optional[torch.Tensor] = None,
|
|
):
|
|
"""
|
|
query: [q_len, num_heads, head_dim]
|
|
key: [kv_len, num_heads, head_dim]
|
|
value: [kv_len, num_heads, head_dim]
|
|
mask: [q_len, kv_len]
|
|
"""
|
|
q_len, num_heads, head_dim = query.shape
|
|
|
|
scores = torch.einsum("qhd,khd->qkh", query.float(), key.float())
|
|
# Scale scores $\frac{Q K^\top}{\sqrt{d_k}}$
|
|
scores *= 1.0 / math.sqrt(head_dim)
|
|
|
|
# Create AliBi biases if it's not cached
|
|
alibi_biases = get_alibi_biases(num_heads, mask)
|
|
|
|
# Add AliBi biases to attention scores.
|
|
# ALiBi biases has shape `[seq_len, seq_len, n_heads]`
|
|
# and `scores` has shape `[seq_len, seq_len, batch_size, n_heads]`
|
|
scores += alibi_biases
|
|
|
|
# Apply mask
|
|
scores = scores.masked_fill(mask.unsqueeze(-1) == 0, float("-inf"))
|
|
|
|
# $softmax$ attention along the key sequence dimension
|
|
# $\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_k}}\Bigg)$
|
|
attn = torch.softmax(scores, dim=1)
|
|
|
|
# Multiply by values
|
|
# $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_k}}\Bigg)V$$
|
|
return torch.einsum("ovh,vhd->ohd", attn, value.float()).to(query)
|