sglang_v0.5.2/sglang/test/srt/test_fim_completion.py

73 lines
2.2 KiB
Python

import unittest
import openai
from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
)
class TestFimCompletion(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = "deepseek-ai/deepseek-coder-1.3b-base"
cls.base_url = DEFAULT_URL_FOR_TEST
cls.api_key = "sk-123456"
other_args = ["--completion-template", "deepseek_coder"]
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
api_key=cls.api_key,
other_args=other_args,
)
cls.base_url += "/v1"
cls.tokenizer = get_tokenizer(cls.model)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def run_fim_completion(self, number_of_completion):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
prompt = "function sum(a: number, b: number): number{\n"
suffix = "}"
prompt_input = self.tokenizer.encode(prompt) + self.tokenizer.encode(suffix)
num_prompt_tokens = len(prompt_input) + 2
response = client.completions.create(
model=self.model,
prompt=prompt,
suffix=suffix,
temperature=0.3,
max_tokens=32,
stream=False,
n=number_of_completion,
)
print(response)
print(len(response.choices))
assert len(response.choices) == number_of_completion
assert response.id
assert response.created
assert response.object == "text_completion"
assert (
response.usage.prompt_tokens == num_prompt_tokens
), f"{response.usage.prompt_tokens} vs {num_prompt_tokens}"
assert response.usage.completion_tokens > 0
assert response.usage.total_tokens > 0
def test_fim_completion(self):
for number_of_completion in [1, 3]:
self.run_fim_completion(number_of_completion)
if __name__ == "__main__":
unittest.main()