# Copyright 2023-2024 SGLang Team # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== import multiprocessing as mp import os from dataclasses import dataclass from typing import List, Optional, Tuple, Union import torch import torch.nn.functional as F from transformers import ( AutoModel, AutoModelForCausalLM, AutoModelForVision2Seq, AutoProcessor, ) from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.srt.server import Engine from sglang.srt.utils import load_image from sglang.test.test_utils import DEFAULT_PORT_FOR_SRT_TEST_RUNNER, calculate_rouge_l DEFAULT_PROMPTS = [ "Apple is red. Banana is Yellow. " * 800 + "Apple is", "The capital of the United Kingdom is", "Today is a sunny day and I like", "AI is a field of computer science focused on", # the output of gemma-2-2b from SRT is unstable on the commented prompt # "The capital of France is", ] dirpath = os.path.dirname(__file__) with open(os.path.join(dirpath, "long_prompt.txt"), "r") as f: long_prompt = f.read() DEFAULT_PROMPTS.append(long_prompt) NUM_TOP_LOGPROBS = 5 def get_dtype_str(torch_dtype): if torch_dtype is torch.float16: return "float16" else: raise NotImplementedError() def get_top_logprobs(logits, k): logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32) del logits logprobs, top_indices = torch.topk(logprobs, k=k, dim=-1) return logprobs def get_token_ids_logprobs(logits, token_ids): logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32) del logits logprobs = logprobs[..., token_ids] return logprobs def _get_sentence_transformer_embedding_model(model_path, torch_dtype): from sentence_transformers import SentenceTransformer from sentence_transformers.util import is_sentence_transformer_model if is_sentence_transformer_model(model_path): model = SentenceTransformer( model_path, model_kwargs={"torch_dtype": torch_dtype}, ) else: # if no pre-trained sentence-transformers model from sentence_transformers import models word_embedding_model = models.Transformer(model_path).to(dtype=torch_dtype) pooling_model = models.Pooling( word_embedding_model.get_word_embedding_dimension(), pooling_mode="lasttoken", ) model = SentenceTransformer(modules=[word_embedding_model, pooling_model]) return model.cuda() @dataclass class ModelOutput: output_strs: List[str] = None output_ids: List[int] = None top_input_logprobs: List[torch.Tensor] = None top_output_logprobs: List[torch.Tensor] = None top_output_logprob_idx: List[List[int]] = None embed_logits: List[torch.Tensor] = None scores: List[float] = None input_token_logprobs_lst: List[List[Tuple[float, int, None]]] = None output_token_logprobs_lst: List[List[Tuple[float, int, None]]] = None token_ids_input_logprobs: List[torch.Tensor] = None token_ids_output_logprobs: List[torch.Tensor] = None class HFRunner: def __init__( self, model_path: str, torch_dtype: torch.dtype, model_type: str = "generation", output_str_only: bool = False, trust_remote_code: bool = False, ): self.model_type = model_type self.output_str_only = output_str_only self.trust_remote_code = trust_remote_code self.in_queue = mp.Queue() self.out_queue = mp.Queue() self.model_proc = mp.Process( target=self.start_model_process, args=( self.in_queue, self.out_queue, model_path, torch_dtype, ), ) self.model_proc.start() def needs_trust_remote_code(self, model_path): models_needs_trust_remote = [ "LxzGordon/URM-LLaMa-3.1-8B", ] if model_path in models_needs_trust_remote: return True return False # copy from https://huggingface.co/Alibaba-NLP/gme-Qwen2-VL-2B-Instruct/blob/main/gme_inference.py def _get_gme_qwen2_vl_embeddings( self, prompts, image_data: Optional[List[str]] = None ): images = None if image_data is not None: images = [load_image(image)[0] for image in image_data] inputs = self.processor( text=prompts, images=images, padding=True, truncation=True, max_length=1800, return_tensors="pt", ) inputs = {k: v.to(self.model.device) for k, v in inputs.items()} with torch.no_grad(): embeddings = self._forward_gme_qwen2_vl(**inputs) return embeddings.tolist() def _forward_gme_qwen2_vl( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, pixel_values: Optional[torch.Tensor] = None, image_grid_thw: Optional[torch.LongTensor] = None, pooling_mask: Optional[torch.LongTensor] = None, **kwargs, ) -> torch.Tensor: if inputs_embeds is None: inputs_embeds = self.model.model.embed_tokens(input_ids) if pixel_values is not None: pixel_values = pixel_values.type(self.model.visual.get_dtype()) image_embeds = self.model.visual( pixel_values, grid_thw=image_grid_thw ).to(inputs_embeds.device) image_mask = input_ids == self.model.config.image_token_id inputs_embeds[image_mask] = image_embeds if attention_mask is not None: attention_mask = attention_mask.to(inputs_embeds.device) outputs = self.model.model( input_ids=None, position_ids=position_ids, attention_mask=attention_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, ) pooling_mask = attention_mask if pooling_mask is None else pooling_mask left_padding = pooling_mask[:, -1].sum() == pooling_mask.shape[0] # TODO if left_padding: embeddings = outputs.last_hidden_state[:, -1] else: sequence_lengths = pooling_mask.sum(dim=1) - 1 batch_size = outputs.last_hidden_state.shape[0] embeddings = outputs.last_hidden_state[ torch.arange(batch_size, device=outputs.last_hidden_state.device), sequence_lengths, ] embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1) return embeddings.contiguous() def start_model_process(self, in_queue, out_queue, model_path, torch_dtype): # Apply model-specific patches monkey_patch_gemma2_sdpa() # Load the model and tokenizer if self.model_type == "generation": self.base_model = AutoModelForCausalLM.from_pretrained( model_path, torch_dtype=torch_dtype, trust_remote_code=self.trust_remote_code, low_cpu_mem_usage=True, ).cuda() elif self.model_type == "embedding": if "gme-qwen2-vl" in model_path.lower(): self.model = AutoModelForVision2Seq.from_pretrained( model_path, torch_dtype=torch_dtype, trust_remote_code=False, low_cpu_mem_usage=True, ).cuda() self.processor = AutoProcessor.from_pretrained(model_path) elif "clip" in model_path.lower(): self.model = AutoModel.from_pretrained(model_path).cuda() self.processor = AutoProcessor.from_pretrained(model_path) else: self.model = _get_sentence_transformer_embedding_model( model_path, torch_dtype ) elif self.model_type == "reward": from transformers import AutoModelForSequenceClassification self.model = AutoModelForSequenceClassification.from_pretrained( model_path, torch_dtype=torch_dtype, trust_remote_code=self.needs_trust_remote_code(model_path), ).cuda() else: raise Exception(f"Unrecognized model type {self.model_type}") self.tokenizer = get_tokenizer( model_path, torch_dtype=torch.dtype, trust_remote_code=self.trust_remote_code, ) # Run forward while True: prompts, image_data, max_new_tokens, lora_paths, token_ids_logprob = ( in_queue.get() ) if lora_paths is not None: assert len(prompts) == len(lora_paths) if prompts is not None: if self.model_type == "generation": out_queue.put( self.forward_generation_raw( base_model=self.base_model, prompts=prompts, max_new_tokens=max_new_tokens, tokenizer=self.tokenizer, lora_paths=lora_paths, torch_dtype=torch_dtype, output_str_only=self.output_str_only, token_ids_logprob=token_ids_logprob, ) ) elif self.model_type == "embedding": assert not self.output_str_only if "gme-qwen2-vl" in model_path.lower(): logits = self._get_gme_qwen2_vl_embeddings(prompts, image_data) elif "clip" in model_path.lower(): if image_data is not None: image = load_image(image_data) inputs = self.processor( images=image[0], return_tensors="pt" ) logits = self.model.get_image_features( pixel_values=inputs.data["pixel_values"].cuda(), ).tolist() else: inputs = self.tokenizer( prompts, padding=True, return_tensors="pt" ) logits = self.model.get_text_features( input_ids=inputs.data["input_ids"].cuda(), attention_mask=inputs.data["attention_mask"].cuda(), ).tolist() else: logits = self.model.encode(prompts).tolist() out_queue.put(ModelOutput(embed_logits=logits)) elif self.model_type == "reward": scores = [] for conv in prompts: conv_formatted = self.tokenizer.apply_chat_template( conv, tokenize=False ) conv_tokenized = self.tokenizer( conv_formatted, return_tensors="pt" ).to("cuda") scores.append( float(self.model(**conv_tokenized).logits[0][0].item()) ) out_queue.put(ModelOutput(scores=scores)) else: raise Exception(f"Unrecognized model type {self.model_type}") def forward( self, prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS, image_data: Optional[List[str]] = None, max_new_tokens: int = 8, lora_paths: Optional[List[str]] = None, token_ids_logprob: Optional[int] = None, ): self.in_queue.put( (prompts, image_data, max_new_tokens, lora_paths, token_ids_logprob) ) return self.out_queue.get() def terminate(self): self.model_proc.terminate() self.in_queue = self.out_queue = None def __enter__(self): return self def __exit__(self, exc_type, exc_value, traceback): self.model_proc.terminate() self.in_queue = self.out_queue = None @staticmethod def forward_generation_raw( base_model, prompts: Union[List[str], List[torch.Tensor]], max_new_tokens: int, tokenizer, torch_dtype: torch.dtype, lora_paths: Optional[List[str]] = None, output_str_only: bool = False, token_ids_logprob: Optional[int] = None, ) -> ModelOutput: output_strs = [] top_input_logprobs = [] top_output_logprobs = [] if token_ids_logprob is not None: token_ids_input_logprobs = [] token_ids_output_logprobs = [] else: token_ids_input_logprobs = token_ids_output_logprobs = None for i, p in enumerate(prompts): if isinstance(p, str): input_ids = tokenizer.encode(p, return_tensors="pt").cuda() else: input_ids = torch.tensor([p], device="cuda") if lora_paths is not None and lora_paths[i] is not None: from peft import PeftModel model = PeftModel.from_pretrained( base_model, lora_paths[i], torch_dtype=torch_dtype, is_trainable=False, ) else: model = base_model outputs = model.generate( input_ids, do_sample=False, temperature=None, top_p=None, max_new_tokens=max_new_tokens, return_dict_in_generate=True, output_scores=(not output_str_only), ) text = tokenizer.decode( outputs[0][0][len(input_ids[0]) :], skip_special_tokens=True ) # Check if the text is empty or only whitespace. if not text.strip(): raise ValueError( "Received an empty text response. Please verify your input or model configuration." ) output_strs.append(text) if not output_str_only: # outputs.scores: (num_token, 1, vocab_size) top_output_logprobs.append( [ get_top_logprobs(logits[0], NUM_TOP_LOGPROBS).tolist() for logits in outputs.scores ] ) if token_ids_logprob is not None: token_ids_output_logprobs.append( [ get_token_ids_logprobs( logits[0], token_ids_logprob ).tolist() for logits in outputs.scores ] ) del outputs input_logits = model.forward(input_ids).logits[0] top_input_logprobs.append( get_top_logprobs(input_logits, NUM_TOP_LOGPROBS).tolist() ) if token_ids_logprob is not None: token_ids_input_logprobs.append( get_token_ids_logprobs(input_logits, token_ids_logprob).tolist() ) del input_logits return ModelOutput( output_strs=output_strs, top_input_logprobs=top_input_logprobs, top_output_logprobs=top_output_logprobs, token_ids_input_logprobs=token_ids_input_logprobs, token_ids_output_logprobs=token_ids_output_logprobs, ) class SRTRunner: def __init__( self, model_path: str, torch_dtype: torch.dtype, model_type: str, tp_size: int = 1, port: int = DEFAULT_PORT_FOR_SRT_TEST_RUNNER, lora_paths: List[str] = None, max_loras_per_batch: int = 4, lora_backend: str = "triton", disable_cuda_graph: bool = False, disable_radix_cache: bool = False, chunked_prefill_size: Optional[int] = None, dp_size: int = 1, tokenizer_path: Optional[str] = None, enable_ep_moe: bool = False, mem_fraction_static: float = 0.65, trust_remote_code: bool = False, speculative_draft_model_path: Optional[str] = None, speculative_algorithm: Optional[str] = None, speculative_num_steps: Optional[int] = None, speculative_eagle_topk: Optional[int] = None, speculative_num_draft_tokens: Optional[int] = None, disable_overlap_schedule: bool = False, disable_custom_all_reduce: bool = False, ): self.model_type = model_type self.is_generation = model_type == "generation" enable_dp_attention = dp_size > 1 spec_kwargs = {} if speculative_draft_model_path: spec_kwargs["speculative_draft_model_path"] = speculative_draft_model_path spec_kwargs["speculative_algorithm"] = speculative_algorithm spec_kwargs["speculative_num_steps"] = speculative_num_steps spec_kwargs["speculative_eagle_topk"] = speculative_eagle_topk spec_kwargs["speculative_num_draft_tokens"] = speculative_num_draft_tokens self.engine = Engine( model_path=model_path, tp_size=tp_size, dtype=get_dtype_str(torch_dtype), port=port, mem_fraction_static=mem_fraction_static, trust_remote_code=trust_remote_code, is_embedding=not self.is_generation, lora_paths=lora_paths, max_loras_per_batch=max_loras_per_batch, lora_backend=lora_backend, disable_cuda_graph=disable_cuda_graph, disable_radix_cache=disable_radix_cache, chunked_prefill_size=chunked_prefill_size, enable_dp_attention=enable_dp_attention, dp_size=dp_size, tokenizer_path=tokenizer_path, enable_ep_moe=enable_ep_moe, disable_overlap_schedule=disable_overlap_schedule, cuda_graph_max_bs=4, disable_custom_all_reduce=disable_custom_all_reduce, **spec_kwargs, ) if tokenizer_path is None: self.tokenizer = get_tokenizer( model_path, trust_remote_code=trust_remote_code ) else: self.tokenizer = None def forward( self, prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS, image_data: Optional[List[str]] = None, max_new_tokens: int = 8, lora_paths: Optional[List[str]] = None, logprob_start_len: int = 0, top_k: Optional[int] = None, token_ids_logprob: Optional[List[int]] = None, ): if self.is_generation: return self.forward_generation_raw( engine=self.engine, prompts=prompts, max_new_tokens=max_new_tokens, lora_paths=lora_paths, logprob_start_len=logprob_start_len, top_k=top_k, token_ids_logprob=token_ids_logprob, ) else: if self.model_type == "embedding": response = self.engine.encode(prompt=prompts, image_data=image_data) if isinstance(response, list): logits = [x["embedding"] for x in response] else: logits = [response["embedding"]] return ModelOutput(embed_logits=logits) # reward model else: response = self.engine.encode(prompts) scores = [x["embedding"][0] for x in response] return ModelOutput(scores=scores) def batch_forward( self, prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS, image_data: Optional[List[str]] = None, max_new_tokens=8, lora_paths=None, ): """ testing serving by sending all prompts once only return output strings and no logprobs """ if self.is_generation: return self.batch_forward_generation_raw( engine=self.engine, prompts=prompts, max_new_tokens=max_new_tokens, lora_paths=lora_paths, ) else: response = self.engine.encode(prompts, image_data) if self.model_type == "embedding": logits = [x["embedding"] for x in response] return ModelOutput(embed_logits=logits) else: scores = [x["embedding"][0] for x in response] return ModelOutput(scores=scores) def __enter__(self): return self def __exit__(self, exc_type, exc_value, traceback): self.engine.shutdown() del self.engine @staticmethod def forward_generation_raw( engine: Engine, prompts: Union[List[str], List[torch.Tensor]], max_new_tokens: int = 8, lora_paths: Optional[List[str]] = None, logprob_start_len: int = 0, top_k: Optional[int] = None, token_ids_logprob: Optional[List[int]] = None, ): # the return value contains logprobs from prefill output_strs = [] output_ids = [] # Input logprobs. Note that the last item in input logprob is equivalent to # the first item in the output logprob. top_input_logprobs = [] input_token_logprobs_lst = [] top_output_logprobs = [] output_token_logprobs_lst = [] top_output_logprob_idx = [] if token_ids_logprob is not None: token_ids_input_logprobs = [] token_ids_output_logprobs = [] else: token_ids_input_logprobs = token_ids_output_logprobs = None sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0} if top_k: sampling_params["top_k"] = top_k for i, prompt in enumerate(prompts): response = engine.generate( prompt, lora_path=lora_paths[i] if lora_paths else None, sampling_params=sampling_params, return_logprob=True, logprob_start_len=logprob_start_len, top_logprobs_num=NUM_TOP_LOGPROBS, token_ids_logprob=token_ids_logprob, ) text = response["text"] # Check if the text is empty or only whitespace. if not text.strip(): raise ValueError( "Received an empty text response. Please verify your input or model configuration." ) output_strs.append(text) # output_ids.append(response["output_ids"]) input_token_logprobs = response["meta_info"]["input_token_logprobs"] output_token_logprobs = response["meta_info"]["output_token_logprobs"] # print(i, input_token_logprobs) # print(i, output_token_logprobs) logprobs = response["meta_info"]["input_top_logprobs"] if token_ids_logprob is not None: input_token_ids_logprobs = response["meta_info"][ "input_token_ids_logprobs" ][1:] else: input_token_ids_logprobs = None num_prompt_tokens = response["meta_info"]["prompt_tokens"] assert len(input_token_logprobs) == num_prompt_tokens - logprob_start_len assert len(logprobs) == num_prompt_tokens - logprob_start_len # The first token logprob has no meaning in sglang. input_token_logprobs = input_token_logprobs[1:] logprobs = logprobs[1:] assert len(input_token_logprobs) == len(logprobs) input_token_logprobs_lst.append( input_token_logprobs + [output_token_logprobs[0]] ) output_token_logprobs_lst.append(output_token_logprobs) top_input_logprobs.append( [[tup[0] for tup in x[:NUM_TOP_LOGPROBS]] for x in logprobs] + [ [ tup[0] for tup in response["meta_info"]["output_top_logprobs"][0][ :NUM_TOP_LOGPROBS ] ] ] ) top_output_logprobs.append( [ [tup[0] for tup in x[:NUM_TOP_LOGPROBS]] for x in response["meta_info"]["output_top_logprobs"] ] ) top_output_logprob_idx.append( [ [tup[1] for tup in x[:NUM_TOP_LOGPROBS]] for x in response["meta_info"]["output_top_logprobs"] ] ) if token_ids_logprob is not None: token_ids_input_logprobs.append( [[tup[0] for tup in x] for x in input_token_ids_logprobs] + [ [ tup[0] for tup in response["meta_info"][ "output_token_ids_logprobs" ][0] ] ] ) token_ids_output_logprobs.append( [ [tup[0] for tup in x] for x in response["meta_info"]["output_token_ids_logprobs"] ] ) return ModelOutput( output_strs=output_strs, output_ids=output_ids, top_input_logprobs=top_input_logprobs, top_output_logprobs=top_output_logprobs, input_token_logprobs_lst=input_token_logprobs_lst, output_token_logprobs_lst=output_token_logprobs_lst, top_output_logprob_idx=top_output_logprob_idx, token_ids_input_logprobs=token_ids_input_logprobs, token_ids_output_logprobs=token_ids_output_logprobs, ) @staticmethod def batch_forward_generation_raw( prompts: Union[List[str], List[torch.Tensor]], max_new_tokens, lora_paths, engine, ): # the return value contains logprobs from prefill output_strs = [] sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0} response = engine.generate( prompts, lora_path=lora_paths if lora_paths else None, sampling_params=sampling_params, ) output_strs = [r["text"] for r in response] return ModelOutput( output_strs=output_strs, ) def monkey_patch_gemma2_sdpa(): """ Use sdpa by default to fix the OOM issue. Revert this commit: https://github.com/huggingface/transformers/commit/975b988bfe6e7ebb47390cd9a1556c6888804883#diff-5f76eac6f18f4b491521314c318a9692318feb4d19228e9576cce7bde4240834R660 """ from transformers.models.gemma2.modeling_gemma2 import Gemma2PreTrainedModel def _check_and_enable_sdpa(config, hard_check_only: bool = False): config._attn_implementation = "sdpa" return config setattr(Gemma2PreTrainedModel, "_check_and_enable_sdpa", _check_and_enable_sdpa) def check_close_model_outputs( hf_outputs: ModelOutput, srt_outputs: ModelOutput, prefill_tolerance: float, decode_tolerance: float, rouge_l_tolerance: float, debug_text: str = "", check_logprobs: bool = True, ): # Compare output strings print(f"{hf_outputs.output_strs=}") print(f"{srt_outputs.output_strs=}") rouge_l_scores = calculate_rouge_l(hf_outputs.output_strs, srt_outputs.output_strs) print(f"{rouge_l_scores=}") assert all( score >= rouge_l_tolerance for score in rouge_l_scores ), f"Not all ROUGE-L scores are greater than rouge_l_tolerance={rouge_l_tolerance}" if check_logprobs: for i in range(len(hf_outputs.output_strs)): # Compare input logprobs hf_logprobs = torch.Tensor(hf_outputs.top_input_logprobs[i]) srt_logprobs = torch.Tensor(srt_outputs.top_input_logprobs[i]) input_len = hf_logprobs.shape[0] print( "prefill logprobs max_diff", torch.max(abs(hf_logprobs - srt_logprobs)) ) if input_len <= 100: assert torch.all(abs(hf_logprobs - srt_logprobs) < prefill_tolerance), ( f"prefill logprobs are not all close with {debug_text} " f"prefill_tolerance={prefill_tolerance}." f"{hf_logprobs=}, {srt_logprobs=}" ) # Compare output logprobs hf_logprobs = torch.Tensor(hf_outputs.top_output_logprobs[i]) srt_logprobs = torch.Tensor(srt_outputs.top_output_logprobs[i]) print( "decode logprobs max_diff", torch.max(abs(hf_logprobs - srt_logprobs)) ) if input_len <= 100: assert torch.all(abs(hf_logprobs - srt_logprobs) < decode_tolerance), ( f"decode logprobs are not all close with {debug_text} " f"decode_tolerance={decode_tolerance}." f"{hf_logprobs=}, {srt_logprobs=}" )