457 lines
13 KiB
Python
457 lines
13 KiB
Python
# Copyright 2023-2024 SGLang Team
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""
|
|
Memory-efficient attention for prefill.
|
|
It supports page size = 1 and prefill with KV cache (i.e. extend).
|
|
"""
|
|
|
|
import torch
|
|
import triton
|
|
import triton.language as tl
|
|
|
|
from sglang.srt.layers.attention.triton_ops.prefill_attention import (
|
|
context_attention_fwd,
|
|
)
|
|
from sglang.srt.utils import is_hip
|
|
|
|
is_cuda_available = torch.cuda.is_available()
|
|
if is_cuda_available:
|
|
CUDA_CAPABILITY = torch.cuda.get_device_capability()
|
|
|
|
_is_hip = is_hip()
|
|
|
|
|
|
@triton.jit
|
|
def tanh(x):
|
|
# Tanh is just a scaled sigmoid
|
|
return 2 * tl.sigmoid(2 * x) - 1
|
|
|
|
|
|
@triton.jit
|
|
def _fwd_kernel(
|
|
Q_Extend,
|
|
K_Extend,
|
|
V_Extend,
|
|
O_Extend,
|
|
K_Buffer,
|
|
V_Buffer,
|
|
qo_indptr,
|
|
kv_indptr,
|
|
kv_indices,
|
|
mask_ptr,
|
|
mask_indptr,
|
|
sm_scale,
|
|
kv_group_num,
|
|
stride_qbs,
|
|
stride_qh,
|
|
stride_kbs,
|
|
stride_kh,
|
|
stride_vbs,
|
|
stride_vh,
|
|
stride_obs,
|
|
stride_oh,
|
|
stride_buf_kbs,
|
|
stride_buf_kh,
|
|
stride_buf_vbs,
|
|
stride_buf_vh,
|
|
logit_cap: tl.constexpr,
|
|
Lq: tl.constexpr,
|
|
Lv: tl.constexpr,
|
|
BLOCK_DMODEL: tl.constexpr,
|
|
BLOCK_DPE: tl.constexpr,
|
|
BLOCK_DV: tl.constexpr,
|
|
BLOCK_M: tl.constexpr,
|
|
BLOCK_N: tl.constexpr,
|
|
USE_CUSTOM_MASK: tl.constexpr,
|
|
SKIP_PREFIX_CUSTOM_MASK: tl.constexpr,
|
|
STORE_TRANSPOSE: tl.constexpr,
|
|
):
|
|
cur_seq = tl.program_id(0)
|
|
cur_head = tl.program_id(1)
|
|
cur_block_m = tl.program_id(2)
|
|
cur_kv_head = cur_head // kv_group_num
|
|
|
|
cur_seq_extend_start_idx = tl.load(qo_indptr + cur_seq)
|
|
cur_seq_len_extend = tl.load(qo_indptr + cur_seq + 1) - cur_seq_extend_start_idx
|
|
cur_seq_kv_start_idx = tl.load(kv_indptr + cur_seq)
|
|
cur_seq_len_prefix = tl.load(kv_indptr + cur_seq + 1) - cur_seq_kv_start_idx
|
|
cur_seq_len = cur_seq_len_prefix + cur_seq_len_extend
|
|
|
|
if USE_CUSTOM_MASK:
|
|
cur_seq_mask_start_idx = tl.load(mask_indptr + cur_seq)
|
|
|
|
offs_d = tl.arange(0, BLOCK_DMODEL)
|
|
offs_dv = tl.arange(0, BLOCK_DV)
|
|
offs_m = tl.arange(0, BLOCK_M)
|
|
mask_m = (cur_block_m * BLOCK_M + offs_m) < cur_seq_len_extend
|
|
|
|
mask_d = offs_d < Lq
|
|
mask_dv = offs_dv < Lv
|
|
|
|
offs_q = (
|
|
(cur_seq_extend_start_idx + cur_block_m * BLOCK_M + offs_m[:, None])
|
|
* stride_qbs
|
|
+ cur_head * stride_qh
|
|
+ offs_d[None, :]
|
|
)
|
|
q = tl.load(
|
|
Q_Extend + offs_q, mask=(mask_m[:, None]) & (mask_d[None, :]), other=0.0
|
|
)
|
|
|
|
if BLOCK_DPE > 0:
|
|
offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
|
|
offs_qpe = (
|
|
(cur_seq_extend_start_idx + cur_block_m * BLOCK_M + offs_m[:, None])
|
|
* stride_qbs
|
|
+ cur_head * stride_qh
|
|
+ offs_dpe[None, :]
|
|
)
|
|
qpe = tl.load(Q_Extend + offs_qpe, mask=mask_m[:, None], other=0.0)
|
|
|
|
# stage 1: compute scores with prefix
|
|
offs_n = tl.arange(0, BLOCK_N)
|
|
|
|
acc = tl.zeros([BLOCK_M, BLOCK_DV], dtype=tl.float32)
|
|
deno = tl.zeros([BLOCK_M], dtype=tl.float32)
|
|
e_max = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
|
|
|
for start_n in range(0, cur_seq_len_prefix, BLOCK_N):
|
|
start_n = tl.multiple_of(start_n, BLOCK_N)
|
|
mask_n = (start_n + offs_n) < cur_seq_len_prefix
|
|
offs_kv_loc = tl.load(
|
|
kv_indices + cur_seq_kv_start_idx + start_n + offs_n, mask=mask_n, other=0
|
|
)
|
|
|
|
# load k in transposed way
|
|
offs_buf_k = (
|
|
offs_kv_loc[None, :] * stride_buf_kbs
|
|
+ cur_kv_head * stride_buf_kh
|
|
+ offs_d[:, None]
|
|
)
|
|
k = tl.load(
|
|
K_Buffer + offs_buf_k, mask=(mask_n[None, :]) & (mask_d[:, None]), other=0.0
|
|
)
|
|
|
|
qk = tl.dot(q.to(k.dtype), k)
|
|
if BLOCK_DPE > 0:
|
|
offs_kpe = (
|
|
offs_kv_loc[None, :] * stride_buf_kbs
|
|
+ cur_kv_head * stride_buf_kh
|
|
+ offs_dpe[:, None]
|
|
)
|
|
kpe = tl.load(
|
|
K_Buffer + offs_kpe,
|
|
mask=mask_n[None, :],
|
|
other=0.0,
|
|
)
|
|
qk += tl.dot(qpe.to(kpe.dtype), kpe)
|
|
qk *= sm_scale
|
|
|
|
if logit_cap > 0:
|
|
qk = logit_cap * tanh(qk / logit_cap)
|
|
|
|
if USE_CUSTOM_MASK and not SKIP_PREFIX_CUSTOM_MASK:
|
|
custom_mask = tl.load(
|
|
mask_ptr
|
|
+ cur_seq_mask_start_idx
|
|
+ (cur_block_m * BLOCK_M + offs_m[:, None]) * cur_seq_len
|
|
+ start_n
|
|
+ offs_n[None, :],
|
|
mask=(mask_m[:, None] & mask_n[None, :]),
|
|
other=0,
|
|
)
|
|
custom_mask &= mask_m[:, None] & mask_n[None, :]
|
|
qk = tl.where(custom_mask, qk, float("-inf"))
|
|
else:
|
|
qk = tl.where(mask_m[:, None] & mask_n[None, :], qk, float("-inf"))
|
|
|
|
n_e_max = tl.maximum(tl.max(qk, 1), e_max)
|
|
re_scale = tl.exp(e_max - n_e_max)
|
|
p = tl.exp(qk - n_e_max[:, None])
|
|
deno = deno * re_scale + tl.sum(p, 1)
|
|
|
|
offs_buf_v = (
|
|
offs_kv_loc[:, None] * stride_buf_vbs
|
|
+ cur_kv_head * stride_buf_vh
|
|
+ offs_dv[None, :]
|
|
)
|
|
v = tl.load(
|
|
V_Buffer + offs_buf_v, mask=mask_n[:, None] & mask_dv[None, :], other=0.0
|
|
)
|
|
p = p.to(v.dtype)
|
|
acc = acc * re_scale[:, None] + tl.dot(p, v)
|
|
|
|
e_max = n_e_max
|
|
|
|
# stage 2: compute the triangle part
|
|
|
|
cur_block_m_end = tl.minimum(cur_seq_len_extend, (cur_block_m + 1) * BLOCK_M)
|
|
for start_n in range(0, cur_block_m_end, BLOCK_N):
|
|
start_n = tl.multiple_of(start_n, BLOCK_N)
|
|
mask_n = (start_n + offs_n) < cur_block_m_end
|
|
|
|
# load k in transposed way
|
|
offs_k = (
|
|
(cur_seq_extend_start_idx + start_n + offs_n[None, :]) * stride_kbs
|
|
+ cur_kv_head * stride_kh
|
|
+ offs_d[:, None]
|
|
)
|
|
k = tl.load(
|
|
K_Extend + offs_k, mask=(mask_n[None, :]) & (mask_d[:, None]), other=0.0
|
|
)
|
|
|
|
qk = tl.dot(q, k, out_dtype=tl.float32)
|
|
if BLOCK_DPE > 0:
|
|
offs_kpe = (
|
|
(cur_seq_extend_start_idx + start_n + offs_n[None, :]) * stride_kbs
|
|
+ cur_kv_head * stride_kh
|
|
+ offs_dpe[:, None]
|
|
)
|
|
kpe = tl.load(
|
|
K_Extend + offs_kpe,
|
|
mask=mask_n[None, :],
|
|
other=0.0,
|
|
)
|
|
qk += tl.dot(qpe, kpe)
|
|
|
|
qk *= sm_scale
|
|
|
|
if logit_cap > 0:
|
|
qk = logit_cap * tanh(qk / logit_cap)
|
|
|
|
if USE_CUSTOM_MASK:
|
|
custom_mask = tl.load(
|
|
mask_ptr
|
|
+ cur_seq_mask_start_idx
|
|
+ (cur_block_m * BLOCK_M + offs_m[:, None]) * cur_seq_len
|
|
+ cur_seq_len_prefix
|
|
+ start_n
|
|
+ offs_n[None, :],
|
|
mask=(mask_m[:, None] & mask_n[None, :]),
|
|
other=0,
|
|
)
|
|
custom_mask &= mask_m[:, None] & mask_n[None, :]
|
|
qk = tl.where(custom_mask, qk, float("-inf"))
|
|
else:
|
|
mask_causual = (cur_block_m * BLOCK_M + offs_m[:, None]) >= (
|
|
start_n + offs_n[None, :]
|
|
)
|
|
mask_causual &= mask_m[:, None] & mask_n[None, :]
|
|
qk = tl.where(mask_causual, qk, float("-inf"))
|
|
|
|
n_e_max = tl.maximum(tl.max(qk, 1), e_max)
|
|
re_scale = tl.exp(e_max - n_e_max)
|
|
p = tl.exp(qk - n_e_max[:, None])
|
|
deno = deno * re_scale + tl.sum(p, 1)
|
|
|
|
offs_v = (
|
|
(cur_seq_extend_start_idx + start_n + offs_n[:, None]) * stride_vbs
|
|
+ cur_kv_head * stride_vh
|
|
+ offs_dv[None, :]
|
|
)
|
|
v = tl.load(
|
|
V_Extend + offs_v, mask=mask_n[:, None] & mask_dv[None, :], other=0.0
|
|
)
|
|
p = p.to(v.dtype)
|
|
acc = acc * re_scale[:, None] + tl.dot(p, v)
|
|
|
|
e_max = n_e_max
|
|
|
|
offs_o = (
|
|
(cur_seq_extend_start_idx + cur_block_m * BLOCK_M + offs_m[:, None])
|
|
* stride_obs
|
|
+ cur_head * stride_oh
|
|
+ offs_dv[None, :]
|
|
)
|
|
if STORE_TRANSPOSE:
|
|
tl.store(
|
|
O_Extend + offs_o.T,
|
|
(acc / deno[:, None]).T,
|
|
mask=(mask_m[:, None] & mask_dv[None, :]).T,
|
|
)
|
|
else:
|
|
tl.store(
|
|
O_Extend + offs_o,
|
|
acc / deno[:, None],
|
|
mask=mask_m[:, None] & mask_dv[None, :],
|
|
)
|
|
|
|
|
|
def extend_attention_fwd(
|
|
q_extend,
|
|
k_extend,
|
|
v_extend,
|
|
o_extend,
|
|
k_buffer,
|
|
v_buffer,
|
|
qo_indptr,
|
|
kv_indptr,
|
|
kv_indices,
|
|
custom_mask,
|
|
mask_indptr,
|
|
max_len_extend,
|
|
sm_scale=None,
|
|
logit_cap=0.0,
|
|
skip_prefix_custom_mask=True,
|
|
):
|
|
"""
|
|
q_extend, k_extend, v_extend, o_extend: contiguous tensors
|
|
|
|
k_buffer, v_buffer: (prefix + extend) tensors in mem_manager
|
|
"""
|
|
Lq, Lk, Lv = (
|
|
q_extend.shape[-1],
|
|
k_extend.shape[-1],
|
|
v_extend.shape[-1],
|
|
)
|
|
|
|
if Lq == 576:
|
|
BLOCK_DMODEL = 512
|
|
BLOCK_DPE = 64
|
|
elif Lq == 288:
|
|
BLOCK_DMODEL = 256
|
|
BLOCK_DPE = 32
|
|
elif Lq == 192:
|
|
BLOCK_DMODEL = 128
|
|
BLOCK_DPE = 64
|
|
else:
|
|
BLOCK_DMODEL = triton.next_power_of_2(Lq)
|
|
BLOCK_DPE = 0
|
|
BLOCK_DV = triton.next_power_of_2(Lv)
|
|
|
|
if _is_hip:
|
|
BLOCK_M, BLOCK_N = (64, 64)
|
|
num_warps = 4
|
|
|
|
else:
|
|
if is_cuda_available and CUDA_CAPABILITY[0] >= 9:
|
|
if Lq <= 256:
|
|
BLOCK_M, BLOCK_N = (128, 64)
|
|
else:
|
|
BLOCK_M, BLOCK_N = (32, 64)
|
|
elif is_cuda_available and CUDA_CAPABILITY[0] >= 8:
|
|
# sm86/sm89 has a much smaller shared memory size (100K) than sm80 (160K)
|
|
if CUDA_CAPABILITY[1] == 9 or CUDA_CAPABILITY[1] == 6:
|
|
if Lq <= 128:
|
|
BLOCK_M, BLOCK_N = (64, 128)
|
|
elif Lq <= 256:
|
|
BLOCK_M, BLOCK_N = (64, 64)
|
|
else:
|
|
BLOCK_M, BLOCK_N = (32, 32)
|
|
else:
|
|
if Lq <= 128:
|
|
BLOCK_M, BLOCK_N = (128, 128)
|
|
elif Lq <= 256:
|
|
BLOCK_M, BLOCK_N = (64, 64)
|
|
else:
|
|
BLOCK_M, BLOCK_N = (32, 64)
|
|
else:
|
|
BLOCK_M, BLOCK_N = (64, 64) if Lq <= 128 else (32, 32)
|
|
|
|
num_warps = 4 if Lk <= 64 else 8
|
|
|
|
sm_scale = sm_scale or 1.0 / (Lq**0.5)
|
|
batch_size, head_num = qo_indptr.shape[0] - 1, q_extend.shape[1]
|
|
kv_group_num = q_extend.shape[1] // k_extend.shape[1]
|
|
|
|
USE_CUSTOM_MASK = custom_mask is not None
|
|
# Skip custom mask for prefix part
|
|
SKIP_PREFIX_CUSTOM_MASK = skip_prefix_custom_mask
|
|
|
|
grid = (batch_size, head_num, triton.cdiv(max_len_extend, BLOCK_M))
|
|
num_stages = 1
|
|
|
|
extra_kargs = {}
|
|
if _is_hip:
|
|
extra_kargs = {"waves_per_eu": 1, "matrix_instr_nonkdim": 16, "kpack": 2}
|
|
|
|
_fwd_kernel[grid](
|
|
q_extend,
|
|
k_extend,
|
|
v_extend,
|
|
o_extend,
|
|
k_buffer,
|
|
v_buffer,
|
|
qo_indptr,
|
|
kv_indptr,
|
|
kv_indices,
|
|
custom_mask,
|
|
mask_indptr,
|
|
sm_scale,
|
|
kv_group_num,
|
|
q_extend.stride(0),
|
|
q_extend.stride(1),
|
|
k_extend.stride(0),
|
|
k_extend.stride(1),
|
|
v_extend.stride(0),
|
|
v_extend.stride(1),
|
|
o_extend.stride(0),
|
|
o_extend.stride(1),
|
|
k_buffer.stride(0),
|
|
k_buffer.stride(1),
|
|
v_buffer.stride(0),
|
|
v_buffer.stride(1),
|
|
logit_cap=logit_cap,
|
|
BLOCK_DMODEL=BLOCK_DMODEL,
|
|
BLOCK_DPE=BLOCK_DPE,
|
|
BLOCK_DV=BLOCK_DV,
|
|
BLOCK_M=BLOCK_M,
|
|
BLOCK_N=BLOCK_N,
|
|
Lq=Lq,
|
|
Lv=Lv,
|
|
USE_CUSTOM_MASK=USE_CUSTOM_MASK,
|
|
SKIP_PREFIX_CUSTOM_MASK=SKIP_PREFIX_CUSTOM_MASK,
|
|
STORE_TRANSPOSE=_is_hip,
|
|
num_warps=num_warps,
|
|
num_stages=num_stages,
|
|
**extra_kargs,
|
|
)
|
|
|
|
|
|
def redundant_attention(
|
|
q_extend,
|
|
o_extend,
|
|
k_buffer,
|
|
v_buffer,
|
|
b_req_idx,
|
|
b_start_loc,
|
|
b_seq_len,
|
|
b_seq_len_prefix,
|
|
max_len_in_batch,
|
|
):
|
|
total_token_num = k_buffer.shape[0]
|
|
B, H_Q, D = b_req_idx.shape[0], q_extend.shape[-2], q_extend.shape[-1]
|
|
q_buffer = torch.empty(
|
|
(total_token_num, H_Q, D), dtype=q_extend.dtype, device=q_extend.device
|
|
)
|
|
|
|
pt = 0
|
|
for i in range(B):
|
|
cur_seq_len_extend = b_seq_len[i] - b_seq_len_prefix[i]
|
|
pl, pr = b_start_loc[i] + b_seq_len_prefix[i], b_start_loc[i] + b_seq_len[i]
|
|
q_buffer[pl:pr] = q_extend[pt : pt + cur_seq_len_extend]
|
|
pt += cur_seq_len_extend
|
|
|
|
o_buffer = torch.empty_like(q_buffer)
|
|
context_attention_fwd(
|
|
q_buffer, k_buffer, v_buffer, o_buffer, b_start_loc, b_seq_len, max_len_in_batch
|
|
)
|
|
|
|
pt = 0
|
|
for i in range(B):
|
|
cur_seq_len_extend = b_seq_len[i] - b_seq_len_prefix[i]
|
|
pl, pr = b_start_loc[i] + b_seq_len_prefix[i], b_start_loc[i] + b_seq_len[i]
|
|
o_extend[pt : pt + cur_seq_len_extend] = o_buffer[pl:pr]
|
|
pt += cur_seq_len_extend
|