132 lines
4.0 KiB
Python
132 lines
4.0 KiB
Python
from typing import Tuple
|
|
|
|
import torch
|
|
|
|
from sglang.srt.lora.backend import BaseLoRABackend
|
|
from sglang.srt.lora.utils import LoRABatchInfo
|
|
from sglang.srt.utils import is_flashinfer_available
|
|
|
|
if is_flashinfer_available():
|
|
from flashinfer import SegmentGEMMWrapper
|
|
|
|
|
|
class FlashInferLoRABackend(BaseLoRABackend):
|
|
|
|
def __init__(self, name: str, batch_info: LoRABatchInfo = None):
|
|
super().__init__(name, batch_info)
|
|
|
|
# Set up SGemm Wrapper from flashinfer
|
|
# FIXME wait for flashinfer segment gemm update
|
|
workspace_buffer = torch.empty(1 * 1024 * 1024, dtype=torch.int8, device="cuda")
|
|
self.segment_gemm = SegmentGEMMWrapper(workspace_buffer)
|
|
|
|
def run_lora_a_sgemm(
|
|
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
|
|
) -> torch.Tensor:
|
|
|
|
return self.segment_gemm.run(
|
|
x=x,
|
|
weights=weights,
|
|
batch_size=self.batch_info.bs,
|
|
weight_column_major=True,
|
|
seg_indptr=self.batch_info.seg_indptr,
|
|
weight_indices=self.batch_info.weight_indices,
|
|
)
|
|
|
|
def run_lora_b_sgemm(
|
|
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
|
|
) -> torch.Tensor:
|
|
|
|
return (
|
|
self.segment_gemm.run(
|
|
x=x,
|
|
weights=weights,
|
|
batch_size=self.batch_info.bs,
|
|
weight_column_major=True,
|
|
seg_indptr=self.batch_info.seg_indptr,
|
|
weight_indices=self.batch_info.weight_indices,
|
|
)
|
|
* self.batch_info.scalings[0]
|
|
)
|
|
|
|
def run_qkv_lora(
|
|
self,
|
|
x: torch.Tensor,
|
|
qkv_lora_a: torch.Tensor,
|
|
qkv_lora_b: Tuple[torch.Tensor],
|
|
*args,
|
|
**kwargs,
|
|
) -> torch.Tensor:
|
|
|
|
assert isinstance(qkv_lora_b, tuple) and len(qkv_lora_b) == 2
|
|
|
|
# Shape of lora_a_output: (s, 3 * r)
|
|
lora_a_output = self.run_lora_a_sgemm(x=x, weights=qkv_lora_a)
|
|
|
|
q_lora_b, kv_lora_b = qkv_lora_b
|
|
lora_rank = kv_lora_b.shape[-1]
|
|
output_dim_q = q_lora_b.shape[-2]
|
|
output_dim_kv = kv_lora_b.shape[-2]
|
|
lora_output = torch.empty(
|
|
(x.shape[0], output_dim_q + 2 * output_dim_kv),
|
|
device=x.device,
|
|
dtype=x.dtype,
|
|
)
|
|
|
|
# q
|
|
lora_output[:, :output_dim_q] = self.run_lora_b_sgemm(
|
|
x=lora_a_output[:, :lora_rank].contiguous(), weights=q_lora_b[0]
|
|
)
|
|
|
|
# kv
|
|
lora_output[:, output_dim_q : output_dim_q + output_dim_kv] = (
|
|
self.run_lora_b_sgemm(
|
|
x=lora_a_output[:, lora_rank : 2 * lora_rank].contiguous(),
|
|
weights=kv_lora_b[0],
|
|
)
|
|
)
|
|
|
|
lora_output[
|
|
:, output_dim_q + output_dim_kv : output_dim_q + 2 * output_dim_kv
|
|
] = self.run_lora_b_sgemm(
|
|
x=lora_a_output[:, 2 * lora_rank : 3 * lora_rank].contiguous(),
|
|
weights=kv_lora_b[1],
|
|
)
|
|
|
|
return lora_output * self.batch_info.scalings[0]
|
|
|
|
def run_gate_up_lora(
|
|
self,
|
|
x: torch.Tensor,
|
|
gate_up_lora_a: torch.Tensor,
|
|
gate_up_lora_b: Tuple[torch.Tensor],
|
|
*args,
|
|
**kwargs,
|
|
) -> torch.Tensor:
|
|
|
|
assert isinstance(gate_up_lora_b, tuple) and len(gate_up_lora_b) == 2
|
|
lora_rank = gate_up_lora_b[0].shape[-1]
|
|
output_dim = gate_up_lora_b[0].shape[-2]
|
|
|
|
# Shape of lora_a_output: (s, 2 * r)
|
|
lora_a_output = self.run_lora_a_sgemm(x=x, weights=gate_up_lora_a)
|
|
|
|
lora_output = torch.empty(
|
|
(x.shape[0], 2 * output_dim),
|
|
device=x.device,
|
|
dtype=x.dtype,
|
|
)
|
|
|
|
# Compute lora for gate and up proj respectively
|
|
lora_output[:, :output_dim] = self.run_lora_b_sgemm(
|
|
x=lora_a_output[:, :lora_rank].contiguous(),
|
|
weights=gate_up_lora_b[0],
|
|
)
|
|
|
|
lora_output[:, output_dim:] = self.run_lora_b_sgemm(
|
|
x=lora_a_output[:, lora_rank:].contiguous(),
|
|
weights=gate_up_lora_b[1],
|
|
)
|
|
|
|
return lora_output * self.batch_info.scalings[0]
|