sglang0.4.5.post1/python/sglang/srt/lora/backend/base_backend.py

118 lines
4.5 KiB
Python

from typing import Tuple, Union
import torch
from sglang.srt.lora.utils import LoRABatchInfo
def get_fuse_output_add_from_name(name: str) -> bool:
mapping = {
"triton": True,
"flashinfer": False,
}
return mapping.get(name, False)
def get_fuse_stacked_lora_b_from_name(name: str) -> bool:
mapping = {
"triton": True,
"flashinfer": False,
}
return mapping.get(name, False)
class BaseLoRABackend:
"""Base class for different Lora backends.
Each backend has its own implementation of Lora kernels.
Args:
name: name of backend
batch_info: information of current batch for use
fuse_output_add: if set to True, the output buffer for storing result will be passed in when doing lora_b forward,
and the operation of adding will be fused into kernel
"""
def __init__(self, name: str, batch_info: LoRABatchInfo = None):
self.name = name
self.batch_info = batch_info
self.fuse_output_add = get_fuse_output_add_from_name(name)
self.fuse_stacked_lora_b = get_fuse_stacked_lora_b_from_name(name)
def run_lora_a_sgemm(
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
) -> torch.Tensor:
"""Run segment Gemm of lora a modules with current backend.
The definition of segment Gemm can be referred to https://docs.flashinfer.ai/api/gemm.html.
Args:
x: input matrix with shape (s, input_dim), here s is the sum of all sequence lengths
weights: a set of lora weights with shape (num_lora, c * r, input_dim),
here r is lora rank, c is a multiplier for stacked modules (e.g., c=3 for qkv_proj, c=2 for gate_up_proj)
usually input_dim is much larger than r
Returns:
result with shape (s, c * r)
"""
pass
def run_lora_b_sgemm(
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
) -> torch.Tensor:
"""Run segment Gemm of lora b modules with current backend.
The definition of segment Gemm can be referred to https://docs.flashinfer.ai/api/gemm.html.
Args:
x: input matrix with shape (s, r), here s is the sum of all sequence lengths, r is lora rank
weights: a set of lora weights with shape (num_lora, output_dim, r)
usually output_dim is much larger than r
Returns:
result with shape (s, output_dim)
"""
pass
def run_qkv_lora(
self,
x: torch.Tensor,
qkv_lora_a: torch.Tensor,
qkv_lora_b: Union[torch.Tensor, Tuple[torch.Tensor]],
*args,
**kwargs
) -> torch.Tensor:
"""Run the lora pass for QKV Layer.
Args:
x: input matrix with shape (s, input_dim), here s is the sum of all sequence lengths
qkv_lora_a: lora_a module for qkv, with shape (num_lora, 3 * r, input_dim)
qkv_lora_b: lora_b module for qkv.
If passed in as a tensor, its shape should be (num_lora,output_dim_q + 2 * output_dim_kv, r)
If passed in as a tuple of two tensors, it should contain:
a lora_b module for q, with shape (1, num_lora, output_dim_q, r)
and a combined lora_b module for kv, with shape (2, num_lora, output_dim_kv, r)
Returns:
result with shape (s, output_dim_q + 2 * output_dim_kv)
"""
pass
def run_gate_up_lora(
self,
x: torch.Tensor,
gate_up_lora_a: torch.Tensor,
gate_up_lora_b: Union[torch.Tensor, Tuple[torch.Tensor]],
*args,
**kwargs
) -> torch.Tensor:
"""Run the lora pass for gate_up_proj, usually attached to MergedColumnParallelLayer.
Args:
x: input matrix with shape (s, input_dim), here s is the sum of all sequence lengths
gate_up_lora_a: lora_a module for gate_up_proj, with shape (num_lora, 2 * r, input_dim)
gate_up_lora_b: lora_b module for qkv.
If passed in as a tensor, its shape should be (num_lora, 2 * output_dim, r)
If passed in as a tuple, it should contain two tensors with shape (num_lora, output_dim, r)
Returns:
result with shape (s, 2 * output_dim)
"""
pass
def set_batch_info(self, batch_info: LoRABatchInfo):
self.batch_info = batch_info