sglang0.4.5.post1/python/sglang/srt/lora/lora.py

182 lines
7.3 KiB
Python

# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Integrates "S-LoRA: Serving Thousands of Concurrent LoRA Adapters"
# and "Punica: Multi-Tenant LoRA Serving"
# LoRA layers class inheritance adapted from:
# https://github.com/vllm-project/vllm/blob/4abf6336ec65c270343eb895e7b18786e9274176/vllm/lora/layers.py
import logging
import re
from typing import Dict, List
import torch
from torch import nn
from sglang.srt.configs.load_config import LoadConfig
from sglang.srt.hf_transformers_utils import AutoConfig
from sglang.srt.lora.backend import BaseLoRABackend
from sglang.srt.lora.lora_config import LoRAConfig
from sglang.srt.model_loader.loader import DefaultModelLoader
logger = logging.getLogger(__name__)
class LoRALayer(nn.Module):
def __init__(self, config: LoRAConfig, base_hf_config: AutoConfig):
super().__init__()
self.config: LoRAConfig = config
self.base_hf_config: AutoConfig = base_hf_config
# lora weights in cpu. The weights are loaded from checkpoint.
self.weights: Dict[str, torch.Tensor] = {}
class LoRAAdapter(nn.Module):
def __init__(
self,
uid: str,
config: LoRAConfig,
base_hf_config: AutoConfig,
load_config: LoadConfig,
lora_backend: BaseLoRABackend,
):
super().__init__()
self.uid: str = uid
self.config: LoRAConfig = config
assert self.config.hf_config["peft_type"].lower() == "lora"
self.base_hf_config: AutoConfig = base_hf_config
self.load_config: LoadConfig = load_config
self.lora_backend: BaseLoRABackend = lora_backend
self.scaling: float = self.config.lora_alpha / self.config.r
self.layers: List[LoRALayer] = nn.ModuleList(
[
LoRALayer(config, base_hf_config)
for i in range(base_hf_config.num_hidden_layers)
]
)
self.weights: Dict[str, torch.Tensor] = {}
# initialize the LoRA weights to cpu
def initialize_weights(self):
model_path = self.config.path
loader = DefaultModelLoader(self.load_config)
revision = getattr(self.config.hf_config, "revision", None)
for name, loaded_weight in loader._get_weights_iterator(
DefaultModelLoader.Source(
model_path, revision=revision, fall_back_to_pt=True
)
):
match = re.search(r"layers\.(\d+)\.", name)
if match is not None:
layer_id = int(match.group(1))
self.layers[layer_id].weights[name] = loaded_weight.cpu()
else:
self.weights[name] = loaded_weight.cpu()
# stack kv_proj and gate_up_proj
for i in range(self.base_hf_config.num_hidden_layers):
layer = self.layers[i]
weight_names = [name for name, _ in layer.weights.items()]
self.stack_qkv_proj(weight_names, layer.weights)
self.stack_gate_up_proj(weight_names, layer.weights)
def stack_qkv_proj(self, weight_names: List[str], weights: Dict[str, torch.Tensor]):
# Collect target q/k/v modules. This process is necessary since there might be no lora attached to k_proj
target_module = set()
for weight_name in weight_names:
if "k_proj" in weight_name:
target_module.add("k_proj")
if "q_proj" in weight_name:
target_module.add("q_proj")
if "v_proj" in weight_name:
target_module.add("v_proj")
if len(target_module) == 0:
return
for weight_name in weight_names:
# We assume every lora adaptor should contain lora modules for q_proj
if "q_proj" in weight_name:
q_name = weight_name
k_name = weight_name.replace("q_proj", "k_proj")
v_name = weight_name.replace("q_proj", "v_proj")
kv_name = weight_name.replace("q_proj", "kv_proj")
qkv_name = weight_name.replace("q_proj", "qkv_proj")
# If k_proj doesn't have lora, initialize it to zero
k_proj_weight = (
weights[k_name]
if "k_proj" in target_module
else torch.zeros_like(weights[v_name])
)
if "lora_A" in weight_name:
weights[qkv_name] = torch.cat(
(
weights[q_name],
k_proj_weight,
weights[v_name],
),
0,
)
weights.pop(q_name)
if "k_proj" in target_module:
weights.pop(k_name)
weights.pop(v_name)
else:
weights[kv_name] = torch.stack(
[
k_proj_weight,
weights[v_name],
],
dim=0,
)
if "k_proj" in target_module:
weights.pop(k_name)
weights.pop(v_name)
def stack_gate_up_proj(
self, weight_names: List[str], weights: Dict[str, torch.Tensor]
):
for weight_name in weight_names:
if "gate_proj" in weight_name:
up_name = weight_name.replace("gate_proj", "up_proj")
gate_up_name = weight_name.replace("gate_proj", "gate_up_proj")
if up_name not in weights:
logger.warning(
f"Gate projection {weight_name} does not have a corresponding up projection {up_name}. "
f"Initializing up projection to zero."
)
weights[up_name] = torch.zeros_like(weights[weight_name])
# FIXME: Add gate-only support for flashinfer in future implementations
assert self.lora_backend.name == "triton", (
f"LoRA weight initialization currently only supported for 'triton' backend. "
f"Received backend: {self.lora_backend.name}. Please verify your backend configuration "
f"or consider implementing custom initialization logic for other backends."
)
if "lora_A" in weight_name:
weights[gate_up_name] = torch.cat(
(weights[weight_name], weights[up_name]), 0
)
else:
weights[gate_up_name] = torch.stack(
[weights[weight_name], weights[up_name]], dim=0
)
weights.pop(weight_name)
if up_name in weights:
weights.pop(up_name)