118 lines
4.5 KiB
Python
118 lines
4.5 KiB
Python
from typing import Tuple, Union
|
|
|
|
import torch
|
|
|
|
from sglang.srt.lora.utils import LoRABatchInfo
|
|
|
|
|
|
def get_fuse_output_add_from_name(name: str) -> bool:
|
|
mapping = {
|
|
"triton": True,
|
|
"flashinfer": False,
|
|
}
|
|
return mapping.get(name, False)
|
|
|
|
|
|
def get_fuse_stacked_lora_b_from_name(name: str) -> bool:
|
|
mapping = {
|
|
"triton": True,
|
|
"flashinfer": False,
|
|
}
|
|
return mapping.get(name, False)
|
|
|
|
|
|
class BaseLoRABackend:
|
|
"""Base class for different Lora backends.
|
|
Each backend has its own implementation of Lora kernels.
|
|
|
|
Args:
|
|
name: name of backend
|
|
batch_info: information of current batch for use
|
|
fuse_output_add: if set to True, the output buffer for storing result will be passed in when doing lora_b forward,
|
|
and the operation of adding will be fused into kernel
|
|
"""
|
|
|
|
def __init__(self, name: str, batch_info: LoRABatchInfo = None):
|
|
self.name = name
|
|
self.batch_info = batch_info
|
|
self.fuse_output_add = get_fuse_output_add_from_name(name)
|
|
self.fuse_stacked_lora_b = get_fuse_stacked_lora_b_from_name(name)
|
|
|
|
def run_lora_a_sgemm(
|
|
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
|
|
) -> torch.Tensor:
|
|
"""Run segment Gemm of lora a modules with current backend.
|
|
The definition of segment Gemm can be referred to https://docs.flashinfer.ai/api/gemm.html.
|
|
|
|
Args:
|
|
x: input matrix with shape (s, input_dim), here s is the sum of all sequence lengths
|
|
weights: a set of lora weights with shape (num_lora, c * r, input_dim),
|
|
here r is lora rank, c is a multiplier for stacked modules (e.g., c=3 for qkv_proj, c=2 for gate_up_proj)
|
|
usually input_dim is much larger than r
|
|
Returns:
|
|
result with shape (s, c * r)
|
|
"""
|
|
pass
|
|
|
|
def run_lora_b_sgemm(
|
|
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
|
|
) -> torch.Tensor:
|
|
"""Run segment Gemm of lora b modules with current backend.
|
|
The definition of segment Gemm can be referred to https://docs.flashinfer.ai/api/gemm.html.
|
|
|
|
Args:
|
|
x: input matrix with shape (s, r), here s is the sum of all sequence lengths, r is lora rank
|
|
weights: a set of lora weights with shape (num_lora, output_dim, r)
|
|
usually output_dim is much larger than r
|
|
Returns:
|
|
result with shape (s, output_dim)
|
|
"""
|
|
pass
|
|
|
|
def run_qkv_lora(
|
|
self,
|
|
x: torch.Tensor,
|
|
qkv_lora_a: torch.Tensor,
|
|
qkv_lora_b: Union[torch.Tensor, Tuple[torch.Tensor]],
|
|
*args,
|
|
**kwargs
|
|
) -> torch.Tensor:
|
|
"""Run the lora pass for QKV Layer.
|
|
|
|
Args:
|
|
x: input matrix with shape (s, input_dim), here s is the sum of all sequence lengths
|
|
qkv_lora_a: lora_a module for qkv, with shape (num_lora, 3 * r, input_dim)
|
|
qkv_lora_b: lora_b module for qkv.
|
|
If passed in as a tensor, its shape should be (num_lora,output_dim_q + 2 * output_dim_kv, r)
|
|
If passed in as a tuple of two tensors, it should contain:
|
|
a lora_b module for q, with shape (1, num_lora, output_dim_q, r)
|
|
and a combined lora_b module for kv, with shape (2, num_lora, output_dim_kv, r)
|
|
Returns:
|
|
result with shape (s, output_dim_q + 2 * output_dim_kv)
|
|
"""
|
|
pass
|
|
|
|
def run_gate_up_lora(
|
|
self,
|
|
x: torch.Tensor,
|
|
gate_up_lora_a: torch.Tensor,
|
|
gate_up_lora_b: Union[torch.Tensor, Tuple[torch.Tensor]],
|
|
*args,
|
|
**kwargs
|
|
) -> torch.Tensor:
|
|
"""Run the lora pass for gate_up_proj, usually attached to MergedColumnParallelLayer.
|
|
|
|
Args:
|
|
x: input matrix with shape (s, input_dim), here s is the sum of all sequence lengths
|
|
gate_up_lora_a: lora_a module for gate_up_proj, with shape (num_lora, 2 * r, input_dim)
|
|
gate_up_lora_b: lora_b module for qkv.
|
|
If passed in as a tensor, its shape should be (num_lora, 2 * output_dim, r)
|
|
If passed in as a tuple, it should contain two tensors with shape (num_lora, output_dim, r)
|
|
Returns:
|
|
result with shape (s, 2 * output_dim)
|
|
"""
|
|
pass
|
|
|
|
def set_batch_info(self, batch_info: LoRABatchInfo):
|
|
self.batch_info = batch_info
|