367 lines
12 KiB
Python
367 lines
12 KiB
Python
from __future__ import annotations
|
|
|
|
from functools import lru_cache
|
|
from typing import Optional, Tuple
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from einops import rearrange
|
|
|
|
from sglang.srt.distributed import parallel_state
|
|
from sglang.srt.distributed import utils as dist_utils
|
|
from sglang.srt.layers.attention.triton_ops.prefill_attention import (
|
|
context_attention_fwd,
|
|
)
|
|
from sglang.srt.layers.linear import (
|
|
ColumnParallelLinear,
|
|
QKVParallelLinear,
|
|
RowParallelLinear,
|
|
)
|
|
from sglang.srt.layers.quantization import QuantizationConfig
|
|
from sglang.srt.layers.rotary_embedding import apply_rotary_pos_emb, rotate_half
|
|
from sglang.srt.utils import add_prefix
|
|
|
|
|
|
class VisionAttention(nn.Module):
|
|
r"""
|
|
Multi-headed attention without any cache, mostly used for ViT.
|
|
|
|
|
|
Args:
|
|
use_qkv_parallel (bool, optional): If True, use QKV-parallel attention.
|
|
use_context_forward (bool, default to True):
|
|
if ``True``, a flash_attn style attention will be applied
|
|
Otherwise, a full-sequence attention will be applied.
|
|
softmax_in_single_precision (bool, default to False):
|
|
if ``True``, the softmax will be performed in single-precision
|
|
Otherwise, it will be performed in half-precision
|
|
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
embed_dim: int,
|
|
num_heads: int,
|
|
projection_size: int,
|
|
use_qkv_parallel: bool,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
dropout: float = 0.0,
|
|
use_context_forward: bool = True,
|
|
softmax_in_single_precision: bool = False,
|
|
flatten_batch: bool = False,
|
|
prefix: str = "",
|
|
):
|
|
super().__init__()
|
|
self.use_context_forward = use_context_forward
|
|
world_size = parallel_state.get_tensor_model_parallel_world_size()
|
|
self.dropout = dropout
|
|
self.head_size = embed_dim // num_heads
|
|
self.hidden_size_per_attention_head = dist_utils.divide(
|
|
projection_size, num_heads
|
|
)
|
|
self.num_attention_heads_per_partition = dist_utils.divide(
|
|
num_heads, world_size
|
|
)
|
|
|
|
if self.use_context_forward:
|
|
self.qkv_backend = VisionTritonAttention()
|
|
else:
|
|
self.qkv_backend = VisionSdpaAttention(
|
|
head_size=self.head_size,
|
|
dropout=dropout,
|
|
flatten_batch=flatten_batch,
|
|
softmax_in_single_precision=softmax_in_single_precision,
|
|
)
|
|
|
|
self.use_qkv_parallel = use_qkv_parallel
|
|
if use_qkv_parallel:
|
|
self.qkv_proj = QKVParallelLinear(
|
|
hidden_size=embed_dim,
|
|
head_size=self.head_size,
|
|
total_num_heads=num_heads,
|
|
quant_config=quant_config,
|
|
prefix=add_prefix("qkv_proj", prefix),
|
|
)
|
|
else:
|
|
self.qkv_proj = ColumnParallelLinear(
|
|
input_size=embed_dim,
|
|
output_size=3 * projection_size,
|
|
quant_config=quant_config,
|
|
prefix=add_prefix("qkv_proj", prefix),
|
|
)
|
|
self.proj = RowParallelLinear(
|
|
input_size=embed_dim,
|
|
output_size=embed_dim,
|
|
quant_config=quant_config,
|
|
prefix=add_prefix("out_proj", prefix),
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
x: torch.Tensor,
|
|
cu_seqlens: Optional[torch.Tensor] = None,
|
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
) -> torch.Tensor:
|
|
r"""
|
|
Args:
|
|
x: [b, s, embed_dim]
|
|
cu_seqlens: [b]
|
|
Returns:
|
|
[s, b, head * head_size]
|
|
"""
|
|
bsz, s, _ = x.shape
|
|
head = self.num_attention_heads_per_partition
|
|
if self.use_qkv_parallel:
|
|
# [b, s, embed_dim] --> [b, s, embed_dim]
|
|
qkv, _ = self.qkv_proj(x)
|
|
q, k, v = qkv.chunk(3, dim=-1)
|
|
|
|
# [b, s, embed_dim] --> [b * s, head, head_size]
|
|
q, k, v = [x.reshape(bsz * s, head, -1).contiguous() for x in (q, k, v)]
|
|
else:
|
|
# [b, s, embed_dim] --> [s, b, embed_dim]
|
|
x = rearrange(x, "b s ... -> s b ...")
|
|
# [s, b, embed_dim] --> [s, b, head * 3 * head_size]
|
|
qkv, _ = self.qkv_proj(x)
|
|
# [s, b, head * 3 * head_size] --> [s, b, head, 3 * head_size]
|
|
new_x_shape = qkv.size()[:-1] + (
|
|
head,
|
|
3 * self.hidden_size_per_attention_head,
|
|
)
|
|
qkv = qkv.view(*new_x_shape)
|
|
|
|
# [s, b, head, 3 * head_size] --> 3 [s, b, head, head_size]
|
|
q, k, v = dist_utils.split_tensor_along_last_dim(qkv, 3)
|
|
|
|
# [s, b, head, head_size] --> [b, s, head, head_size]
|
|
q, k, v = [
|
|
rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v)
|
|
]
|
|
|
|
if position_embeddings is not None:
|
|
cos, sin = position_embeddings
|
|
original_shape = q.shape
|
|
# [total_tokens, head, head_size]
|
|
q = q.view(-1, head, self.head_size)
|
|
k = k.view(-1, head, self.head_size)
|
|
|
|
q, k = apply_rotary_pos_emb(q, k, cos, sin)
|
|
|
|
q = q.view(original_shape)
|
|
k = k.view(original_shape)
|
|
|
|
if self.use_qkv_parallel:
|
|
pass
|
|
else:
|
|
# [b, s, head, head_size] --> [b * s, head, head_size]
|
|
q, k, v = [rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]]
|
|
|
|
output = self.qkv_backend.forward(q, k, v, bsz, cu_seqlens, attention_mask)
|
|
|
|
if self.use_qkv_parallel:
|
|
# [b * s, h, head_size] --> [b, s, h * head_size]
|
|
output = rearrange(output, "(b s) ... h d -> b s ... (h d)", b=bsz)
|
|
|
|
# [b, s, h * head_size] --> [b, s, h * head_size]
|
|
output, _ = self.proj(output)
|
|
else:
|
|
# [b * s, h, head_size] --> [s, b, h * head_size]
|
|
context_layer = rearrange(
|
|
output, "(b s) h d -> s b (h d)", b=bsz, s=s
|
|
).contiguous()
|
|
|
|
# [s, b, h * head_size] --> [s, b, h * head_size]
|
|
output, _ = self.proj(context_layer)
|
|
|
|
# [s, b, h * head_size] --> [b, s, h * head_size]
|
|
output = output.view(bsz, s, -1)
|
|
|
|
return output
|
|
|
|
|
|
class VisionSdpaAttention(nn.Module):
|
|
r"""
|
|
Scaled Dot Product Attention inner product
|
|
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
head_size: int,
|
|
dropout: float = 0.0,
|
|
flatten_batch: bool = False,
|
|
softmax_in_single_precision: bool = False,
|
|
):
|
|
super().__init__()
|
|
self.head_size = head_size
|
|
self.flatten_batch = flatten_batch
|
|
self.softmax_in_single_precision = softmax_in_single_precision
|
|
self.dropout = dropout
|
|
|
|
@staticmethod
|
|
@lru_cache(maxsize=128)
|
|
def _generate_mask_cache(
|
|
s: int, flatten_batch: bool, cu_seqlens: tuple
|
|
) -> torch.BoolTensor:
|
|
"""
|
|
Generate a boolean attention mask with caching mechanism.
|
|
Args:
|
|
s: sequence length
|
|
flatten_batch: whether to flatten batch dimension
|
|
cu_seqlens: tuple of cumulative sequence lengths
|
|
Returns:
|
|
attention mask tensor
|
|
"""
|
|
if flatten_batch:
|
|
mask = torch.zeros([1, s, s], dtype=torch.bool)
|
|
for i in range(1, len(cu_seqlens)):
|
|
start = cu_seqlens[i - 1]
|
|
end = cu_seqlens[i]
|
|
mask[..., start:end, start:end] = True
|
|
else:
|
|
# [1, 1, 1, s]
|
|
row_indices = torch.arange(s).view(1, 1, 1, s)
|
|
# [1, 1, s, 1]
|
|
col_indices = torch.arange(s).view(1, 1, s, 1)
|
|
# [b, 1, 1, 1]
|
|
seq_lens = torch.tensor(
|
|
[end - start for start, end in zip(cu_seqlens[:-1], cu_seqlens[1:])],
|
|
).view(-1, 1, 1, 1)
|
|
|
|
mask = (row_indices < seq_lens) & (col_indices < seq_lens)
|
|
|
|
return mask
|
|
|
|
def generate_patch_attention_mask(
|
|
self,
|
|
s: int,
|
|
cu_seqlens: Optional[torch.Tensor],
|
|
flatten_batch: bool = False,
|
|
) -> Optional[torch.Tensor]:
|
|
r"""
|
|
Creates a non-causal 4D mask of shape `(b, 1, s, s)` or `(1, 1, s, s)`.
|
|
Args:
|
|
s: sequence length
|
|
cu_seqlens: cumulative sequence lengths tensor. If not, returns an empty mask
|
|
flatten_batch: whether to flatten batch dimension
|
|
Returns:
|
|
attention mask tensor or None
|
|
"""
|
|
if cu_seqlens is None:
|
|
return None
|
|
|
|
cu_seqlens_tuple = tuple(cu_seqlens.cpu().tolist())
|
|
|
|
return self._generate_mask_cache(s, flatten_batch, cu_seqlens_tuple)
|
|
|
|
def forward(
|
|
self,
|
|
q: torch.Tensor,
|
|
k: torch.Tensor,
|
|
v: torch.Tensor,
|
|
bsz: int,
|
|
cu_seqlens: Optional[torch.Tensor] = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
) -> torch.Tensor:
|
|
r"""
|
|
Args:
|
|
cu_seqlens: [b]
|
|
Returns:
|
|
[b * s, h, head_size]
|
|
"""
|
|
|
|
s = q.shape[0] // bsz
|
|
|
|
# [b, 1, s, s]
|
|
if attention_mask is None:
|
|
attention_mask = self.generate_patch_attention_mask(
|
|
s, cu_seqlens, flatten_batch=self.flatten_batch
|
|
)
|
|
|
|
if attention_mask is None:
|
|
if self.softmax_in_single_precision:
|
|
raise RuntimeError("Empty attention mask")
|
|
else:
|
|
attention_mask = attention_mask.to(device=q.device)
|
|
|
|
q, k, v = [rearrange(x, "(b s) h d -> b h s d", b=bsz) for x in [q, k, v]]
|
|
|
|
if self.softmax_in_single_precision:
|
|
scale = self.head_size**-0.5
|
|
k_transposed = rearrange(k, "b h s d -> b h d s")
|
|
attn_weights = torch.matmul(q, k_transposed) * scale
|
|
del k, k_transposed
|
|
attention_mask = (~attention_mask) * torch.finfo(q.dtype).min
|
|
attn_weights = attn_weights + attention_mask
|
|
del attention_mask
|
|
# full-precision
|
|
attn_weights = nn.functional.softmax(
|
|
attn_weights, dim=-1, dtype=torch.float32
|
|
).to(q.dtype)
|
|
attn_weights = nn.functional.dropout(
|
|
attn_weights, p=self.dropout, training=False
|
|
)
|
|
output = torch.matmul(attn_weights, v)
|
|
del attn_weights, v
|
|
else:
|
|
# SDPA
|
|
# [b, h, s, head_size]
|
|
output = F.scaled_dot_product_attention(
|
|
q,
|
|
k,
|
|
v,
|
|
attn_mask=attention_mask,
|
|
dropout_p=self.dropout,
|
|
is_causal=False,
|
|
)
|
|
|
|
# [b, h, s, head_size] --> [b * s, h, head_size]
|
|
output = rearrange(output, "b h s d -> (b s) h d")
|
|
|
|
return output
|
|
|
|
|
|
class VisionTritonAttention(nn.Module):
|
|
"""
|
|
Triton-implemented attention without a causal mask
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
):
|
|
super().__init__()
|
|
|
|
def forward(
|
|
self,
|
|
q: torch.Tensor,
|
|
k: torch.Tensor,
|
|
v: torch.Tensor,
|
|
_bsz: int,
|
|
cu_seqlens: Optional[torch.Tensor],
|
|
) -> torch.Tensor:
|
|
r"""
|
|
Args:
|
|
cu_seqlens: [b]
|
|
Returns:
|
|
[b * s, h, head_size]
|
|
"""
|
|
|
|
# [b * s, head, head_size]
|
|
output = torch.empty_like(q)
|
|
seq_lens = cu_seqlens[1:] - cu_seqlens[:-1]
|
|
max_seqlen = seq_lens.max().item()
|
|
context_attention_fwd(
|
|
q,
|
|
k,
|
|
v,
|
|
output,
|
|
cu_seqlens.cuda(),
|
|
seq_lens.cuda(),
|
|
max_seqlen,
|
|
is_causal=False,
|
|
)
|
|
|
|
return output
|