sglang0.4.5.post1/python/sglang/srt/lora/mem_pool.py

240 lines
9.4 KiB
Python

from typing import Dict, List, Optional, Set, Tuple
import torch
from sglang.srt.distributed import divide
from sglang.srt.hf_transformers_utils import AutoConfig
from sglang.srt.lora.layers import BaseLayerWithLoRA
from sglang.srt.lora.lora import LoRAAdapter
from sglang.srt.lora.utils import (
ROW_PARALLELISM_LINEAR_LORA_NAMES,
LoRAType,
get_hidden_dim,
get_stacked_multiply,
get_weight_name,
)
class LoRAMemoryPool:
"""Class for memory pool management of lora modules"""
def __init__(
self,
base_hf_config: AutoConfig,
max_loras_per_batch: int,
max_lora_dim: int,
dtype: torch.dtype,
tp_size: int,
tp_rank: int,
lora_modules: Dict[int, List[Tuple[str, BaseLayerWithLoRA]]],
):
self.base_hf_config: AutoConfig = base_hf_config
self.num_layer: int = base_hf_config.num_hidden_layers
self.max_loras_per_batch: int = max_loras_per_batch
self.max_lora_dim: int = max_lora_dim
self.dtype: torch.dtype = dtype
self.tp_size: int = tp_size
self.tp_rank: int = tp_rank
self.lora_modules: Dict[int, List[Tuple[str, BaseLayerWithLoRA]]] = lora_modules
# Both A_buffer and B_buffer maps lora weight names to its buffer space.
# A_buffer contains num_layer number of row-major tensors with shape
# (max_loras_per_batch, stacked_num * max_lora_dim, input_dim)
# B_buffer contains num_layer number of column-major tensors with shape
# (stacked_num, max_loras_per_batch, output_dim, max_lora_dim)
self.A_buffer: Dict[str, List[torch.Tensor]] = {}
self.B_buffer: Dict[str, List[torch.Tensor]] = {}
# Lora uid -> buffer idx in memory pool
self.uid_to_buffer_id: Dict[Optional[str], int] = {}
# Buffer idx -> lora uid in memory pool
# All uids are initalized as empty strings for empty buffer slots
# Here we don't initalize to None since None is a valid uid
self.buffer_id_to_uid: List[Optional[str]] = [""] * self.max_loras_per_batch
def get_lora_A_shape(
self, module_name: str, base_model: torch.nn.Module
) -> Tuple[int]:
"""
Given a module_name (might be a stacked name), return the hidden dims of modules's input and output.
"""
input_dim, _ = get_hidden_dim(module_name, self.base_hf_config, base_model)
c = get_stacked_multiply(module_name)
if self.tp_size > 1:
if module_name in ROW_PARALLELISM_LINEAR_LORA_NAMES:
input_dim = divide(input_dim, self.tp_size)
return (
self.max_loras_per_batch,
self.max_lora_dim * c,
input_dim,
)
def get_lora_B_shape(
self, module_name: str, base_model: torch.nn.Module
) -> Tuple[int]:
"""
Given a module_name (might be a stacked name), return the hidden dims of modules's input and output.
"""
_, output_dim = get_hidden_dim(module_name, self.base_hf_config, base_model)
c = get_stacked_multiply(module_name)
if self.tp_size > 1:
if module_name not in ROW_PARALLELISM_LINEAR_LORA_NAMES:
output_dim = divide(output_dim, self.tp_size)
return (
c,
self.max_loras_per_batch,
output_dim,
self.max_lora_dim,
)
def init_buffers(
self,
lora_weight_names: Set[Tuple[str]],
base_model: torch.nn.Module,
):
# lora_weight_names is a set of name pairs indicating each pair of lora modules to load
# e.g., {("qkv_proj", "q_proj"), ("qkv_proj", "kv_proj"), ("o_proj", "o_proj")}
self.lora_weight_names: Set[Tuple[str]] = lora_weight_names
device = next(base_model.parameters()).device
lora_module_A_names = set([name[0] for name in lora_weight_names])
lora_module_B_names = set([name[1] for name in lora_weight_names])
# Init A tensor, column_major=False
for module_A in lora_module_A_names:
lora_A_shape = self.get_lora_A_shape(module_A, base_model)
self.A_buffer[module_A] = [
torch.empty(
lora_A_shape,
dtype=self.dtype,
device=device,
)
for i in range(self.num_layer)
]
# Init B tensor, column_major=True
for module_B in lora_module_B_names:
lora_B_shape = self.get_lora_B_shape(module_B, base_model)
self.B_buffer[module_B] = [
torch.empty(
lora_B_shape,
dtype=self.dtype,
device=device,
)
for _ in range(self.num_layer)
]
def prepare_lora_batch(
self,
cur_uids: Set[Optional[str]],
lora_adapters: Dict[str, LoRAAdapter],
):
def get_available_buffer_slot():
for buffer_id in range(self.max_loras_per_batch):
# Prioritize empty slots
if self.buffer_id_to_uid[buffer_id] == "":
return buffer_id, ""
for buffer_id in range(self.max_loras_per_batch):
# Evict unneeded lora
if self.buffer_id_to_uid[buffer_id] not in cur_uids:
return buffer_id, self.buffer_id_to_uid[buffer_id]
raise ValueError(
"No available buffer slots found. Please ensure the number of active loras is less than max_loras_per_batch."
)
for uid in cur_uids:
if uid not in self.uid_to_buffer_id:
buffer_id, evicted_lora_uid = get_available_buffer_slot()
if evicted_lora_uid != "":
self.uid_to_buffer_id.pop(evicted_lora_uid)
self.load_lora_weight_to_buffer(
uid, buffer_id, lora_adapters.get(uid, None)
)
self.uid_to_buffer_id[uid] = buffer_id
self.buffer_id_to_uid[buffer_id] = uid
def load_lora_weight_to_buffer(
self, uid: str, buffer_id: int, lora_adapter: LoRAAdapter = None
):
if uid is None:
for i in range(self.num_layer):
for k in self.A_buffer.keys():
self.A_buffer[k][i][buffer_id] = 0
return
assert lora_adapter is not None
lora_rank = lora_adapter.config.hf_config["r"]
for layer_id in range(self.num_layer):
layer_weights = lora_adapter.layers[layer_id].weights
temp_A_buffer: Dict[str, torch.Tensor] = {}
temp_B_buffer: Dict[str, torch.Tensor] = {}
for name, weights in layer_weights.items():
if "lora_A" in name:
lora_weight_name = get_weight_name(
name, self.lora_weight_names, LoRAType.LORA_A
)
temp_A_buffer[lora_weight_name] = weights
else:
lora_weight_name = get_weight_name(
name, self.lora_weight_names, LoRAType.LORA_B
)
temp_B_buffer[lora_weight_name] = weights
if self.tp_size > 1:
cur_layer_modules = self.lora_modules[layer_id]
for module_name, module in cur_layer_modules:
if "qkv_proj" in module_name:
temp_A_buffer["qkv_proj"] = module.slice_lora_a_weights(
temp_A_buffer["qkv_proj"], self.tp_rank
)
temp_B_buffer["q_proj"], temp_B_buffer["kv_proj"] = (
module.slice_lora_b_weights(
[temp_B_buffer["q_proj"], temp_B_buffer["kv_proj"]],
self.tp_rank,
)
)
else:
weight_name = get_weight_name(
module_name, self.lora_weight_names, LoRAType.LORA_A
)
temp_A_buffer[weight_name] = module.slice_lora_a_weights(
temp_A_buffer[weight_name], self.tp_rank
)
temp_B_buffer[weight_name] = module.slice_lora_b_weights(
temp_B_buffer[weight_name], self.tp_rank
)
for name, weights in temp_A_buffer.items():
c = get_stacked_multiply(name)
self.A_buffer[name][layer_id][buffer_id][: lora_rank * c, :].copy_(
weights
)
for name, weights in temp_B_buffer.items():
c = get_stacked_multiply(name)
if c > 1:
for stacked_id in range(c):
self.B_buffer[name][layer_id][stacked_id][buffer_id][
:, :lora_rank
].copy_(weights[stacked_id])
else:
self.B_buffer[name][layer_id][0][buffer_id][:, :lora_rank].copy_(
weights
)
def get_tensor(
self, weight_name: str, layer_id: int, lora_type: LoRAType
) -> torch.Tensor:
if lora_type == LoRAType.LORA_A:
return self.A_buffer[weight_name][layer_id]
return self.B_buffer[weight_name][layer_id]
def get_buffer_id(self, lora_uid: str):
return self.uid_to_buffer_id[lora_uid]