# Copyright 2023-2024 SGLang Team # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== # Integrates "S-LoRA: Serving Thousands of Concurrent LoRA Adapters" # and "Punica: Multi-Tenant LoRA Serving" import logging from typing import Dict, List, Set, Tuple import torch from sglang.srt.configs.load_config import LoadConfig from sglang.srt.hf_transformers_utils import AutoConfig from sglang.srt.lora.backend import BaseLoRABackend, get_backend_from_name from sglang.srt.lora.layers import BaseLayerWithLoRA, get_lora_layer from sglang.srt.lora.lora import LoRAAdapter from sglang.srt.lora.lora_config import LoRAConfig from sglang.srt.lora.mem_pool import LoRAMemoryPool from sglang.srt.lora.utils import ( LoRABatchInfo, LoRAType, get_customized_names_from_hf_names, get_layer_id, get_stacked_name, get_weight_name, ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.utils import replace_submodule logger = logging.getLogger(__name__) class LoRAManager: def __init__( self, base_model: torch.nn.Module, lora_paths: Dict[str, str], base_hf_config: AutoConfig, max_loras_per_batch: int, load_config: LoadConfig, dtype: torch.dtype, lora_backend: str = "triton", tp_size: int = 1, tp_rank: int = 0, ): self.base_model: torch.nn.Module = base_model self.lora_paths: Dict[str, str] = lora_paths self.base_hf_config: AutoConfig = base_hf_config self.max_loras_per_batch: int = max_loras_per_batch self.load_config: LoadConfig = load_config self.dtype: torch.dtype = dtype self.device: torch.device = next(self.base_model.parameters()).device self.tp_size: int = tp_size self.tp_rank: int = tp_rank # LoRA backend for running sgemm kernels logger.info(f"Using {lora_backend} as backend of LoRA kernels.") backend_type = get_backend_from_name(lora_backend) self.lora_backend: BaseLoRABackend = backend_type(lora_backend) self.init_loras() self.init_lora_memory_pool() def init_loras(self): # Config of each LoRA adapter self.configs: Dict[str, LoRAConfig] = {} # Target module names in huggingface lora configs. # e.g., {"k_proj", "q_proj", "v_proj", "o_proj"} self.hf_target_names: Set[str] = set() for name, path in self.lora_paths.items(): self.configs[name] = LoRAConfig(path) self.hf_target_names.update(self.configs[name].target_modules) # Target lora weight names for lora_a and lora_b modules repectively. # e.g., {("qkv_proj", "q_proj"), ("qkv_proj", "kv_proj")} self.lora_weight_names: Set[Tuple[str]] = set( [get_stacked_name(module) for module in self.hf_target_names] ) # load all weights to cpu self.loras: Dict[str, LoRAAdapter] = {} for name in self.lora_paths.keys(): lora_adapter = LoRAAdapter( name, self.configs[name], self.base_hf_config, self.load_config, self.lora_backend, ) lora_adapter.initialize_weights() self.loras[name] = lora_adapter # misc lora configs self.max_lora_dim: int = max([x.hf_config["r"] for x in self.configs.values()]) if self.lora_backend == "flashinfer": # FIXME remove the restrictions after supporting multi-rank for flashinfer backend max_lora_dim = max([x.hf_config["r"] for x in self.configs.values()]) scaling = list(self.loras.values())[0].scaling assert all(x.hf_config["r"] == max_lora_dim for x in self.configs.values()) assert all(x.scaling == scaling for x in self.loras.values()) # Convert original model layers to layers with LoRA self.convert_to_lora_layers() def init_lora_memory_pool(self): # Initialize memory pool self.memory_pool = LoRAMemoryPool( self.base_hf_config, self.max_loras_per_batch, self.max_lora_dim, self.dtype, self.tp_size, self.tp_rank, self.lora_modules, ) # Initialize target lora modules in memory pool self.memory_pool.init_buffers(self.lora_weight_names, self.base_model) def prepare_lora_batch(self, forward_batch: ForwardBatch): # load active loras into lora memory pool cur_uids = set(forward_batch.lora_paths) assert len(cur_uids) <= self.max_loras_per_batch self.memory_pool.prepare_lora_batch(cur_uids, self.loras) # FIXME: Handle lora uid with None more safely if cur_uids == set([None]): return # set up batch info shared by all lora moruldes bs = forward_batch.batch_size seg_lens = ( forward_batch.extend_seq_lens if forward_batch.forward_mode.is_extend() else torch.ones(bs, device=self.device) ) seg_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device=self.device) seg_indptr[1:] = torch.cumsum(seg_lens, dim=0) max_len = int(torch.max(seg_lens)) weight_indices = torch.empty((bs,), dtype=torch.int64, device=self.device) lora_ranks = torch.empty( (self.max_loras_per_batch,), dtype=torch.int64, device="cuda" ) scalings = torch.empty( (self.max_loras_per_batch,), dtype=torch.float, device="cuda" ) for i, lora_path in enumerate(forward_batch.lora_paths): weight_indices[i] = self.memory_pool.get_buffer_id(lora_path) lora = self.loras[lora_path] lora_ranks[weight_indices[i]] = lora.config.hf_config["r"] scalings[weight_indices[i]] = lora.scaling batch_info = LoRABatchInfo( bs=bs, seg_lens=seg_lens, seg_indptr=seg_indptr, max_len=max_len, weight_indices=weight_indices, lora_ranks=lora_ranks, scalings=scalings, ) self.lora_backend.set_batch_info(batch_info) # call set_lora_info for each lora modules for layer_id, modules in self.lora_modules.items(): for module_name, module in modules: if "qkv_proj" in module_name: module.set_lora_info( self.memory_pool.get_tensor( "qkv_proj", layer_id, LoRAType.LORA_A ), self.memory_pool.get_tensor( "q_proj", layer_id, LoRAType.LORA_B ), self.memory_pool.get_tensor( "kv_proj", layer_id, LoRAType.LORA_B ), ) else: weight_name = get_weight_name( module_name, self.lora_weight_names, LoRAType.LORA_A ) module.set_lora_info( self.memory_pool.get_tensor( weight_name, layer_id, LoRAType.LORA_A ), self.memory_pool.get_tensor( weight_name, layer_id, LoRAType.LORA_B ), ) def set_lora_module(self, module_name, module): lora_module = get_lora_layer(module, self.lora_backend) replace_submodule(self.base_model, module_name, lora_module) return lora_module def convert_to_lora_layers(self): # Target module names of customized layers defined in python/sglang/srt/layers # e.g., {"qkv_proj", "o_proj"} customized_target_names = get_customized_names_from_hf_names( self.hf_target_names, self.base_model ) # Monkey patch to use the LoRA version layers self.lora_modules: Dict[int, List[Tuple[str, BaseLayerWithLoRA]]] = { i: [] for i in range(self.base_hf_config.num_hidden_layers) } for module_name, module in self.base_model.named_modules(): # The module should be converted if it is included in target_names if module_name.split(".")[-1] in customized_target_names: layer_id = get_layer_id(module_name) self.lora_modules[layer_id].append( (module_name, self.set_lora_module(module_name, module)) )