sglang0.4.5.post1/python/sglang/srt/lora/backend/flashinfer_backend.py

132 lines
4.0 KiB
Python

from typing import Tuple
import torch
from sglang.srt.lora.backend import BaseLoRABackend
from sglang.srt.lora.utils import LoRABatchInfo
from sglang.srt.utils import is_flashinfer_available
if is_flashinfer_available():
from flashinfer import SegmentGEMMWrapper
class FlashInferLoRABackend(BaseLoRABackend):
def __init__(self, name: str, batch_info: LoRABatchInfo = None):
super().__init__(name, batch_info)
# Set up SGemm Wrapper from flashinfer
# FIXME wait for flashinfer segment gemm update
workspace_buffer = torch.empty(1 * 1024 * 1024, dtype=torch.int8, device="cuda")
self.segment_gemm = SegmentGEMMWrapper(workspace_buffer)
def run_lora_a_sgemm(
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
) -> torch.Tensor:
return self.segment_gemm.run(
x=x,
weights=weights,
batch_size=self.batch_info.bs,
weight_column_major=True,
seg_indptr=self.batch_info.seg_indptr,
weight_indices=self.batch_info.weight_indices,
)
def run_lora_b_sgemm(
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
) -> torch.Tensor:
return (
self.segment_gemm.run(
x=x,
weights=weights,
batch_size=self.batch_info.bs,
weight_column_major=True,
seg_indptr=self.batch_info.seg_indptr,
weight_indices=self.batch_info.weight_indices,
)
* self.batch_info.scalings[0]
)
def run_qkv_lora(
self,
x: torch.Tensor,
qkv_lora_a: torch.Tensor,
qkv_lora_b: Tuple[torch.Tensor],
*args,
**kwargs,
) -> torch.Tensor:
assert isinstance(qkv_lora_b, tuple) and len(qkv_lora_b) == 2
# Shape of lora_a_output: (s, 3 * r)
lora_a_output = self.run_lora_a_sgemm(x=x, weights=qkv_lora_a)
q_lora_b, kv_lora_b = qkv_lora_b
lora_rank = kv_lora_b.shape[-1]
output_dim_q = q_lora_b.shape[-2]
output_dim_kv = kv_lora_b.shape[-2]
lora_output = torch.empty(
(x.shape[0], output_dim_q + 2 * output_dim_kv),
device=x.device,
dtype=x.dtype,
)
# q
lora_output[:, :output_dim_q] = self.run_lora_b_sgemm(
x=lora_a_output[:, :lora_rank].contiguous(), weights=q_lora_b[0]
)
# kv
lora_output[:, output_dim_q : output_dim_q + output_dim_kv] = (
self.run_lora_b_sgemm(
x=lora_a_output[:, lora_rank : 2 * lora_rank].contiguous(),
weights=kv_lora_b[0],
)
)
lora_output[
:, output_dim_q + output_dim_kv : output_dim_q + 2 * output_dim_kv
] = self.run_lora_b_sgemm(
x=lora_a_output[:, 2 * lora_rank : 3 * lora_rank].contiguous(),
weights=kv_lora_b[1],
)
return lora_output * self.batch_info.scalings[0]
def run_gate_up_lora(
self,
x: torch.Tensor,
gate_up_lora_a: torch.Tensor,
gate_up_lora_b: Tuple[torch.Tensor],
*args,
**kwargs,
) -> torch.Tensor:
assert isinstance(gate_up_lora_b, tuple) and len(gate_up_lora_b) == 2
lora_rank = gate_up_lora_b[0].shape[-1]
output_dim = gate_up_lora_b[0].shape[-2]
# Shape of lora_a_output: (s, 2 * r)
lora_a_output = self.run_lora_a_sgemm(x=x, weights=gate_up_lora_a)
lora_output = torch.empty(
(x.shape[0], 2 * output_dim),
device=x.device,
dtype=x.dtype,
)
# Compute lora for gate and up proj respectively
lora_output[:, :output_dim] = self.run_lora_b_sgemm(
x=lora_a_output[:, :lora_rank].contiguous(),
weights=gate_up_lora_b[0],
)
lora_output[:, output_dim:] = self.run_lora_b_sgemm(
x=lora_a_output[:, lora_rank:].contiguous(),
weights=gate_up_lora_b[1],
)
return lora_output * self.batch_info.scalings[0]