sglang_v0.5.2/sglang/test/srt/test_original_logprobs.py

197 lines
7.6 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Test original log probability alignment between SGLang and Hugging Face.
This test suite verifies the correctness of the `origin_logprobs` output (temperature=1)
and the `logprobs` output (temperature=0.5) in SGLang by comparing it against
raw logit-based probabilities computed directly from a reference Hugging Face model.
The test covers the following scenarios:
- Next-token prediction: Verifies that the log probability of the next token from
SGLang matches the Hugging Face model.
- Top-k logprobs: Ensures that the top-k original logprobs returned by SGLang are
consistent with Hugging Face outputs.
- Specified token IDs: Confirms that the original logprobs for specific token IDs
match the values computed from Hugging Face logits.
"""
import os
import random
import unittest
import numpy as np
import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer
import sglang as sgl
from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST
# ------------------------- Configurable via env ------------------------- #
MODEL_ID = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
PROMPTS = [
"Hello, my name is",
"The future of AI is",
"The president of the United States is",
"The capital of France is ",
]
TOP_LOGPROBS_NUM = 50
NUM_RANDOM_TOKEN_IDS = 10
RTOL = 0.20
ATOL = 0.00
# ------------------------------------------------
torch.manual_seed(1234)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(1234)
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
class TestOriginalLogprob(unittest.TestCase):
def setUp(self):
# ----- HF side (float32 weights) -----
self.tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, padding_side="right")
self.hf_model = AutoModelForCausalLM.from_pretrained(
MODEL_ID, torch_dtype=torch.float32, device_map="auto"
)
# Shared sampling parameters
self.sampling_params = {
"temperature": 0.5, # SGLang uses 0.5, but original logprobs are used 1.0
"top_p": 1.0,
"top_k": 10,
"max_new_tokens": 1,
}
# ---------------------------------------------------------------------
# Helper: compare one SGLang block (token_logprobs / top_logprobs / ids_logprobs)
# against a reference HF logprob vector.
# ---------------------------------------------------------------------
def assert_logprobs_block_equal(
self,
hf_log_probs: torch.Tensor, # [V]
token_log_probs: list,
top_log_probs: list,
ids_log_probs: list,
random_token_ids: list,
tag: str = "",
):
vals, idxs, _ = zip(*token_log_probs)
sgl_vals = torch.tensor(vals, device=self.hf_model.device, dtype=torch.float32)
sgl_idxs = torch.tensor(idxs, device=self.hf_model.device, dtype=torch.long)
hf_vals = hf_log_probs[sgl_idxs]
self.assertTrue(
torch.allclose(hf_vals, sgl_vals, rtol=RTOL, atol=ATOL),
msg=f"[{tag}] tokenlevel mismatch at indices {sgl_idxs.tolist()}",
)
hf_topk, _ = torch.topk(hf_log_probs, k=TOP_LOGPROBS_NUM, dim=-1)
sgl_topk = torch.tensor(
[float(t[0]) for t in top_log_probs[0] if t and t[0] is not None][
:TOP_LOGPROBS_NUM
],
dtype=torch.float32,
device=self.hf_model.device,
)
k = min(hf_topk.numel(), sgl_topk.numel())
self.assertTrue(
torch.allclose(hf_topk[:k], sgl_topk[:k], rtol=RTOL, atol=ATOL),
msg=f"[{tag}] topk mismatch",
)
indices = torch.tensor(
random_token_ids, dtype=torch.long, device=hf_log_probs.device
)
hf_token_ids = hf_log_probs[indices]
sgl_token_ids = torch.tensor(
[v for v, _, _ in ids_log_probs[0]],
device=self.hf_model.device,
dtype=torch.float32,
)
self.assertTrue(
torch.allclose(hf_token_ids, sgl_token_ids, rtol=RTOL, atol=ATOL),
msg=f"[{tag}] tokenIDs mismatch",
)
# Optional: print max abs diff for quick diagnostics
max_diff = torch.max(torch.abs(hf_vals - sgl_vals)).item()
print(f"[{tag}] max|diff| tokenlevel = {max_diff:.4f}")
def test_logprob_match(self):
vocab_size = self.tokenizer.vocab_size
for env_val in ["True", "False"]:
with self.subTest(return_original_logprob=env_val):
os.environ["RETURN_ORIGINAL_LOGPROB"] = env_val
# ----- SGLang side -----
sgl_engine = sgl.Engine(
model_path=MODEL_ID,
skip_tokenizer_init=True,
trust_remote_code=True,
mem_fraction_static=0.60,
)
for prompt in PROMPTS:
random_token_ids = sorted(
random.sample(range(vocab_size), NUM_RANDOM_TOKEN_IDS)
)
enc = self.tokenizer(prompt, return_tensors="pt")
input_ids = enc["input_ids"].to(self.hf_model.device)
attn_mask = enc["attention_mask"].to(self.hf_model.device)
with torch.inference_mode():
hf_out = self.hf_model(
input_ids=input_ids,
attention_mask=attn_mask,
return_dict=True,
)
logits = hf_out.logits[:, -1, :] # [1, V]
hf_log_probs = F.log_softmax(
logits.float() / self.sampling_params["temperature"], dim=-1
)[0]
hf_original_log_probs = F.log_softmax(logits.float(), dim=-1)[0]
outputs = sgl_engine.generate(
input_ids=input_ids[0].tolist(),
sampling_params=self.sampling_params,
return_logprob=True,
top_logprobs_num=TOP_LOGPROBS_NUM,
token_ids_logprob=random_token_ids,
)
if isinstance(outputs, list):
outputs = outputs[0]
meta = outputs["meta_info"]
# Check original logprobs only if enabled
if env_val.lower() == "true":
self.assert_logprobs_block_equal(
hf_log_probs=hf_original_log_probs,
token_log_probs=meta["output_token_logprobs"],
top_log_probs=meta["output_top_logprobs"],
ids_log_probs=meta["output_token_ids_logprobs"],
random_token_ids=random_token_ids,
tag=f"Original logprobs SGLang vs HF: {prompt} ({env_val})",
)
else:
# Always check regular logprobs
self.assert_logprobs_block_equal(
hf_log_probs=hf_log_probs,
token_log_probs=meta["output_token_logprobs"],
top_log_probs=meta["output_top_logprobs"],
ids_log_probs=meta["output_token_ids_logprobs"],
random_token_ids=random_token_ids,
tag=f"logprobs SGLang vs HF: {prompt} ({env_val})",
)
sgl_engine.shutdown()
if __name__ == "__main__":
unittest.main()