sglang.0.4.8.post1/meta_ui.py

129 lines
5.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os, json, datetime, textwrap, requests, gradio as gr
from pathlib import Path
from collections import deque
import threading, time, queue
# ────────────────────────────────────
# 1. 服务端 & 权重路径
# ────────────────────────────────────
API_URL = "http://localhost:30000/generate"
API_KEY = "token-abc123"
MODEL_PATH = Path("/root/.cradle/Alibaba/Qwen3-30B-A3B-Base") # ← 换成你的权重目录
def detect_model_name(path: Path) -> str:
cfg = path / "config.json"
if cfg.exists():
data = json.load(cfg.open())
return data.get("architectures", [None])[0] or data.get("model_type") or path.name
return path.name
MODEL_NAME = detect_model_name(MODEL_PATH)
now = lambda: datetime.datetime.now().strftime("%H:%M:%S")
# ────────────────────────────────────
# 2. 日志队列UI 实时查看)
# ────────────────────────────────────
LOG_Q: "queue.Queue[str]" = queue.Queue()
def log(msg: str):
print(msg, flush=True) # 写到 stdout
LOG_Q.put(msg) # 送到 UI
def log_worker(log_box: gr.Textbox):
buf = deque(maxlen=400) # 最近 400 行
while True:
line = LOG_Q.get()
buf.append(line)
log_box.value = "\n".join(buf)
# ────────────────────────────────────
# 3. 调用 /generate
# ────────────────────────────────────
def call_backend(text: str, sampling: dict):
payload = {
"model": MODEL_NAME,
"text": text,
"sampling_params": sampling
}
log(f"\n🟡 [{now()}] payload ↓\n{json.dumps(payload, ensure_ascii=False, indent=2)}")
try:
resp = requests.post(
API_URL,
headers={"Authorization": f"Bearer {API_KEY}",
"Content-Type": "application/json"},
json=payload, timeout=180
)
status = resp.status_code
body = resp.text
try:
data = resp.json()
except Exception:
data = {}
finish = data.get("meta_info", {}).get("finish_reason")
c_tok = data.get("meta_info", {}).get("completion_tokens")
log(f"🟢 [{now()}] HTTP {status} "
f"completion_tokens={c_tok} finish_reason={finish}\n"
f"🟢 resp 前 800 字 ↓\n{body[:800]!r}")
if status != 200:
return f"[HTTP {status}] {body[:300]}"
return data.get("text", "").strip() or "[⚠ 后端返回空文本]"
except Exception as e:
log(f"[❌ 请求异常] {e}")
return f"[❌ 请求异常] {e}"
# ────────────────────────────────────
# 4. ChatInterface 回调
# ────────────────────────────────────
def chat(
user_msg, history,
max_new, temp, top_p, top_k,
rep_pen, pres_pen, stop_raw
):
stop = [s.strip() for s in stop_raw.split(",") if s.strip()] or None
sampling = {
"max_new_tokens": int(max_new),
"temperature": temp,
"top_p": top_p,
"top_k": int(top_k),
"repetition_penalty": rep_pen,
"presence_penalty": pres_pen,
**({"stop": stop} if stop else {})
}
return call_backend(user_msg, sampling)
# ────────────────────────────────────
# 5. Gradio UI
# ────────────────────────────────────
with gr.Blocks(title="Base 模型调试界面") as demo:
gr.Markdown(f"## 💬 Base 模型调试界面 \n权重 **{MODEL_PATH.name}**")
# 采样控件
with gr.Row():
max_new = gr.Slider(32, 32768, 2048, label="max_new_tokens")
temp = gr.Slider(0, 1.5, 0.8, step=0.05, label="temperature")
with gr.Row():
top_p = gr.Slider(0, 1, 0.95, step=0.01, label="top_p")
top_k = gr.Slider(0, 200, 50, step=1, label="top_k")
with gr.Row():
rep_pen = gr.Slider(0.8, 2.0, 1.05, step=0.01, label="repetition_penalty")
pres_pen= gr.Slider(0, 2.0, 0.0, step=0.05, label="presence_penalty")
stop_text = gr.Textbox("", label="stop 序列(逗号分隔)")
dbg_toggle = gr.Checkbox(label="📜 显示 / 隐藏 Debug 面板", value=False)
dbg_box = gr.Textbox(label="实时日志", lines=20, visible=False)
# Chatbot
chatbot = gr.ChatInterface(
fn=chat,
additional_inputs=[max_new, temp, top_p, top_k, rep_pen, pres_pen, stop_text],
type="messages"
)
# Toggle 日志面板可见性
dbg_toggle.change(lambda v: gr.update(visible=v), dbg_toggle, dbg_box)
# 启动后台日志线程
threading.Thread(target=log_worker, args=(dbg_box,), daemon=True).start()
demo.launch(server_name="0.0.0.0", server_port=30001)