109 lines
3.2 KiB
Python
109 lines
3.2 KiB
Python
from __future__ import annotations
|
|
|
|
from abc import ABC, abstractmethod
|
|
from typing import TYPE_CHECKING, Optional, Union
|
|
|
|
import torch
|
|
|
|
if TYPE_CHECKING:
|
|
from sglang.srt.layers.radix_attention import RadixAttention
|
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
|
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
|
|
|
|
|
class AttentionBackend(ABC):
|
|
"""The base class of attention backends"""
|
|
|
|
@abstractmethod
|
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
|
"""Init the metadata for a forward pass."""
|
|
raise NotImplementedError()
|
|
|
|
def init_cuda_graph_state(self, max_bs: int):
|
|
"""Init the global shared states for cuda graph."""
|
|
raise NotImplementedError()
|
|
|
|
def init_forward_metadata_capture_cuda_graph(
|
|
self,
|
|
bs: int,
|
|
num_tokens: int,
|
|
req_pool_indices: torch.Tensor,
|
|
seq_lens: torch.Tensor,
|
|
encoder_lens: Optional[torch.Tensor],
|
|
forward_mode: ForwardMode,
|
|
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
|
):
|
|
"""Init the metadata for a forward pass for capturing a cuda graph."""
|
|
raise NotImplementedError()
|
|
|
|
def init_forward_metadata_replay_cuda_graph(
|
|
self,
|
|
bs: int,
|
|
req_pool_indices: torch.Tensor,
|
|
seq_lens: torch.Tensor,
|
|
seq_lens_sum: int,
|
|
encoder_lens: Optional[torch.Tensor],
|
|
forward_mode: ForwardMode,
|
|
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
|
seq_lens_cpu: Optional[torch.Tensor],
|
|
):
|
|
"""Init the metadata for a forward pass for replaying a cuda graph."""
|
|
raise NotImplementedError()
|
|
|
|
def get_cuda_graph_seq_len_fill_value(self):
|
|
"""Get the fill value for padded seq lens. Typically, it is 0 or 1."""
|
|
raise NotImplementedError()
|
|
|
|
def forward(
|
|
self,
|
|
q: torch.Tensor,
|
|
k: torch.Tensor,
|
|
v: torch.Tensor,
|
|
layer: RadixAttention,
|
|
forward_batch: ForwardBatch,
|
|
save_kv_cache: bool = True,
|
|
):
|
|
"""Run forward on an attention layer."""
|
|
if forward_batch.forward_mode.is_decode():
|
|
return self.forward_decode(
|
|
q,
|
|
k,
|
|
v,
|
|
layer,
|
|
forward_batch,
|
|
save_kv_cache=save_kv_cache,
|
|
)
|
|
else:
|
|
return self.forward_extend(
|
|
q,
|
|
k,
|
|
v,
|
|
layer,
|
|
forward_batch,
|
|
save_kv_cache=save_kv_cache,
|
|
)
|
|
|
|
def forward_decode(
|
|
self,
|
|
q: torch.Tensor,
|
|
k: torch.Tensor,
|
|
v: torch.Tensor,
|
|
layer: RadixAttention,
|
|
forward_batch: ForwardBatch,
|
|
save_kv_cache: bool = True,
|
|
):
|
|
"""Run a forward for decode."""
|
|
raise NotImplementedError()
|
|
|
|
def forward_extend(
|
|
self,
|
|
q: torch.Tensor,
|
|
k: torch.Tensor,
|
|
v: torch.Tensor,
|
|
layer: RadixAttention,
|
|
forward_batch: ForwardBatch,
|
|
save_kv_cache: bool = True,
|
|
):
|
|
"""Run a forward for extend."""
|
|
raise NotImplementedError()
|