98 lines
3.5 KiB
Python
98 lines
3.5 KiB
Python
import unittest
|
|
|
|
import openai
|
|
|
|
from sglang.srt.utils import kill_process_tree
|
|
from sglang.test.test_utils import (
|
|
DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST,
|
|
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
|
DEFAULT_URL_FOR_TEST,
|
|
CustomTestCase,
|
|
popen_launch_server,
|
|
)
|
|
|
|
|
|
class TestOpenAIEmbedding(CustomTestCase):
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
cls.model = DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST
|
|
cls.base_url = DEFAULT_URL_FOR_TEST
|
|
cls.api_key = "sk-123456"
|
|
|
|
# Configure embedding-specific args
|
|
other_args = ["--is-embedding", "--enable-metrics"]
|
|
cls.process = popen_launch_server(
|
|
cls.model,
|
|
cls.base_url,
|
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
|
api_key=cls.api_key,
|
|
other_args=other_args,
|
|
)
|
|
cls.base_url += "/v1"
|
|
|
|
@classmethod
|
|
def tearDownClass(cls):
|
|
kill_process_tree(cls.process.pid)
|
|
|
|
def test_embedding_single(self):
|
|
"""Test single embedding request"""
|
|
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
|
response = client.embeddings.create(model=self.model, input="Hello world")
|
|
self.assertEqual(len(response.data), 1)
|
|
self.assertTrue(len(response.data[0].embedding) > 0)
|
|
|
|
def test_embedding_batch(self):
|
|
"""Test batch embedding request"""
|
|
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
|
response = client.embeddings.create(
|
|
model=self.model, input=["Hello world", "Test text"]
|
|
)
|
|
self.assertEqual(len(response.data), 2)
|
|
self.assertTrue(len(response.data[0].embedding) > 0)
|
|
self.assertTrue(len(response.data[1].embedding) > 0)
|
|
|
|
def test_embedding_single_batch_str(self):
|
|
"""Test embedding with a List[str] and length equals to 1"""
|
|
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
|
response = client.embeddings.create(model=self.model, input=["Hello world"])
|
|
self.assertEqual(len(response.data), 1)
|
|
self.assertTrue(len(response.data[0].embedding) > 0)
|
|
|
|
def test_embedding_single_int_list(self):
|
|
"""Test embedding with a List[int] or List[List[int]]]"""
|
|
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
|
response = client.embeddings.create(
|
|
model=self.model,
|
|
input=[[15339, 314, 703, 284, 612, 262, 10658, 10188, 286, 2061]],
|
|
)
|
|
self.assertEqual(len(response.data), 1)
|
|
self.assertTrue(len(response.data[0].embedding) > 0)
|
|
|
|
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
|
response = client.embeddings.create(
|
|
model=self.model,
|
|
input=[15339, 314, 703, 284, 612, 262, 10658, 10188, 286, 2061],
|
|
)
|
|
self.assertEqual(len(response.data), 1)
|
|
self.assertTrue(len(response.data[0].embedding) > 0)
|
|
|
|
def test_empty_string_embedding(self):
|
|
"""Test embedding an empty string."""
|
|
|
|
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
|
|
|
# Text embedding example with empty string
|
|
text = ""
|
|
# Expect a BadRequestError for empty input
|
|
with self.assertRaises(openai.BadRequestError) as cm:
|
|
client.embeddings.create(
|
|
model=self.model,
|
|
input=text,
|
|
)
|
|
# check the status code
|
|
self.assertEqual(cm.exception.status_code, 400)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|