From 82e5957f8e4b376e0c51cc4276824bd5052524e5 Mon Sep 17 00:00:00 2001 From: hailin Date: Sun, 27 Jul 2025 12:42:37 +0800 Subject: [PATCH] . --- meta_ui.py | 145 +++++++++++++++++++++++++---------------------------- 1 file changed, 68 insertions(+), 77 deletions(-) diff --git a/meta_ui.py b/meta_ui.py index 3f887f7..db67538 100644 --- a/meta_ui.py +++ b/meta_ui.py @@ -1,103 +1,86 @@ -import os, json, datetime, textwrap, requests, gradio as gr +import json, datetime, textwrap, requests, gradio as gr from pathlib import Path from collections import deque -import threading, time, queue +import queue, threading, time -# ──────────────────────────────────── -# 1. 服务端 & 权重路径 -# ──────────────────────────────────── +# ───────────────────── 基础配置 ───────────────────── API_URL = "http://localhost:30000/generate" API_KEY = "token-abc123" -MODEL_PATH = Path("/root/.cradle/Alibaba/Qwen3-30B-A3B-Base") # ← 换成你的权重目录 +MODEL_PATH = Path("/root/.cradle/Alibaba/Qwen3-30B-A3B-Base") -def detect_model_name(path: Path) -> str: +def model_name(path: Path): cfg = path / "config.json" if cfg.exists(): data = json.load(cfg.open()) return data.get("architectures", [None])[0] or data.get("model_type") or path.name return path.name -MODEL_NAME = detect_model_name(MODEL_PATH) +MODEL_NAME = model_name(MODEL_PATH) now = lambda: datetime.datetime.now().strftime("%H:%M:%S") -# ──────────────────────────────────── -# 2. 日志队列(UI 实时查看) -# ──────────────────────────────────── +# ───────────────────── 日志队列 ───────────────────── LOG_Q: "queue.Queue[str]" = queue.Queue() +def log(msg): # 写终端 + 推队列 + print(msg, flush=True) + LOG_Q.put(msg) -def log(msg: str): - print(msg, flush=True) # 写到 stdout - LOG_Q.put(msg) # 送到 UI +def consume_logs(state_txt: str): + """供 Interval 调用:把队列里所有新行取出拼接到 state""" + buf = deque(state_txt.splitlines(), maxlen=400) + while not LOG_Q.empty(): + buf.append(LOG_Q.get()) + return "\n".join(buf) -def log_worker(log_box: gr.Textbox): - buf = deque(maxlen=400) # 最近 400 行 - while True: - line = LOG_Q.get() - buf.append(line) - log_box.value = "\n".join(buf) - -# ──────────────────────────────────── -# 3. 调用 /generate -# ──────────────────────────────────── -def call_backend(text: str, sampling: dict): - payload = { - "model": MODEL_NAME, - "text": text, - "sampling_params": sampling - } - log(f"\n🟡 [{now()}] payload ↓\n{json.dumps(payload, ensure_ascii=False, indent=2)}") +# ───────────────────── 后端调用 ───────────────────── +def backend(text, sampling): + payload = {"model": MODEL_NAME, "text": text, "sampling_params": sampling} + log(f"\n🟡 [{now()}] payload\n{json.dumps(payload, ensure_ascii=False, indent=2)}") try: - resp = requests.post( - API_URL, - headers={"Authorization": f"Bearer {API_KEY}", - "Content-Type": "application/json"}, - json=payload, timeout=180 - ) - status = resp.status_code - body = resp.text + r = requests.post(API_URL, + headers={"Authorization": f"Bearer {API_KEY}", + "Content-Type": "application/json"}, + json=payload, timeout=180) try: - data = resp.json() + data = r.json() except Exception: data = {} - finish = data.get("meta_info", {}).get("finish_reason") - c_tok = data.get("meta_info", {}).get("completion_tokens") - log(f"🟢 [{now()}] HTTP {status} " - f"completion_tokens={c_tok} finish_reason={finish}\n" - f"🟢 resp 前 800 字 ↓\n{body[:800]!r}") - if status != 200: - return f"[HTTP {status}] {body[:300]}" - return data.get("text", "").strip() or "[⚠ 后端返回空文本]" + fr = data.get("meta_info", {}).get("finish_reason") + ctok = data.get("meta_info", {}).get("completion_tokens") + log(f"🟢 [{now()}] HTTP {r.status_code} tokens={ctok} finish={fr}\n" + f"🟢 resp800={r.text[:800]!r}") + if r.status_code != 200: + return f"[HTTP {r.status_code}] {r.text[:300]}" + return data.get("text", "").strip() or "[⚠ 空]" except Exception as e: log(f"[❌ 请求异常] {e}") return f"[❌ 请求异常] {e}" -# ──────────────────────────────────── -# 4. ChatInterface 回调 -# ──────────────────────────────────── +# ───────────────────── Chat 回调 ───────────────────── def chat( - user_msg, history, + user, history, max_new, temp, top_p, top_k, - rep_pen, pres_pen, stop_raw + rep_pen, pres_pen, stop_raw, + log_state ): stop = [s.strip() for s in stop_raw.split(",") if s.strip()] or None - sampling = { + samp = { "max_new_tokens": int(max_new), - "temperature": temp, - "top_p": top_p, - "top_k": int(top_k), + "temperature": temp, + "top_p": top_p, + "top_k": int(top_k), "repetition_penalty": rep_pen, - "presence_penalty": pres_pen, + "presence_penalty": pres_pen, **({"stop": stop} if stop else {}) } - return call_backend(user_msg, sampling) + out = backend(user, samp) + # 返回回答,同时把 log_state 原样带回(不刷新由 Interval 处理) + return out, log_state -# ──────────────────────────────────── -# 5. Gradio UI -# ──────────────────────────────────── -with gr.Blocks(title="Base 模型调试界面") as demo: - gr.Markdown(f"## 💬 Base 模型调试界面 \n权重 **{MODEL_PATH.name}**") +# ───────────────────── Gradio UI ───────────────────── +with gr.Blocks(title="调试界面") as demo: + gr.Markdown(f"## 💬 调试界面 \n权重 **{MODEL_PATH.name}**") - # 采样控件 + # 采样参数控件 with gr.Row(): max_new = gr.Slider(32, 32768, 2048, label="max_new_tokens") temp = gr.Slider(0, 1.5, 0.8, step=0.05, label="temperature") @@ -105,24 +88,32 @@ with gr.Blocks(title="Base 模型调试界面") as demo: top_p = gr.Slider(0, 1, 0.95, step=0.01, label="top_p") top_k = gr.Slider(0, 200, 50, step=1, label="top_k") with gr.Row(): - rep_pen = gr.Slider(0.8, 2.0, 1.05, step=0.01, label="repetition_penalty") - pres_pen= gr.Slider(0, 2.0, 0.0, step=0.05, label="presence_penalty") - stop_text = gr.Textbox("", label="stop 序列(逗号分隔)") + rep_pen = gr.Slider(0.8, 2, 1.05, step=0.01, label="repetition_penalty") + pres_pen= gr.Slider(0, 2, 0.0, step=0.05, label="presence_penalty") + stop_txt = gr.Textbox("", label="stop 序列(逗号分隔)") - dbg_toggle = gr.Checkbox(label="📜 显示 / 隐藏 Debug 面板", value=False) - dbg_box = gr.Textbox(label="实时日志", lines=20, visible=False) + dbg_chk = gr.Checkbox(label="📜 显示 Debug 面板", value=True) + log_box = gr.Textbox(label="实时日志", lines=20, interactive=False, visible=True) + log_state= gr.State("") # 保存全部日志文本 + + # 定时刷新日志 + logger = gr.Interval(1.0, visible=False) + logger.set_event_trigger( + fn=consume_logs, + inputs=log_state, + outputs=log_state + ) + # 显示到 log_box + log_state.change(lambda txt: gr.update(value=txt), log_state, log_box) + + # Debug 面板可见性切换 + dbg_chk.change(lambda v: gr.update(visible=v), dbg_chk, log_box) # Chatbot chatbot = gr.ChatInterface( fn=chat, - additional_inputs=[max_new, temp, top_p, top_k, rep_pen, pres_pen, stop_text], + additional_inputs=[max_new, temp, top_p, top_k, rep_pen, pres_pen, stop_txt, log_state], type="messages" ) - # Toggle 日志面板可见性 - dbg_toggle.change(lambda v: gr.update(visible=v), dbg_toggle, dbg_box) - -# 启动后台日志线程 -threading.Thread(target=log_worker, args=(dbg_box,), daemon=True).start() - demo.launch(server_name="0.0.0.0", server_port=30001)