73 lines
2.2 KiB
Python
73 lines
2.2 KiB
Python
import unittest
|
|
|
|
import openai
|
|
|
|
from sglang.srt.hf_transformers_utils import get_tokenizer
|
|
from sglang.srt.utils import kill_process_tree
|
|
from sglang.test.test_utils import (
|
|
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
|
DEFAULT_URL_FOR_TEST,
|
|
CustomTestCase,
|
|
popen_launch_server,
|
|
)
|
|
|
|
|
|
class TestFimCompletion(CustomTestCase):
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
cls.model = "deepseek-ai/deepseek-coder-1.3b-base"
|
|
cls.base_url = DEFAULT_URL_FOR_TEST
|
|
cls.api_key = "sk-123456"
|
|
other_args = ["--completion-template", "deepseek_coder"]
|
|
cls.process = popen_launch_server(
|
|
cls.model,
|
|
cls.base_url,
|
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
|
api_key=cls.api_key,
|
|
other_args=other_args,
|
|
)
|
|
cls.base_url += "/v1"
|
|
cls.tokenizer = get_tokenizer(cls.model)
|
|
|
|
@classmethod
|
|
def tearDownClass(cls):
|
|
kill_process_tree(cls.process.pid)
|
|
|
|
def run_fim_completion(self, number_of_completion):
|
|
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
|
prompt = "function sum(a: number, b: number): number{\n"
|
|
suffix = "}"
|
|
|
|
prompt_input = self.tokenizer.encode(prompt) + self.tokenizer.encode(suffix)
|
|
num_prompt_tokens = len(prompt_input) + 2
|
|
|
|
response = client.completions.create(
|
|
model=self.model,
|
|
prompt=prompt,
|
|
suffix=suffix,
|
|
temperature=0.3,
|
|
max_tokens=32,
|
|
stream=False,
|
|
n=number_of_completion,
|
|
)
|
|
|
|
print(response)
|
|
print(len(response.choices))
|
|
assert len(response.choices) == number_of_completion
|
|
assert response.id
|
|
assert response.created
|
|
assert response.object == "text_completion"
|
|
assert (
|
|
response.usage.prompt_tokens == num_prompt_tokens
|
|
), f"{response.usage.prompt_tokens} vs {num_prompt_tokens}"
|
|
assert response.usage.completion_tokens > 0
|
|
assert response.usage.total_tokens > 0
|
|
|
|
def test_fim_completion(self):
|
|
for number_of_completion in [1, 3]:
|
|
self.run_fim_completion(number_of_completion)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|