from __future__ import annotations from functools import lru_cache from typing import Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange from sglang.srt.distributed import parallel_state from sglang.srt.distributed import utils as dist_utils from sglang.srt.layers.attention.triton_ops.prefill_attention import ( context_attention_fwd, ) from sglang.srt.layers.linear import ( ColumnParallelLinear, QKVParallelLinear, RowParallelLinear, ) from sglang.srt.layers.quantization import QuantizationConfig from sglang.srt.layers.rotary_embedding import apply_rotary_pos_emb, rotate_half from sglang.srt.utils import add_prefix class VisionAttention(nn.Module): r""" Multi-headed attention without any cache, mostly used for ViT. Args: use_qkv_parallel (bool, optional): If True, use QKV-parallel attention. use_context_forward (bool, default to True): if ``True``, a flash_attn style attention will be applied Otherwise, a full-sequence attention will be applied. softmax_in_single_precision (bool, default to False): if ``True``, the softmax will be performed in single-precision Otherwise, it will be performed in half-precision """ def __init__( self, embed_dim: int, num_heads: int, projection_size: int, use_qkv_parallel: bool, quant_config: Optional[QuantizationConfig] = None, dropout: float = 0.0, use_context_forward: bool = True, softmax_in_single_precision: bool = False, flatten_batch: bool = False, prefix: str = "", ): super().__init__() self.use_context_forward = use_context_forward world_size = parallel_state.get_tensor_model_parallel_world_size() self.dropout = dropout self.head_size = embed_dim // num_heads self.hidden_size_per_attention_head = dist_utils.divide( projection_size, num_heads ) self.num_attention_heads_per_partition = dist_utils.divide( num_heads, world_size ) if self.use_context_forward: self.qkv_backend = VisionTritonAttention() else: self.qkv_backend = VisionSdpaAttention( head_size=self.head_size, dropout=dropout, flatten_batch=flatten_batch, softmax_in_single_precision=softmax_in_single_precision, ) self.use_qkv_parallel = use_qkv_parallel if use_qkv_parallel: self.qkv_proj = QKVParallelLinear( hidden_size=embed_dim, head_size=self.head_size, total_num_heads=num_heads, quant_config=quant_config, prefix=add_prefix("qkv_proj", prefix), ) else: self.qkv_proj = ColumnParallelLinear( input_size=embed_dim, output_size=3 * projection_size, quant_config=quant_config, prefix=add_prefix("qkv_proj", prefix), ) self.proj = RowParallelLinear( input_size=embed_dim, output_size=embed_dim, quant_config=quant_config, prefix=add_prefix("out_proj", prefix), ) def forward( self, x: torch.Tensor, cu_seqlens: Optional[torch.Tensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: r""" Args: x: [b, s, embed_dim] cu_seqlens: [b] Returns: [s, b, head * head_size] """ bsz, s, _ = x.shape head = self.num_attention_heads_per_partition if self.use_qkv_parallel: # [b, s, embed_dim] --> [b, s, embed_dim] qkv, _ = self.qkv_proj(x) q, k, v = qkv.chunk(3, dim=-1) # [b, s, embed_dim] --> [b * s, head, head_size] q, k, v = [x.reshape(bsz * s, head, -1).contiguous() for x in (q, k, v)] else: # [b, s, embed_dim] --> [s, b, embed_dim] x = rearrange(x, "b s ... -> s b ...") # [s, b, embed_dim] --> [s, b, head * 3 * head_size] qkv, _ = self.qkv_proj(x) # [s, b, head * 3 * head_size] --> [s, b, head, 3 * head_size] new_x_shape = qkv.size()[:-1] + ( head, 3 * self.hidden_size_per_attention_head, ) qkv = qkv.view(*new_x_shape) # [s, b, head, 3 * head_size] --> 3 [s, b, head, head_size] q, k, v = dist_utils.split_tensor_along_last_dim(qkv, 3) # [s, b, head, head_size] --> [b, s, head, head_size] q, k, v = [ rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v) ] if position_embeddings is not None: cos, sin = position_embeddings original_shape = q.shape # [total_tokens, head, head_size] q = q.view(-1, head, self.head_size) k = k.view(-1, head, self.head_size) q, k = apply_rotary_pos_emb(q, k, cos, sin) q = q.view(original_shape) k = k.view(original_shape) if self.use_qkv_parallel: pass else: # [b, s, head, head_size] --> [b * s, head, head_size] q, k, v = [rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]] output = self.qkv_backend.forward(q, k, v, bsz, cu_seqlens, attention_mask) if self.use_qkv_parallel: # [b * s, h, head_size] --> [b, s, h * head_size] output = rearrange(output, "(b s) ... h d -> b s ... (h d)", b=bsz) # [b, s, h * head_size] --> [b, s, h * head_size] output, _ = self.proj(output) else: # [b * s, h, head_size] --> [s, b, h * head_size] context_layer = rearrange( output, "(b s) h d -> s b (h d)", b=bsz, s=s ).contiguous() # [s, b, h * head_size] --> [s, b, h * head_size] output, _ = self.proj(context_layer) # [s, b, h * head_size] --> [b, s, h * head_size] output = output.view(bsz, s, -1) return output class VisionSdpaAttention(nn.Module): r""" Scaled Dot Product Attention inner product """ def __init__( self, head_size: int, dropout: float = 0.0, flatten_batch: bool = False, softmax_in_single_precision: bool = False, ): super().__init__() self.head_size = head_size self.flatten_batch = flatten_batch self.softmax_in_single_precision = softmax_in_single_precision self.dropout = dropout @staticmethod @lru_cache(maxsize=128) def _generate_mask_cache( s: int, flatten_batch: bool, cu_seqlens: tuple ) -> torch.BoolTensor: """ Generate a boolean attention mask with caching mechanism. Args: s: sequence length flatten_batch: whether to flatten batch dimension cu_seqlens: tuple of cumulative sequence lengths Returns: attention mask tensor """ if flatten_batch: mask = torch.zeros([1, s, s], dtype=torch.bool) for i in range(1, len(cu_seqlens)): start = cu_seqlens[i - 1] end = cu_seqlens[i] mask[..., start:end, start:end] = True else: # [1, 1, 1, s] row_indices = torch.arange(s).view(1, 1, 1, s) # [1, 1, s, 1] col_indices = torch.arange(s).view(1, 1, s, 1) # [b, 1, 1, 1] seq_lens = torch.tensor( [end - start for start, end in zip(cu_seqlens[:-1], cu_seqlens[1:])], ).view(-1, 1, 1, 1) mask = (row_indices < seq_lens) & (col_indices < seq_lens) return mask def generate_patch_attention_mask( self, s: int, cu_seqlens: Optional[torch.Tensor], flatten_batch: bool = False, ) -> Optional[torch.Tensor]: r""" Creates a non-causal 4D mask of shape `(b, 1, s, s)` or `(1, 1, s, s)`. Args: s: sequence length cu_seqlens: cumulative sequence lengths tensor. If not, returns an empty mask flatten_batch: whether to flatten batch dimension Returns: attention mask tensor or None """ if cu_seqlens is None: return None cu_seqlens_tuple = tuple(cu_seqlens.cpu().tolist()) return self._generate_mask_cache(s, flatten_batch, cu_seqlens_tuple) def forward( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, bsz: int, cu_seqlens: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: r""" Args: cu_seqlens: [b] Returns: [b * s, h, head_size] """ s = q.shape[0] // bsz # [b, 1, s, s] if attention_mask is None: attention_mask = self.generate_patch_attention_mask( s, cu_seqlens, flatten_batch=self.flatten_batch ) if attention_mask is None: if self.softmax_in_single_precision: raise RuntimeError("Empty attention mask") else: attention_mask = attention_mask.to(device=q.device) q, k, v = [rearrange(x, "(b s) h d -> b h s d", b=bsz) for x in [q, k, v]] if self.softmax_in_single_precision: scale = self.head_size**-0.5 k_transposed = rearrange(k, "b h s d -> b h d s") attn_weights = torch.matmul(q, k_transposed) * scale del k, k_transposed attention_mask = (~attention_mask) * torch.finfo(q.dtype).min attn_weights = attn_weights + attention_mask del attention_mask # full-precision attn_weights = nn.functional.softmax( attn_weights, dim=-1, dtype=torch.float32 ).to(q.dtype) attn_weights = nn.functional.dropout( attn_weights, p=self.dropout, training=False ) output = torch.matmul(attn_weights, v) del attn_weights, v else: # SDPA # [b, h, s, head_size] output = F.scaled_dot_product_attention( q, k, v, attn_mask=attention_mask, dropout_p=self.dropout, is_causal=False, ) # [b, h, s, head_size] --> [b * s, h, head_size] output = rearrange(output, "b h s d -> (b s) h d") return output class VisionTritonAttention(nn.Module): """ Triton-implemented attention without a causal mask """ def __init__( self, ): super().__init__() def forward( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, _bsz: int, cu_seqlens: Optional[torch.Tensor], ) -> torch.Tensor: r""" Args: cu_seqlens: [b] Returns: [b * s, h, head_size] """ # [b * s, head, head_size] output = torch.empty_like(q) seq_lens = cu_seqlens[1:] - cu_seqlens[:-1] max_seqlen = seq_lens.max().item() context_attention_fwd( q, k, v, output, cu_seqlens.cuda(), seq_lens.cuda(), max_seqlen, is_causal=False, ) return output