152 lines
5.7 KiB
Python
152 lines
5.7 KiB
Python
import json
|
|
import os
|
|
import tempfile
|
|
import unittest
|
|
|
|
import requests
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
|
from sglang.srt.utils import kill_process_tree
|
|
from sglang.test.test_utils import (
|
|
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
|
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
|
DEFAULT_URL_FOR_TEST,
|
|
CustomTestCase,
|
|
popen_launch_server,
|
|
)
|
|
|
|
|
|
class TestInputEmbeds(CustomTestCase):
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
|
cls.base_url = DEFAULT_URL_FOR_TEST
|
|
cls.tokenizer = AutoTokenizer.from_pretrained(cls.model)
|
|
cls.ref_model = AutoModelForCausalLM.from_pretrained(cls.model)
|
|
cls.process = popen_launch_server(
|
|
cls.model,
|
|
cls.base_url,
|
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
|
other_args=["--disable-radix", "--cuda-graph-max-bs", 4],
|
|
)
|
|
cls.texts = [
|
|
"The capital of France is",
|
|
"What is the best time of year to visit Japan for cherry blossoms?",
|
|
]
|
|
|
|
def generate_input_embeddings(self, text):
|
|
"""Generate input embeddings for a given text."""
|
|
input_ids = self.tokenizer(text, return_tensors="pt")["input_ids"]
|
|
embeddings = self.ref_model.get_input_embeddings()(input_ids)
|
|
return embeddings.squeeze().tolist() # Convert tensor to a list for API use
|
|
|
|
def send_request(self, payload):
|
|
"""Send a POST request to the /generate endpoint and return the response."""
|
|
response = requests.post(
|
|
self.base_url + "/generate",
|
|
json=payload,
|
|
timeout=30, # Set a reasonable timeout for the API request
|
|
)
|
|
if response.status_code == 200:
|
|
return response.json()
|
|
return {
|
|
"error": f"Request failed with status {response.status_code}: {response.text}"
|
|
}
|
|
|
|
def send_file_request(self, file_path):
|
|
"""Send a POST request to the /generate_from_file endpoint with a file."""
|
|
with open(file_path, "rb") as f:
|
|
response = requests.post(
|
|
self.base_url + "/generate_from_file",
|
|
files={"file": f},
|
|
timeout=30, # Set a reasonable timeout for the API request
|
|
)
|
|
if response.status_code == 200:
|
|
return response.json()
|
|
return {
|
|
"error": f"Request failed with status {response.status_code}: {response.text}"
|
|
}
|
|
|
|
def test_text_based_response(self):
|
|
"""Test and print API responses using text-based input."""
|
|
for text in self.texts:
|
|
payload = {
|
|
"model": self.model,
|
|
"text": text,
|
|
"sampling_params": {"temperature": 0, "max_new_tokens": 50},
|
|
}
|
|
response = self.send_request(payload)
|
|
print(
|
|
f"Text Input: {text}\nResponse: {json.dumps(response, indent=2)}\n{'-' * 80}"
|
|
)
|
|
|
|
def test_embedding_based_response(self):
|
|
"""Test and print API responses using input embeddings."""
|
|
for text in self.texts:
|
|
embeddings = self.generate_input_embeddings(text)
|
|
payload = {
|
|
"model": self.model,
|
|
"input_embeds": embeddings,
|
|
"sampling_params": {"temperature": 0, "max_new_tokens": 50},
|
|
}
|
|
response = self.send_request(payload)
|
|
print(
|
|
f"Embeddings Input (for text '{text}'):\nResponse: {json.dumps(response, indent=2)}\n{'-' * 80}"
|
|
)
|
|
|
|
def test_compare_text_vs_embedding(self):
|
|
"""Test and compare responses for text-based and embedding-based inputs."""
|
|
for text in self.texts:
|
|
# Text-based payload
|
|
text_payload = {
|
|
"model": self.model,
|
|
"text": text,
|
|
"sampling_params": {"temperature": 0, "max_new_tokens": 50},
|
|
}
|
|
# Embedding-based payload
|
|
embeddings = self.generate_input_embeddings(text)
|
|
embed_payload = {
|
|
"model": self.model,
|
|
"input_embeds": embeddings,
|
|
"sampling_params": {"temperature": 0, "max_new_tokens": 50},
|
|
}
|
|
# Get responses
|
|
text_response = self.send_request(text_payload)
|
|
embed_response = self.send_request(embed_payload)
|
|
# Print responses
|
|
print(
|
|
f"Text Input: {text}\nText-Based Response: {json.dumps(text_response, indent=2)}\n"
|
|
)
|
|
print(
|
|
f"Embeddings Input (for text '{text}'):\nEmbedding-Based Response: {json.dumps(embed_response, indent=2)}\n{'-' * 80}"
|
|
)
|
|
# This is flaky, so we skip this temporarily
|
|
# self.assertEqual(text_response["text"], embed_response["text"])
|
|
|
|
def test_generate_from_file(self):
|
|
"""Test the /generate_from_file endpoint using tokenized embeddings."""
|
|
for text in self.texts:
|
|
embeddings = self.generate_input_embeddings(text)
|
|
with tempfile.NamedTemporaryFile(
|
|
mode="w", suffix=".json", delete=False
|
|
) as tmp_file:
|
|
json.dump(embeddings, tmp_file)
|
|
tmp_file_path = tmp_file.name
|
|
|
|
try:
|
|
response = self.send_file_request(tmp_file_path)
|
|
print(
|
|
f"Text Input: {text}\nResponse from /generate_from_file: {json.dumps(response, indent=2)}\n{'-' * 80}"
|
|
)
|
|
finally:
|
|
# Ensure the temporary file is deleted
|
|
os.remove(tmp_file_path)
|
|
|
|
@classmethod
|
|
def tearDownClass(cls):
|
|
kill_process_tree(cls.process.pid)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|