sglang.0.4.8.post1/meta_ui.py

306 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json, datetime, textwrap, requests, gradio as gr
from pathlib import Path
from collections import deque
import queue, threading, time
# ────────────────── 基础配置 ──────────────────
API_KEY = "token-abc123"
MODEL_PATH = Path("/root/.cradle/Alibaba/Qwen3-30B-A3B-Base")
def model_name(path: Path):
cfg = path / "config.json"
if cfg.exists():
data = json.load(cfg.open())
return data.get("architectures", [None])[0] or data.get("model_type") or path.name
return path.name
MODEL_NAME = model_name(MODEL_PATH)
now = lambda: datetime.datetime.now().strftime("%H:%M:%S")
# ────────────────── 日志队列 ──────────────────
LOG_Q: "queue.Queue[str]" = queue.Queue()
LOG_TXT = ""
def log(msg):
print(msg, flush=True)
LOG_Q.put(msg)
prev_log_value = ""
def consume_logs(dummy=None):
global LOG_TXT, prev_log_value
buf = deque(LOG_TXT.splitlines(), maxlen=400)
while not LOG_Q.empty():
buf.append(LOG_Q.get())
LOG_TXT = "\n".join(buf)
if LOG_TXT != prev_log_value:
prev_log_value = LOG_TXT
return gr.update(value=LOG_TXT)
return gr.update()
# ────────────────── 后端调用 ──────────────────
def backend(text, sampling, api_suffix):
url = f"http://localhost:30000{api_suffix}"
if api_suffix == "/generate":
payload = {"model": MODEL_NAME, "text": text, "sampling_params": sampling}
elif api_suffix == "/v1/completions":
payload = {
"model": MODEL_NAME,
"prompt": text,
**sampling
}
elif api_suffix == "/v1/chat/completions":
payload = {
"model": MODEL_NAME,
"messages": text, # ← 这里 text 实际是 messages list
**sampling
}
log(f"\n🟡 [{now()}] POST {url}\n{json.dumps(payload, ensure_ascii=False, indent=2)}")
try:
r = requests.post(url,
headers={"Authorization": f"Bearer {API_KEY}",
"Content-Type": "application/json"},
json=payload, timeout=180)
try:
data = r.json()
except Exception:
data = {}
if api_suffix == "/generate":
txt = data.get("text", "").strip()
meta = data.get("meta_info", {})
fr = meta.get("finish_reason")
ctok = meta.get("completion_tokens")
elif api_suffix == "/v1/completions":
choice = data.get("choices", [{}])[0]
txt = choice.get("text", "").strip()
fr = choice.get("finish_reason")
ctok = data.get("usage", {}).get("completion_tokens")
elif api_suffix == "/v1/chat/completions":
choice = data.get("choices", [{}])[0]
msg = choice.get("message", {})
txt = msg.get("content", "").strip()
# 新增:从 usage 获取 completion_tokens
ctok = data.get("usage", {}).get("completion_tokens")
fr = choice.get("finish_reason") # 如果后续需要 finish reason
log(f"🟢 [{now()}] HTTP {r.status_code} tokens={ctok} finish={fr}\n"
f"🟢 resp800={r.text[:800]!r}")
if r.status_code != 200:
return f"[HTTP {r.status_code}] {r.text[:300]}"
return txt or "[⚠ 空]"
except Exception as e:
log(f"[❌ 请求异常] {e}")
return f"[❌ 请求异常] {e}"
# ────────────────── Chat 回调 ──────────────────
def chat(
user_msg, history,
max_new, temp, top_p, top_k,
rep_pen, pres_pen, stop_raw,
api_suffix, log_state
):
from queue import Queue, Empty
user = user_msg["text"] if isinstance(user_msg, dict) and "text" in user_msg else user_msg
if api_suffix == "/v1/chat/completions":
# 给 LLM 的完整 history用于上下文推理
messages = history[:]
messages.append({"role": "user", "content": user})
prompt_input = messages
else:
prompt_input = user
stop = [s.strip() for s in stop_raw.split(",") if s.strip()] or None
samp = {
("max_tokens" if api_suffix == "/v1/completions" else "max_new_tokens"): int(max_new),
"temperature": temp,
"top_p": top_p,
"top_k": int(top_k),
"repetition_penalty": rep_pen,
"presence_penalty": pres_pen,
**({"stop": stop} if stop else {})
}
result_q = Queue()
def worker():
out = backend(prompt_input, samp, api_suffix)
result_q.put(out)
thread = threading.Thread(target=worker, daemon=True)
thread.start()
if api_suffix == "/v1/chat/completions":
while True:
if not thread.is_alive() and result_q.empty():
break
try:
result = result_q.get(timeout=0.1)
except Empty:
continue
if isinstance(result, str):
result = {"text": result}
elif not isinstance(result, dict) or "text" not in result:
result = {"text": str(result)}
# ❌ 不 append 到 history让前端 UI 不显示之前的历史)
# ✅ 但我们已经在前面把 history 全部传给 LLM 推理了
yield result["text"], None # UI 只显示当前回复
return
else:
while thread.is_alive():
try:
result = result_q.get(timeout=0.1)
break
except Empty:
continue
if isinstance(result, str):
result = {"text": result}
elif not isinstance(result, dict) or "text" not in result:
result = {"text": str(result)}
yield result["text"], log_state
return
# # ────────────────── Chat 回调 ──────────────────
# def chat(
# user_msg, history,
# max_new, temp, top_p, top_k,
# rep_pen, pres_pen, stop_raw,
# api_suffix, log_state
# ):
# from queue import Queue, Empty
# # 解析传入的 ChatInput 格式
# user = user_msg["text"] if isinstance(user_msg, dict) and "text" in user_msg else user_msg
# # 构造 OpenAI 风格 messages仅用于 /v1/chat/completions
# if api_suffix == "/v1/chat/completions":
# messages = []
# messages = history[:] # 正确使用 OpenAI 格式
# messages.append({"role": "user", "content": user})
# prompt_input = messages
# else:
# prompt_input = user # 原来的单轮文本 prompt
# stop = [s.strip() for s in stop_raw.split(",") if s.strip()] or None
# samp = {
# ("max_tokens" if api_suffix == "/v1/completions" else "max_new_tokens"): int(max_new),
# "temperature": temp,
# "top_p": top_p,
# "top_k": int(top_k),
# "repetition_penalty": rep_pen,
# "presence_penalty": pres_pen,
# **({"stop": stop} if stop else {})
# }
# result_q = Queue()
# def worker():
# out = backend(prompt_input, samp, api_suffix)
# result_q.put(out)
# thread = threading.Thread(target=worker, daemon=True)
# thread.start()
# if api_suffix == "/v1/chat/completions":
# while True:
# if not thread.is_alive() and result_q.empty():
# break
# try:
# result = result_q.get(timeout=0.1)
# except Empty:
# continue
# if isinstance(result, str):
# result = {"text": result}
# elif not isinstance(result, dict) or "text" not in result:
# result = {"text": str(result)}
# history.append({"role": "assistant", "content": result["text"]})
# yield result["text"], None # ✅ 显示模型输出,同时更新 history
# return
# else:
# while thread.is_alive():
# try:
# result = result_q.get(timeout=0.1)
# break
# except Empty:
# continue
# if isinstance(result, str):
# result = {"text": result}
# elif not isinstance(result, dict) or "text" not in result:
# result = {"text": str(result)}
# yield result["text"], log_state # ✅ 其它接口只输出文本,不更新 history
# return
# ────────────────── Gradio UI ──────────────────
with gr.Blocks(title="调试界面") as demo:
gr.Markdown(f"## 💬 调试界面 \n权重 **{MODEL_PATH.name}**")
with gr.Row():
api_choice = gr.Dropdown(choices=["/generate", "/v1/completions", "/v1/chat/completions"],
value="/generate", label="选择推理接口")
with gr.Row():
max_new = gr.Slider(32, 32768, 128, label="max_new_tokens")
temp = gr.Slider(0, 1.5, 0.8, step=0.05, label="temperature")
with gr.Row():
top_p = gr.Slider(0, 1, 0.95, step=0.01, label="top_p")
top_k = gr.Slider(0, 200, 50, step=1, label="top_k")
with gr.Row():
rep_pen = gr.Slider(0.8, 2, 1.05, step=0.01, label="repetition_penalty")
pres_pen= gr.Slider(0, 2, 0.0, step=0.05, label="presence_penalty")
stop_txt = gr.Textbox("", label="stop 序列(逗号分隔)")
log_state = gr.State("")
dbg_chk = gr.Checkbox(label="📜 显示 Debug 面板", value=False)
log_box = gr.Textbox(label="实时日志", lines=20, interactive=False, visible=False)
chat = gr.ChatInterface(
fn=chat,
additional_inputs=[max_new, temp, top_p, top_k,
rep_pen, pres_pen, stop_txt,
api_choice, log_state],
additional_outputs=[log_state],
type="messages"
)
timer = gr.Timer(1.0, render=True)
timer.tick(
fn=consume_logs,
inputs=[],
outputs=[log_box],
)
def clear_all_logs(_):
global LOG_Q, LOG_TXT, prev_log_value
with LOG_Q.mutex:
LOG_Q.queue.clear()
LOG_TXT = ""
prev_log_value = ""
return gr.update(value=""), gr.update(value="")
api_choice.change(fn=clear_all_logs, inputs=api_choice, outputs=[log_state, log_box])
log_state.change(lambda txt: gr.update(value=txt), log_state, log_box)
dbg_chk.change(lambda v: gr.update(visible=v), dbg_chk, log_box)
demo.launch(server_name="0.0.0.0", server_port=30001)