sglang.0.4.8.post1/meta_ui.py

162 lines
5.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json, datetime, textwrap, requests, gradio as gr
from pathlib import Path
from collections import deque
import queue, threading, time
# ───────────────────── 基础配置 ─────────────────────
API_URL = "http://localhost:30000/generate"
API_KEY = "token-abc123"
MODEL_PATH = Path("/root/.cradle/Alibaba/Qwen3-30B-A3B-Base")
def model_name(path: Path):
cfg = path / "config.json"
if cfg.exists():
data = json.load(cfg.open())
return data.get("architectures", [None])[0] or data.get("model_type") or path.name
return path.name
MODEL_NAME = model_name(MODEL_PATH)
now = lambda: datetime.datetime.now().strftime("%H:%M:%S")
# ───────────────────── 日志队列 ─────────────────────
LOG_Q: "queue.Queue[str]" = queue.Queue()
LOG_TXT = "" # ✅ 全局日志缓存,避免 chat 焦点阻断 log_box 更新
def log(msg): # 写终端 + 推队列
print(msg, flush=True)
LOG_Q.put(msg)
prev_log_value = "" # 上一帧的日志内容
def consume_logs(dummy=None):
"""每秒更新 log_box 内容,避免 chat 阻塞 UI 刷新"""
global LOG_TXT, prev_log_value
buf = deque(LOG_TXT.splitlines(), maxlen=400)
while not LOG_Q.empty():
buf.append(LOG_Q.get())
LOG_TXT = "\n".join(buf)
if LOG_TXT != prev_log_value:
prev_log_value = LOG_TXT
return gr.update(value=LOG_TXT)
return gr.update() # 无更新则不触发前端刷新
# ───────────────────── 后端调用 ─────────────────────
def backend(text, sampling):
payload = {"model": MODEL_NAME, "text": text, "sampling_params": sampling}
log(f"\n🟡 [{now()}] payload\n{json.dumps(payload, ensure_ascii=False, indent=2)}")
try:
r = requests.post(API_URL,
headers={"Authorization": f"Bearer {API_KEY}",
"Content-Type": "application/json"},
json=payload, timeout=180)
try:
data = r.json()
except Exception:
data = {}
fr = data.get("meta_info", {}).get("finish_reason")
ctok = data.get("meta_info", {}).get("completion_tokens")
log(f"🟢 [{now()}] HTTP {r.status_code} tokens={ctok} finish={fr}\n"
f"🟢 resp800={r.text[:800]!r}")
if r.status_code != 200:
return f"[HTTP {r.status_code}] {r.text[:300]}"
return data.get("text", "").strip() or "[⚠ 空]"
except Exception as e:
log(f"[❌ 请求异常] {e}")
return f"[❌ 请求异常] {e}"
# ───────────────────── Chat 回调 ─────────────────────
def chat(
user, history,
max_new, temp, top_p, top_k,
rep_pen, pres_pen, stop_raw,
log_state
):
import threading
from queue import Queue, Empty
stop = [s.strip() for s in stop_raw.split(",") if s.strip()] or None
samp = {
"max_new_tokens": int(max_new),
"temperature": temp,
"top_p": top_p,
"top_k": int(top_k),
"repetition_penalty": rep_pen,
"presence_penalty": pres_pen,
**({"stop": stop} if stop else {})
}
result_q = Queue()
# 后台线程执行 backend 推理
def worker():
out = backend(user, samp)
result_q.put(out)
thread = threading.Thread(target=worker)
thread.start()
# 先返回提示
yield "⏳ 正在生成中...", log_state
# 每 0.1 秒轮询结果队列(避免阻塞 UI
while thread.is_alive() or not result_q.empty():
try:
result = result_q.get(timeout=0.1)
yield result, log_state
except Empty:
continue
# ───────────────────── Gradio UI ─────────────────────
with gr.Blocks(title="调试界面") as demo:
gr.Markdown(f"## 💬 调试界面 \n权重 **{MODEL_PATH.name}**")
# 采样参数控件
with gr.Row():
max_new = gr.Slider(32, 32768, 128, label="max_new_tokens")
temp = gr.Slider(0, 1.5, 0.8, step=0.05, label="temperature")
with gr.Row():
top_p = gr.Slider(0, 1, 0.95, step=0.01, label="top_p")
top_k = gr.Slider(0, 200, 50, step=1, label="top_k")
with gr.Row():
rep_pen = gr.Slider(0.8, 2, 1.05, step=0.01, label="repetition_penalty")
pres_pen= gr.Slider(0, 2, 0.0, step=0.05, label="presence_penalty")
stop_txt = gr.Textbox("", label="stop 序列(逗号分隔)")
log_state = gr.State("") # 状态透传
dbg_chk = gr.Checkbox(label="📜 显示 Debug 面板", value=False) # ✅ 默认关闭
log_box = gr.Textbox(label="实时日志", lines=20, interactive=False, visible=False) # ✅ 默认隐藏
# Chat 界面(移到日志之前)
# chatbot = gr.ChatInterface(
# fn=chat,
# additional_inputs=[max_new, temp, top_p, top_k,
# rep_pen, pres_pen, stop_txt, log_state],
# additional_outputs=[log_state],
# type="messages"
# )
chatbot = gr.ChatInterface(
fn=chat,
textbox=gr.Textbox(lines=2, placeholder="文本...", label="Prompt"),
submit_btn=gr.Button(value="", icon="send"),
additional_inputs=[max_new, temp, top_p, top_k, rep_pen, pres_pen, stop_txt, log_state],
additional_outputs=[log_state],
type="messages"
)
# 日志刷新定时器
timer = gr.Timer(1.0, render=True)
timer.tick(
fn=consume_logs,
inputs=[],
outputs=[log_box],
)
log_state.change(lambda txt: gr.update(value=txt), log_state, log_box)
dbg_chk.change(lambda v: gr.update(visible=v), dbg_chk, log_box)
demo.launch(server_name="0.0.0.0", server_port=30001)