sglang0.4.5.post1/python/sglang/srt/lora/layers.py

348 lines
12 KiB
Python

from typing import List, Tuple
import torch
from torch import nn
from sglang.srt.distributed import (
get_tensor_model_parallel_rank,
split_tensor_along_last_dim,
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce,
)
from sglang.srt.layers.linear import (
ColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear,
)
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
from sglang.srt.lora.backend import BaseLoRABackend
class BaseLayerWithLoRA(nn.Module):
def __init__(
self,
base_layer: nn.Module,
lora_backend: BaseLoRABackend,
):
super().__init__()
self.base_layer: nn.Module = base_layer
self.set_lora: bool = False
self.lora_backend: BaseLoRABackend = lora_backend
def forward(self, x: torch.Tensor):
return self.base_layer.forward(x)
def set_lora_info(self, *args):
pass
def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int):
pass
def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int):
pass
class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
"""
Vocab parallel embedding layer with support for LoRA (Low-Rank Adaptation).
Note: The current version does not yet implement the LoRA functionality.
This class behaves exactly the same as the base VocabParallelEmbedding.
Future versions will integrate LoRA functionality to support efficient parameter fine-tuning.
"""
def __init__(
self,
base_layer: VocabParallelEmbedding,
lora_backend: BaseLoRABackend,
) -> None:
super().__init__(base_layer, lora_backend)
self.weight = base_layer.weight
class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
def __init__(
self,
base_layer: ColumnParallelLinear,
lora_backend: BaseLoRABackend,
) -> None:
super().__init__(base_layer, lora_backend)
def set_lora_info(
self,
A_buffer: torch.Tensor,
B_buffer: torch.Tensor,
):
self.set_lora = True
self.A_buffer = A_buffer
self.B_buffer = B_buffer
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
backend_kwargs = {"base_output": base_output}
lora_a_output = self.lora_backend.run_lora_a_sgemm(x, self.A_buffer)
lora_output = self.lora_backend.run_lora_b_sgemm(
lora_a_output,
self.B_buffer[0],
**backend_kwargs,
)
return (
lora_output
if self.lora_backend.fuse_output_add
else base_output + lora_output
)
def forward(self, input_: torch.Tensor):
# duplicate the logic in ColumnParallelLinear
bias = self.base_layer.bias if not self.base_layer.skip_bias_add else None
output_parallel = self.base_layer.quant_method.apply(
self.base_layer, input_, bias
)
if self.set_lora:
output_parallel = self.apply_lora(output_parallel, input_)
if self.base_layer.gather_output:
output = tensor_model_parallel_all_gather(output_parallel)
else:
output = output_parallel
output_bias = self.base_layer.bias if self.base_layer.skip_bias_add else None
return output, output_bias
def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int):
return A
def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int):
shard_size = self.base_layer.output_partition_sizes[0]
start_idx = tp_rank * shard_size
end_idx = (tp_rank + 1) * shard_size
B = B[start_idx:end_idx, :]
return B
class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
def __init__(
self,
base_layer: MergedColumnParallelLinear,
lora_backend: BaseLoRABackend,
) -> None:
super().__init__(base_layer, lora_backend)
def set_lora_info(
self,
A_buffer: torch.Tensor,
B_buffer: torch.Tensor,
):
self.set_lora = True
self.A_buffer_gate_up = A_buffer
if self.lora_backend.fuse_stacked_lora_b:
# TODO: avoid using contiguous() in GPU.
# B_buffer_gate_up: (num_lora, 2 * output_dim, r)
self.B_buffer_gate_up = torch.cat(
(B_buffer[0], B_buffer[1]), dim=-2
).contiguous()
else:
self.B_buffer_gate_up = (B_buffer[0], B_buffer[1])
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
backend_kwargs = {"base_output": base_output}
lora_output = self.lora_backend.run_gate_up_lora(
x,
self.A_buffer_gate_up,
self.B_buffer_gate_up,
**backend_kwargs,
)
return (
lora_output
if self.lora_backend.fuse_output_add
else base_output + lora_output
)
def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int):
return A
def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int):
# Since the outputs for both gate and up are identical, we use a random one.
shard_size = self.base_layer.output_partition_sizes[0]
start_idx = tp_rank * shard_size
end_idx = (tp_rank + 1) * shard_size
return B[:, start_idx:end_idx, :]
class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
def init__(
self,
base_layer: QKVParallelLinear,
lora_backend: BaseLoRABackend,
) -> None:
super().__init__(base_layer, lora_backend)
def set_lora_info(
self,
A_buffer_qkv: torch.Tensor,
B_buffer_q: torch.Tensor,
B_buffer_kv: torch.Tensor,
):
self.set_lora = True
self.A_buffer_qkv = A_buffer_qkv
if self.lora_backend.fuse_stacked_lora_b:
assert (
B_buffer_q.shape[-1] == B_buffer_kv.shape[-1]
), "The lora rank of q and kv should be the same when enabling fusion of qkv lora_b"
output_dim_q, output_dim_kv = B_buffer_q.shape[-2], B_buffer_kv.shape[-2]
# B_buffer_qkv: (num_lora, output_dim_q + 2 * output_dim_kv, r)
self.B_buffer_qkv = torch.cat(
(B_buffer_q[0], B_buffer_kv[0], B_buffer_kv[1]), dim=-2
).contiguous()
# Offsets of q/k/v in output dimension
self.output_offset = torch.tensor(
[
0,
output_dim_q,
output_dim_q + output_dim_kv,
output_dim_q + 2 * output_dim_kv,
],
dtype=torch.int32,
device=B_buffer_q.device,
)
# For computing number of launched blocks
self.max_qkv_out_dim = max(output_dim_q, output_dim_kv)
else:
self.B_buffer_qkv = (
B_buffer_q,
B_buffer_kv,
)
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
backend_kwargs = {"base_output": base_output}
if self.lora_backend.fuse_stacked_lora_b:
backend_kwargs["output_offset"] = self.output_offset
backend_kwargs["max_qkv_out_dim"] = self.max_qkv_out_dim
lora_output = self.lora_backend.run_qkv_lora(
x,
self.A_buffer_qkv,
self.B_buffer_qkv,
**backend_kwargs,
)
return (
lora_output
if self.lora_backend.fuse_output_add
else base_output + lora_output
)
def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int):
return A
def slice_lora_b_weights(
self, B: List[torch.Tensor], tp_rank: int
) -> Tuple[torch.Tensor, torch.Tensor]:
B_q, B_kv = B
base_layer = self.base_layer
q_proj_shard_size = base_layer.q_proj_shard_size
kv_proj_shard_size = base_layer.kv_proj_shard_size
num_kv_head_replicas = base_layer.num_kv_head_replicas
q_start_idx = q_proj_shard_size * tp_rank
q_end_idx = q_start_idx + q_proj_shard_size
kv_shard_id = tp_rank // num_kv_head_replicas
kv_start_idx = kv_proj_shard_size * kv_shard_id
kv_end_idx = kv_start_idx + kv_proj_shard_size
return B_q[q_start_idx:q_end_idx, :], B_kv[:, kv_start_idx:kv_end_idx, :]
class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
def __init__(
self,
base_layer: RowParallelLinear,
lora_backend: BaseLoRABackend,
) -> None:
super().__init__(base_layer, lora_backend)
def set_lora_info(self, A_buffer: torch.Tensor, B_buffer: torch.Tensor):
self.set_lora = True
self.A_buffer = A_buffer
self.B_buffer = B_buffer
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
backend_kwargs = {"base_output": base_output}
lora_a_output = self.lora_backend.run_lora_a_sgemm(x, self.A_buffer)
lora_output = self.lora_backend.run_lora_b_sgemm(
lora_a_output,
self.B_buffer[0],
**backend_kwargs,
)
return (
lora_output
if self.lora_backend.fuse_output_add
else base_output + lora_output
)
def forward(self, input_: torch.Tensor):
# duplicate the logic in RowParallelLinear
if self.base_layer.input_is_parallel:
input_parallel = input_
else:
tp_rank = get_tensor_model_parallel_rank()
splitted_input = split_tensor_along_last_dim(
input_, num_partitions=self.base_layer.tp_size
)
input_parallel = splitted_input[tp_rank].contiguous()
output_parallel = self.base_layer.quant_method.apply(
self.base_layer, input_parallel
)
if self.set_lora:
output_parallel = self.apply_lora(output_parallel, input_parallel)
if self.base_layer.reduce_results and self.base_layer.tp_size > 1:
output_ = tensor_model_parallel_all_reduce(output_parallel)
else:
output_ = output_parallel
if not self.base_layer.skip_bias_add:
output = (
output_ + self.base_layer.bias
if self.base_layer.bias is not None
else output_
)
output_bias = None
else:
output = output_
output_bias = self.base_layer.bias
return output, output_bias
def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int):
shard_size = self.base_layer.input_size_per_partition
start_idx = tp_rank * shard_size
end_idx = (tp_rank + 1) * shard_size
A = A[:, start_idx:end_idx].contiguous()
return A
def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int):
return B
def get_lora_layer(
layer: nn.Module, lora_backend: BaseLoRABackend
) -> BaseLayerWithLoRA:
supported_layer_types = {
# the order matters
VocabParallelEmbedding: VocabParallelEmbeddingWithLoRA,
QKVParallelLinear: QKVParallelLinearWithLoRA,
MergedColumnParallelLinear: MergedColumnParallelLinearWithLoRA,
ColumnParallelLinear: ColumnParallelLinearWithLoRA,
RowParallelLinear: RowParallelLinearWithLoRA,
}
for src_layer_type, lora_layer_type in supported_layer_types.items():
if isinstance(layer, src_layer_type): # pylint: disable=unidiomatic-typecheck
ret = lora_layer_type(layer, lora_backend)
return ret
raise Exception(f"No corresponding LoRA layer supported for {type(layer)}.")