from __future__ import annotations from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Optional, Union import torch if TYPE_CHECKING: from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput class AttentionBackend(ABC): """The base class of attention backends""" @abstractmethod def init_forward_metadata(self, forward_batch: ForwardBatch): """Init the metadata for a forward pass.""" raise NotImplementedError() def init_cuda_graph_state(self, max_bs: int): """Init the global shared states for cuda graph.""" raise NotImplementedError() def init_forward_metadata_capture_cuda_graph( self, bs: int, num_tokens: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, encoder_lens: Optional[torch.Tensor], forward_mode: ForwardMode, spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], ): """Init the metadata for a forward pass for capturing a cuda graph.""" raise NotImplementedError() def init_forward_metadata_replay_cuda_graph( self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, seq_lens_sum: int, encoder_lens: Optional[torch.Tensor], forward_mode: ForwardMode, spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], seq_lens_cpu: Optional[torch.Tensor], ): """Init the metadata for a forward pass for replaying a cuda graph.""" raise NotImplementedError() def get_cuda_graph_seq_len_fill_value(self): """Get the fill value for padded seq lens. Typically, it is 0 or 1.""" raise NotImplementedError() def forward( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, layer: RadixAttention, forward_batch: ForwardBatch, save_kv_cache: bool = True, ): """Run forward on an attention layer.""" if forward_batch.forward_mode.is_decode(): return self.forward_decode( q, k, v, layer, forward_batch, save_kv_cache=save_kv_cache, ) else: return self.forward_extend( q, k, v, layer, forward_batch, save_kv_cache=save_kv_cache, ) def forward_decode( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, layer: RadixAttention, forward_batch: ForwardBatch, save_kv_cache: bool = True, ): """Run a forward for decode.""" raise NotImplementedError() def forward_extend( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, layer: RadixAttention, forward_batch: ForwardBatch, save_kv_cache: bool = True, ): """Run a forward for extend.""" raise NotImplementedError()