This commit is contained in:
hailin 2025-08-01 09:28:41 +08:00
parent f86051512d
commit 26f8dc9ab5
1 changed files with 27 additions and 3 deletions

View File

@ -53,6 +53,12 @@ def backend(text, sampling, api_suffix):
"prompt": text,
**sampling
}
elif api_suffix == "/v1/chat/completions":
payload = {
"model": MODEL_NAME,
"messages": text, # ← 这里 text 实际是 messages list
**sampling
}
log(f"\n🟡 [{now()}] POST {url}\n{json.dumps(payload, ensure_ascii=False, indent=2)}")
try:
@ -75,6 +81,10 @@ def backend(text, sampling, api_suffix):
txt = choice.get("text", "").strip()
fr = choice.get("finish_reason")
ctok = data.get("usage", {}).get("completion_tokens")
elif api_suffix == "/v1/chat/completions":
choice = data.get("choices", [{}])[0]
msg = choice.get("message", {})
txt = msg.get("content", "").strip()
log(f"🟢 [{now()}] HTTP {r.status_code} tokens={ctok} finish={fr}\n"
f"🟢 resp800={r.text[:800]!r}")
@ -95,6 +105,19 @@ def chat(
):
from queue import Queue, Empty
# 构造 OpenAI 风格 messages仅用于 /v1/chat/completions
if api_suffix == "/v1/chat/completions":
messages = []
for u, a in history:
messages.append({"role": "user", "content": u})
messages.append({"role": "assistant", "content": a})
user_input = user_msg["text"] if isinstance(user_msg, dict) and "text" in user_msg else user_msg
messages.append({"role": "user", "content": user_input})
prompt_input = messages
else:
prompt_input = user # 原来的单轮文本 prompt
# 解析传入的 ChatInput 格式
user = user_msg["text"] if isinstance(user_msg, dict) and "text" in user_msg else user_msg
@ -112,7 +135,8 @@ def chat(
result_q = Queue()
def worker():
out = backend(user, samp, api_suffix)
#out = backend(user, samp, api_suffix)
out = backend(prompt_input, samp, api_suffix)
result_q.put(out)
# threading.Thread(target=worker).start()
@ -161,7 +185,7 @@ with gr.Blocks(title="调试界面") as demo:
gr.Markdown(f"## 💬 调试界面 \n权重 **{MODEL_PATH.name}**")
with gr.Row():
api_choice = gr.Dropdown(choices=["/generate", "/v1/completions"],
api_choice = gr.Dropdown(choices=["/generate", "/v1/completions", "/v1/chat/completions"],
value="/generate", label="选择推理接口")
with gr.Row():
max_new = gr.Slider(32, 32768, 128, label="max_new_tokens")