89 lines
2.5 KiB
Python
89 lines
2.5 KiB
Python
import torch
|
|
|
|
from sglang.srt.lora.backend import BaseLoRABackend
|
|
from sglang.srt.lora.triton_ops import (
|
|
gate_up_lora_b_fwd,
|
|
qkv_lora_b_fwd,
|
|
sgemm_lora_a_fwd,
|
|
sgemm_lora_b_fwd,
|
|
)
|
|
from sglang.srt.lora.utils import LoRABatchInfo
|
|
|
|
|
|
class TritonLoRABackend(BaseLoRABackend):
|
|
|
|
def __init__(self, name: str, batch_info: LoRABatchInfo = None):
|
|
super().__init__(name, batch_info)
|
|
|
|
def run_lora_a_sgemm(
|
|
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
|
|
) -> torch.Tensor:
|
|
return sgemm_lora_a_fwd(x, weights, self.batch_info)
|
|
|
|
def run_lora_b_sgemm(
|
|
self,
|
|
x: torch.Tensor,
|
|
weights: torch.Tensor,
|
|
base_output: torch.Tensor = None,
|
|
*args,
|
|
**kwargs
|
|
) -> torch.Tensor:
|
|
return sgemm_lora_b_fwd(x, weights, self.batch_info, base_output)
|
|
|
|
def run_qkv_lora(
|
|
self,
|
|
x: torch.Tensor,
|
|
qkv_lora_a: torch.Tensor,
|
|
qkv_lora_b: torch.Tensor,
|
|
output_offset: torch.Tensor,
|
|
max_qkv_out_dim: int,
|
|
base_output: torch.Tensor = None,
|
|
*args,
|
|
**kwargs
|
|
) -> torch.Tensor:
|
|
|
|
# x: (s, input_dim)
|
|
# qkv_lora_a: (num_lora, 3 * r, input_dim)
|
|
# qkv_lora_b: (num_lora, output_dim_q + 2 * output_dim_kv, r)
|
|
assert isinstance(qkv_lora_b, torch.Tensor)
|
|
|
|
lora_a_output = sgemm_lora_a_fwd(x, qkv_lora_a, self.batch_info, stack_num=3)
|
|
lora_output = qkv_lora_b_fwd(
|
|
lora_a_output,
|
|
qkv_lora_b,
|
|
self.batch_info,
|
|
output_offset,
|
|
max_qkv_out_dim,
|
|
base_output,
|
|
)
|
|
return lora_output
|
|
|
|
def run_gate_up_lora(
|
|
self,
|
|
x: torch.Tensor,
|
|
gate_up_lora_a: torch.Tensor,
|
|
gate_up_lora_b: torch.Tensor,
|
|
base_output: torch.Tensor = None,
|
|
*args,
|
|
**kwargs
|
|
) -> torch.Tensor:
|
|
|
|
# x: (s, input_dim)
|
|
# gate_up_lora_a: (num_lora, 2 * r, input_dim)
|
|
# gate_up_lora_b: (num_lora, 2 * output_dim, r)
|
|
assert isinstance(gate_up_lora_b, torch.Tensor)
|
|
output_dim = gate_up_lora_b.shape[-2] // 2
|
|
|
|
# lora_a_output: (s, 2 * r)
|
|
lora_a_output = sgemm_lora_a_fwd(
|
|
x, gate_up_lora_a, self.batch_info, stack_num=2
|
|
)
|
|
lora_output = gate_up_lora_b_fwd(
|
|
lora_a_output,
|
|
gate_up_lora_b,
|
|
self.batch_info,
|
|
output_dim,
|
|
base_output,
|
|
)
|
|
return lora_output
|