403 lines
16 KiB
Python
403 lines
16 KiB
Python
"""
|
|
Copyright (c) 2025 by FlashInfer team.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
"""
|
|
|
|
from typing import Optional
|
|
|
|
import einops
|
|
import torch
|
|
|
|
|
|
def sink_softmax(logits, sink):
|
|
sink = einops.repeat(sink, "h -> b h m 1", b=logits.shape[0], m=logits.shape[2])
|
|
# (b, h, m, (n + 1))
|
|
logits = torch.cat([logits, sink], dim=-1)
|
|
# (s_1, s_2, ..., s_n)
|
|
# (s_1, s_2, ..., s_n, log(sink))
|
|
# (exp(s_1), exp(s_2), ..., exp(s_n), sink)
|
|
# (exp(s_1) / (exp(s_1) + exp(s_2) + ... + exp(s_n) + sink),
|
|
# exp(s_2) / (exp(s_1) + exp(s_2) + ... + exp(s_n) + sink),
|
|
# ...,
|
|
# exp(s_n) / (exp(s_1) + exp(s_2) + ... + exp(s_n) + sink))
|
|
# sink / (exp(s_1) + exp(s_2) + ... + exp(s_n) + sink)
|
|
score = torch.softmax(logits, dim=-1)[..., :-1].contiguous()
|
|
return score
|
|
|
|
|
|
def sink_attention_unified(
|
|
q: torch.Tensor,
|
|
k: torch.Tensor,
|
|
v: torch.Tensor,
|
|
sink: torch.Tensor,
|
|
window_left: int,
|
|
causal: bool,
|
|
sm_scale: float,
|
|
batch_size: Optional[int] = None,
|
|
mode: str = "auto",
|
|
qo_indptr: Optional[torch.Tensor] = None,
|
|
kv_indptr: Optional[torch.Tensor] = None,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Unified sink attention implementation supporting prefill, incremental, chunk prefill, and variable-length scenarios.
|
|
|
|
Args:
|
|
q: Query tensor. Format depends on mode:
|
|
- Regular Prefill: [total_q_len, num_qo_heads, head_dim] where q_len == kv_len
|
|
- Incremental: [batch_size, num_qo_heads, head_dim] where q_len == 1
|
|
- Chunk Prefill: [total_q_len, num_qo_heads, head_dim] where q_len != kv_len and q_len > 1
|
|
- Variable Length: [total_q_len, num_qo_heads, head_dim] with different q_len per request
|
|
k: Key tensor. Format depends on mode:
|
|
- Regular Prefill: [total_kv_len, num_kv_heads, head_dim]
|
|
- Incremental: [batch_size, kv_len, num_kv_heads, head_dim]
|
|
- Chunk Prefill: [total_kv_len, num_kv_heads, head_dim]
|
|
- Variable Length: [total_kv_len, num_kv_heads, head_dim]
|
|
v: Value tensor, same format as k
|
|
sink: Sink values [num_qo_heads]
|
|
window_left: Sliding window size (-1 for no window)
|
|
causal: Whether to apply causal masking
|
|
sm_scale: Scaling factor for attention
|
|
batch_size: Required for prefill/chunk modes, auto-detected for incremental
|
|
mode: Processing mode:
|
|
- "auto": Auto-detect based on tensor shapes and dimensions
|
|
- "prefill": Regular prefill (q_len == kv_len)
|
|
- "incremental": Incremental generation (q_len == 1)
|
|
- "chunk": Chunk prefill (q_len != kv_len and q_len > 1)
|
|
- "varlen": Variable length sequences within batch
|
|
qo_indptr: Optional[torch.Tensor] - Query sequence length pointers for variable length mode.
|
|
Shape: [batch_size + 1]. qo_indptr[i+1] - qo_indptr[i] gives the query length for request i.
|
|
Only used when mode="varlen".
|
|
kv_indptr: Optional[torch.Tensor] - Key/Value sequence length pointers for variable length mode.
|
|
Shape: [batch_size + 1]. kv_indptr[i+1] - kv_indptr[i] gives the kv length for request i.
|
|
Only used when mode="varlen".
|
|
|
|
Returns:
|
|
Output tensor. Format depends on mode:
|
|
- Regular Prefill: [total_q_len, num_qo_heads, head_dim]
|
|
- Incremental: [batch_size, num_qo_heads, head_dim]
|
|
- Chunk Prefill: [total_q_len, num_qo_heads, head_dim]
|
|
- Variable Length: [total_q_len, num_qo_heads, head_dim]
|
|
"""
|
|
|
|
# Auto-detect mode if not specified
|
|
if mode == "auto":
|
|
# Check if variable length mode is indicated by presence of indptr
|
|
if qo_indptr is not None or kv_indptr is not None:
|
|
mode = "varlen"
|
|
elif len(q.shape) == 3 and len(k.shape) == 4:
|
|
# q: [batch_size, num_heads, head_dim], k: [batch_size, kv_len, num_heads, head_dim]
|
|
# This is incremental mode
|
|
mode = "incremental"
|
|
elif len(q.shape) == 3 and len(k.shape) == 3:
|
|
# Both q and k are flattened: [total_len, num_heads, head_dim]
|
|
if batch_size is None:
|
|
raise ValueError(
|
|
"batch_size is required for auto-detection in prefill/chunk modes"
|
|
)
|
|
|
|
qo_len = q.shape[0] // batch_size
|
|
kv_len = k.shape[0] // batch_size
|
|
|
|
if qo_len == kv_len:
|
|
mode = "prefill"
|
|
elif qo_len == 1:
|
|
mode = "incremental" # Special case: single token with flattened format
|
|
elif qo_len > 1 and qo_len != kv_len:
|
|
mode = "chunk"
|
|
else:
|
|
raise ValueError(
|
|
f"Cannot auto-detect mode: qo_len={qo_len}, kv_len={kv_len}"
|
|
)
|
|
else:
|
|
raise ValueError(
|
|
f"Cannot auto-detect mode from tensor shapes: q={q.shape}, k={k.shape}"
|
|
)
|
|
|
|
# Process based on detected/specified mode
|
|
if mode == "incremental":
|
|
# Incremental generation mode: q_len=1, kv_len from cache
|
|
batch_size = q.shape[0]
|
|
qo_len = 1
|
|
kv_len = k.shape[1]
|
|
num_qo_heads = q.shape[1]
|
|
num_kv_heads = k.shape[2]
|
|
|
|
# Handle GQA
|
|
if num_qo_heads != num_kv_heads:
|
|
k = torch.repeat_interleave(
|
|
k, num_qo_heads // num_kv_heads, dim=2
|
|
).contiguous()
|
|
v = torch.repeat_interleave(
|
|
v, num_qo_heads // num_kv_heads, dim=2
|
|
).contiguous()
|
|
num_kv_heads = num_qo_heads
|
|
|
|
head_dim_qk = q.shape[2]
|
|
head_dim_vo = v.shape[3]
|
|
|
|
# Compute logits: [batch_size, num_heads, 1, kv_len]
|
|
logits = (
|
|
torch.einsum(
|
|
"bhd,blhd->bhl",
|
|
q.float(),
|
|
k.float(),
|
|
).unsqueeze(2) # Add seq_len=1 dimension
|
|
* sm_scale
|
|
)
|
|
|
|
elif mode in ["prefill", "chunk"]:
|
|
# Prefill or Chunk prefill mode: q and k are flattened tensors
|
|
if batch_size is None:
|
|
raise ValueError(f"batch_size is required for {mode} mode")
|
|
|
|
qo_len = q.shape[0] // batch_size
|
|
kv_len = k.shape[0] // batch_size
|
|
num_qo_heads = q.shape[1]
|
|
num_kv_heads = k.shape[1]
|
|
|
|
# Handle GQA
|
|
if num_qo_heads != num_kv_heads:
|
|
k = torch.repeat_interleave(
|
|
k, num_qo_heads // num_kv_heads, dim=1
|
|
).contiguous()
|
|
v = torch.repeat_interleave(
|
|
v, num_qo_heads // num_kv_heads, dim=1
|
|
).contiguous()
|
|
|
|
head_dim_qk = q.shape[2]
|
|
head_dim_vo = v.shape[2]
|
|
|
|
# Compute logits: [batch_size, num_heads, qo_len, kv_len]
|
|
logits = (
|
|
torch.einsum(
|
|
"bmhd,bnhd->bhmn",
|
|
q.view(batch_size, qo_len, num_qo_heads, head_dim_qk).float(),
|
|
k.view(batch_size, kv_len, num_qo_heads, head_dim_qk).float(),
|
|
)
|
|
* sm_scale
|
|
)
|
|
|
|
elif mode == "varlen":
|
|
# Variable length sequences mode
|
|
if qo_indptr is None or kv_indptr is None:
|
|
raise ValueError("qo_indptr and kv_indptr are required for varlen mode")
|
|
|
|
batch_size = qo_indptr.shape[0] - 1
|
|
num_qo_heads = q.shape[1]
|
|
num_kv_heads = k.shape[1]
|
|
head_dim_qk = q.shape[2]
|
|
head_dim_vo = v.shape[2]
|
|
|
|
# Handle GQA
|
|
if num_qo_heads != num_kv_heads:
|
|
k = torch.repeat_interleave(
|
|
k, num_qo_heads // num_kv_heads, dim=1
|
|
).contiguous()
|
|
v = torch.repeat_interleave(
|
|
v, num_qo_heads // num_kv_heads, dim=1
|
|
).contiguous()
|
|
num_kv_heads = num_qo_heads
|
|
|
|
# Process each request in the batch separately
|
|
output_list = []
|
|
|
|
for i in range(batch_size):
|
|
# Extract tensors for current request
|
|
qo_start, qo_end = qo_indptr[i].item(), qo_indptr[i + 1].item()
|
|
kv_start, kv_end = kv_indptr[i].item(), kv_indptr[i + 1].item()
|
|
|
|
q_i = q[qo_start:qo_end] # [qo_len_i, num_heads, head_dim]
|
|
k_i = k[kv_start:kv_end] # [kv_len_i, num_heads, head_dim]
|
|
v_i = v[kv_start:kv_end] # [kv_len_i, num_heads, head_dim]
|
|
|
|
qo_len_i = qo_end - qo_start
|
|
kv_len_i = kv_end - kv_start
|
|
|
|
# Compute logits for current request: [1, num_heads, qo_len_i, kv_len_i]
|
|
logits_i = (
|
|
torch.einsum(
|
|
"qhd,khd->hqk",
|
|
q_i.float(),
|
|
k_i.float(),
|
|
).unsqueeze(0) # Add batch dimension
|
|
* sm_scale
|
|
)
|
|
|
|
# Build attention mask for current request
|
|
if causal:
|
|
# Create causal mask for this specific request
|
|
row_idx = torch.arange(qo_len_i, dtype=torch.int32, device=q.device)[
|
|
:, None
|
|
]
|
|
col_idx = torch.arange(kv_len_i, dtype=torch.int32, device=q.device)[
|
|
None, :
|
|
]
|
|
|
|
# Default causal mask: position i can attend to positions 0 to i in the kv sequence
|
|
# Assuming queries correspond to the last qo_len_i positions in the kv sequence
|
|
query_positions = kv_len_i - qo_len_i + row_idx
|
|
mask_i = query_positions >= col_idx
|
|
|
|
if window_left >= 0:
|
|
mask_i &= query_positions - window_left <= col_idx
|
|
else:
|
|
# Non-causal mask
|
|
mask_i = torch.ones(
|
|
qo_len_i, kv_len_i, device=q.device, dtype=torch.bool
|
|
)
|
|
if window_left >= 0:
|
|
row_idx = torch.arange(
|
|
qo_len_i, dtype=torch.int32, device=q.device
|
|
)[:, None]
|
|
col_idx = torch.arange(
|
|
kv_len_i, dtype=torch.int32, device=q.device
|
|
)[None, :]
|
|
query_positions = kv_len_i - qo_len_i + row_idx
|
|
mask_i = query_positions - window_left <= col_idx
|
|
|
|
# Apply mask
|
|
logits_i = logits_i.masked_fill(
|
|
mask_i.unsqueeze(0).unsqueeze(0) == 0, float("-inf")
|
|
)
|
|
|
|
# Apply sink softmax
|
|
p_i = sink_softmax(logits_i, sink) # [1, num_heads, qo_len_i, kv_len_i]
|
|
|
|
# Compute output for current request
|
|
o_i = (
|
|
torch.einsum(
|
|
"bhmn,nhd->bmhd",
|
|
p_i, # [1, num_heads, qo_len_i, kv_len_i]
|
|
v_i.float(), # [kv_len_i, num_heads, head_dim]
|
|
)
|
|
.contiguous()
|
|
.view(qo_len_i, num_qo_heads, head_dim_vo)
|
|
.to(q)
|
|
)
|
|
|
|
output_list.append(o_i)
|
|
|
|
# Concatenate outputs from all requests
|
|
o_ref = torch.cat(output_list, dim=0)
|
|
|
|
return o_ref
|
|
|
|
else:
|
|
raise ValueError(
|
|
f"Unknown mode: {mode}. Supported modes: 'auto', 'prefill', 'incremental', 'chunk', 'varlen'"
|
|
)
|
|
|
|
# Build attention mask (unified for all modes)
|
|
if causal:
|
|
if mode == "incremental":
|
|
# For incremental: new token can attend to all previous tokens
|
|
mask = torch.ones(1, kv_len, device=q.device, dtype=torch.bool)
|
|
if window_left >= 0:
|
|
col_idx = torch.arange(kv_len, dtype=torch.int32, device=q.device)
|
|
mask = (kv_len - 1 - window_left) <= col_idx
|
|
elif mode == "prefill":
|
|
# For regular prefill: standard causal mask
|
|
mask = torch.arange(kv_len - qo_len, kv_len, device=q.device).unsqueeze(
|
|
1
|
|
) >= torch.arange(0, kv_len, device=q.device).unsqueeze(0)
|
|
if window_left >= 0:
|
|
row_idx = torch.arange(qo_len, dtype=torch.int32, device=q.device)[
|
|
:, None
|
|
]
|
|
col_idx = torch.arange(kv_len, dtype=torch.int32, device=q.device)[
|
|
None, :
|
|
]
|
|
mask &= row_idx - window_left <= col_idx
|
|
elif mode == "chunk":
|
|
# For chunk prefill: each query position can attend to all previous KV positions
|
|
# Current chunk positions are at the end: [kv_len - qo_len : kv_len]
|
|
current_chunk_start = kv_len - qo_len
|
|
row_idx = torch.arange(qo_len, dtype=torch.int32, device=q.device)[
|
|
:, None
|
|
] # Positions within chunk
|
|
col_idx = torch.arange(kv_len, dtype=torch.int32, device=q.device)[
|
|
None, :
|
|
] # All KV positions
|
|
|
|
# Each position can attend to: all historical + positions up to itself in current chunk
|
|
abs_row_positions = (
|
|
current_chunk_start + row_idx
|
|
) # Absolute positions in full sequence
|
|
mask = abs_row_positions >= col_idx # Standard causal mask
|
|
|
|
if window_left >= 0:
|
|
mask &= abs_row_positions - window_left <= col_idx
|
|
else:
|
|
# Non-causal mask
|
|
if mode == "incremental":
|
|
mask = torch.ones(1, kv_len, device=q.device, dtype=torch.bool)
|
|
if window_left >= 0:
|
|
col_idx = torch.arange(kv_len, dtype=torch.int32, device=q.device)
|
|
mask = (kv_len - 1 - window_left) <= col_idx
|
|
else: # prefill or chunk
|
|
mask = torch.ones(qo_len, kv_len, device=q.device, dtype=torch.bool)
|
|
if window_left >= 0:
|
|
if mode == "chunk":
|
|
# For chunk mode, apply window relative to absolute positions
|
|
current_chunk_start = kv_len - qo_len
|
|
row_idx = torch.arange(qo_len, dtype=torch.int32, device=q.device)[
|
|
:, None
|
|
]
|
|
col_idx = torch.arange(kv_len, dtype=torch.int32, device=q.device)[
|
|
None, :
|
|
]
|
|
abs_row_positions = current_chunk_start + row_idx
|
|
mask = abs_row_positions - window_left <= col_idx
|
|
else: # prefill
|
|
row_idx = torch.arange(qo_len, dtype=torch.int32, device=q.device)[
|
|
:, None
|
|
]
|
|
col_idx = torch.arange(kv_len, dtype=torch.int32, device=q.device)[
|
|
None, :
|
|
]
|
|
mask = row_idx - window_left <= col_idx
|
|
|
|
# Apply mask
|
|
logits = logits.masked_fill(mask.unsqueeze(0).unsqueeze(0) == 0, float("-inf"))
|
|
|
|
# Apply sink softmax
|
|
p = sink_softmax(logits, sink)
|
|
|
|
# Compute output
|
|
if mode == "incremental":
|
|
# Incremental mode output
|
|
o_ref = (
|
|
torch.einsum(
|
|
"bhml,blhd->bhd",
|
|
p, # [batch_size, num_heads, 1, kv_len]
|
|
v.float(), # [batch_size, kv_len, num_heads, head_dim]
|
|
)
|
|
.contiguous()
|
|
.to(q)
|
|
)
|
|
else: # prefill or chunk mode
|
|
# Prefill/Chunk mode output
|
|
o_ref = (
|
|
torch.einsum(
|
|
"bhmn,bnhd->bmhd",
|
|
p, # [batch_size, num_heads, qo_len, kv_len]
|
|
v.view(batch_size, kv_len, num_qo_heads, head_dim_vo).float(),
|
|
)
|
|
.contiguous()
|
|
.view(batch_size * qo_len, num_qo_heads, head_dim_vo)
|
|
.to(q)
|
|
)
|
|
|
|
return o_ref
|