sglang_v0.5.2/flashinfer_0.3.1/tests/alibi_reference.py

123 lines
4.2 KiB
Python

"""
Attention with Linear Biases (ALiBi) reference implementation.
Code adapted from https://github.com/labmlai/annotated_deep_learning_paper_implementations
Licensed under MIT, you may obtain a copy of the License at
https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license
Source:
- https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/285cb3735bde02fbc8c19ddeb24d0ae7e77135c1/labml_nn/transformers/mha.py
- https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/285cb3735bde02fbc8c19ddeb24d0ae7e77135c1/labml_nn/transformers/alibi/__init__.py
"""
import math
from typing import Optional
import torch
def get_slopes(n_heads: int):
r"""
## Get head-specific slope $m$ for each head
* `n_heads` is the number of heads in the attention layer $n$
The slope for first head is
$$\frac{1}{2^{\frac{8}{n}}} = 2^{-\frac{8}{n}}$$
The slopes for the rest of the heads are in a geometric series with a ratio same as above.
For instance when the number of heads is $8$ the slopes are
$$\frac{1}{2^1}, \frac{1}{2^2}, \dots, \frac{1}{2^8}$$
"""
# Get the closest power of 2 to `n_heads`.
# If `n_heads` is not a power of 2, then we first calculate slopes to the closest (smaller) power of 2,
# and then add the remaining slopes.
n = 2 ** math.floor(math.log2(n_heads))
# $2^{-\frac{8}{n}}$
m_0 = 2.0 ** (-8.0 / n)
# $2^{-1\frac{8}{n}}, 2^{-2 \frac{8}{n}}, 2^{-3 \frac{8}{n}}, \dots$
m = torch.pow(m_0, torch.arange(1, 1 + n))
# If `n_heads` is not a power of 2, then we add the remaining slopes.
# We calculate the remaining slopes for $n * 2$ (avoiding slopes added previously).
# And pick the slopes upto `n_heads`.
if n < n_heads:
# $2^{-\frac{8}{2n}}$
m_hat_0 = 2.0 ** (-4.0 / n)
# $2^{-1\frac{8}{2n}}, 2^{-3 \frac{8}{2n}}, 2^{-5 \frac{8}{2n}}, \dots$
# Note that we take steps by $2$ to avoid slopes added previously.
m_hat = torch.pow(m_hat_0, torch.arange(1, 1 + 2 * (n_heads - n), 2))
# Concatenate the slopes with the remaining slopes.
m = torch.cat([m, m_hat])
return m
@torch.no_grad()
def get_alibi_biases(n_heads: int, mask: torch.Tensor):
"""
## Calculate the attention biases matrix
* `n_heads` is the number of heads in the attention layer
* `mask` is the attention mask of shape `[seq_len_q, seq_len_k]`
This returns a matrix of shape `[seq_len_q, seq_len_k, n_heads, ]` with ALiBi attention biases.
"""
# Get slopes $m$ for each head
m = get_slopes(n_heads).to(mask.device)
# Calculate distances $[0, 1, \dots, N]$
# Here we calculate the distances using the mask.
#
# Since it's causal mask we can just use $[0, 1, \dots, N]$ too.
distance = torch.arange(mask.shape[1], dtype=torch.long, device=mask.device)[
None, :
]
# Multiply them pair-wise to get the AliBi bias matrix
return distance[:, :, None] * m[None, None, :]
def alibi_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: Optional[torch.Tensor] = None,
):
"""
query: [q_len, num_heads, head_dim]
key: [kv_len, num_heads, head_dim]
value: [kv_len, num_heads, head_dim]
mask: [q_len, kv_len]
"""
q_len, num_heads, head_dim = query.shape
scores = torch.einsum("qhd,khd->qkh", query.float(), key.float())
# Scale scores $\frac{Q K^\top}{\sqrt{d_k}}$
scores *= 1.0 / math.sqrt(head_dim)
# Create AliBi biases if it's not cached
alibi_biases = get_alibi_biases(num_heads, mask)
# Add AliBi biases to attention scores.
# ALiBi biases has shape `[seq_len, seq_len, n_heads]`
# and `scores` has shape `[seq_len, seq_len, batch_size, n_heads]`
scores += alibi_biases
# Apply mask
scores = scores.masked_fill(mask.unsqueeze(-1) == 0, float("-inf"))
# $softmax$ attention along the key sequence dimension
# $\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_k}}\Bigg)$
attn = torch.softmax(scores, dim=1)
# Multiply by values
# $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_k}}\Bigg)V$$
return torch.einsum("ovh,vhd->ohd", attn, value.float()).to(query)