"""Test original log probability alignment between SGLang and Hugging Face. This test suite verifies the correctness of the `origin_logprobs` output (temperature=1) and the `logprobs` output (temperature=0.5) in SGLang by comparing it against raw logit-based probabilities computed directly from a reference Hugging Face model. The test covers the following scenarios: - Next-token prediction: Verifies that the log probability of the next token from SGLang matches the Hugging Face model. - Top-k logprobs: Ensures that the top-k original logprobs returned by SGLang are consistent with Hugging Face outputs. - Specified token IDs: Confirms that the original logprobs for specific token IDs match the values computed from Hugging Face logits. """ import os import random import unittest import numpy as np import torch import torch.nn.functional as F from transformers import AutoModelForCausalLM, AutoTokenizer import sglang as sgl from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST # ------------------------- Configurable via env ------------------------- # MODEL_ID = DEFAULT_SMALL_MODEL_NAME_FOR_TEST PROMPTS = [ "Hello, my name is", "The future of AI is", "The president of the United States is", "The capital of France is ", ] TOP_LOGPROBS_NUM = 50 NUM_RANDOM_TOKEN_IDS = 10 RTOL = 0.20 ATOL = 0.00 # ------------------------------------------------ torch.manual_seed(1234) if torch.cuda.is_available(): torch.cuda.manual_seed_all(1234) torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cudnn.allow_tf32 = False class TestOriginalLogprob(unittest.TestCase): def setUp(self): # ----- HF side (float32 weights) ----- self.tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, padding_side="right") self.hf_model = AutoModelForCausalLM.from_pretrained( MODEL_ID, torch_dtype=torch.float32, device_map="auto" ) # Shared sampling parameters self.sampling_params = { "temperature": 0.5, # SGLang uses 0.5, but original logprobs are used 1.0 "top_p": 1.0, "top_k": 10, "max_new_tokens": 1, } # --------------------------------------------------------------------- # Helper: compare one SGLang block (token_logprobs / top_logprobs / ids_logprobs) # against a reference HF log‑prob vector. # --------------------------------------------------------------------- def assert_logprobs_block_equal( self, hf_log_probs: torch.Tensor, # [V] token_log_probs: list, top_log_probs: list, ids_log_probs: list, random_token_ids: list, tag: str = "", ): vals, idxs, _ = zip(*token_log_probs) sgl_vals = torch.tensor(vals, device=self.hf_model.device, dtype=torch.float32) sgl_idxs = torch.tensor(idxs, device=self.hf_model.device, dtype=torch.long) hf_vals = hf_log_probs[sgl_idxs] self.assertTrue( torch.allclose(hf_vals, sgl_vals, rtol=RTOL, atol=ATOL), msg=f"[{tag}] token‑level mismatch at indices {sgl_idxs.tolist()}", ) hf_topk, _ = torch.topk(hf_log_probs, k=TOP_LOGPROBS_NUM, dim=-1) sgl_topk = torch.tensor( [float(t[0]) for t in top_log_probs[0] if t and t[0] is not None][ :TOP_LOGPROBS_NUM ], dtype=torch.float32, device=self.hf_model.device, ) k = min(hf_topk.numel(), sgl_topk.numel()) self.assertTrue( torch.allclose(hf_topk[:k], sgl_topk[:k], rtol=RTOL, atol=ATOL), msg=f"[{tag}] top‑k mismatch", ) indices = torch.tensor( random_token_ids, dtype=torch.long, device=hf_log_probs.device ) hf_token_ids = hf_log_probs[indices] sgl_token_ids = torch.tensor( [v for v, _, _ in ids_log_probs[0]], device=self.hf_model.device, dtype=torch.float32, ) self.assertTrue( torch.allclose(hf_token_ids, sgl_token_ids, rtol=RTOL, atol=ATOL), msg=f"[{tag}] token‑IDs mismatch", ) # Optional: print max abs diff for quick diagnostics max_diff = torch.max(torch.abs(hf_vals - sgl_vals)).item() print(f"[{tag}] max|diff| token‑level = {max_diff:.4f}") def test_logprob_match(self): vocab_size = self.tokenizer.vocab_size for env_val in ["True", "False"]: with self.subTest(return_original_logprob=env_val): os.environ["RETURN_ORIGINAL_LOGPROB"] = env_val # ----- SGLang side ----- sgl_engine = sgl.Engine( model_path=MODEL_ID, skip_tokenizer_init=True, trust_remote_code=True, mem_fraction_static=0.60, ) for prompt in PROMPTS: random_token_ids = sorted( random.sample(range(vocab_size), NUM_RANDOM_TOKEN_IDS) ) enc = self.tokenizer(prompt, return_tensors="pt") input_ids = enc["input_ids"].to(self.hf_model.device) attn_mask = enc["attention_mask"].to(self.hf_model.device) with torch.inference_mode(): hf_out = self.hf_model( input_ids=input_ids, attention_mask=attn_mask, return_dict=True, ) logits = hf_out.logits[:, -1, :] # [1, V] hf_log_probs = F.log_softmax( logits.float() / self.sampling_params["temperature"], dim=-1 )[0] hf_original_log_probs = F.log_softmax(logits.float(), dim=-1)[0] outputs = sgl_engine.generate( input_ids=input_ids[0].tolist(), sampling_params=self.sampling_params, return_logprob=True, top_logprobs_num=TOP_LOGPROBS_NUM, token_ids_logprob=random_token_ids, ) if isinstance(outputs, list): outputs = outputs[0] meta = outputs["meta_info"] # Check original logprobs only if enabled if env_val.lower() == "true": self.assert_logprobs_block_equal( hf_log_probs=hf_original_log_probs, token_log_probs=meta["output_token_logprobs"], top_log_probs=meta["output_top_logprobs"], ids_log_probs=meta["output_token_ids_logprobs"], random_token_ids=random_token_ids, tag=f"Original logprobs SGLang vs HF: {prompt} ({env_val})", ) else: # Always check regular logprobs self.assert_logprobs_block_equal( hf_log_probs=hf_log_probs, token_log_probs=meta["output_token_logprobs"], top_log_probs=meta["output_top_logprobs"], ids_log_probs=meta["output_token_ids_logprobs"], random_token_ids=random_token_ids, tag=f"logprobs SGLang vs HF: {prompt} ({env_val})", ) sgl_engine.shutdown() if __name__ == "__main__": unittest.main()