sglang.0.4.8.post1/sglang/test/srt/models/test_cross_encoder_models.py

92 lines
2.6 KiB
Python

import multiprocessing as mp
import random
import unittest
import torch
from transformers import AutoConfig, AutoTokenizer
from sglang.test.runners import TEST_RERANK_QUERY_DOCS, HFRunner, SRTRunner
from sglang.test.test_utils import CustomTestCase, is_in_ci
MODELS = [
("cross-encoder/ms-marco-MiniLM-L6-v2", 1, 1e-2),
("BAAI/bge-reranker-v2-m3", 1, 1e-2),
]
ATTENTION_BACKEND = ["torch_native", "triton"]
TORCH_DTYPES = [torch.float32]
class TestCrossEncoderModels(CustomTestCase):
@classmethod
def setUpClass(cls):
mp.set_start_method("spawn", force=True)
def assert_close_prefill_logits(
self,
prompts,
model_path,
tp_size,
torch_dtype,
score_tolerance,
attention_backend,
) -> None:
with HFRunner(
model_path,
torch_dtype=torch_dtype,
model_type="cross_encoder",
) as hf_runner:
hf_scores = hf_runner.forward(prompts).scores
with SRTRunner(
model_path,
tp_size=tp_size,
torch_dtype=torch_dtype,
model_type="cross_encoder",
attention_backend=attention_backend,
chunked_prefill_size=-1,
disable_radix_cache=True,
) as srt_runner:
srt_scores = srt_runner.forward(prompts).scores
for i in range(len(srt_scores)):
score_difference = abs(hf_scores[i] - srt_scores[i])
assert (
score_difference < score_tolerance
), "cross encoder scores are not all close"
def preprocess_prompts(self, prompt):
processed_prompts = []
query = prompt["query"]
documents = prompt["documents"]
for document in documents:
processed_prompts.append([query, document])
return processed_prompts
def test_prefill_logits(self):
models_to_test = MODELS
if is_in_ci():
models_to_test = [random.choice(MODELS)]
for model, tp_size, prefill_tolerance in models_to_test:
for attention_backend in ATTENTION_BACKEND:
for queryDocs in TEST_RERANK_QUERY_DOCS:
prompts = self.preprocess_prompts(queryDocs)
for torch_dtype in TORCH_DTYPES:
self.assert_close_prefill_logits(
prompts,
model,
tp_size,
torch_dtype,
prefill_tolerance,
attention_backend,
)
if __name__ == "__main__":
unittest.main()