From d18985e8a36d7efa0d5c7757b68d0e46ad968df7 Mon Sep 17 00:00:00 2001 From: hailin Date: Sun, 27 Jul 2025 12:35:49 +0800 Subject: [PATCH] . --- meta_ui.py | 122 ++++++++++++++++++++++++++++++----------------------- 1 file changed, 70 insertions(+), 52 deletions(-) diff --git a/meta_ui.py b/meta_ui.py index ca71760..3f887f7 100644 --- a/meta_ui.py +++ b/meta_ui.py @@ -1,61 +1,83 @@ import os, json, datetime, textwrap, requests, gradio as gr from pathlib import Path from collections import deque -import threading, time, queue, sys +import threading, time, queue -API_URL = "http://localhost:30000/generate" -API_KEY = "token-abc123" -MODEL_PATH = Path("/root/.cradle/Alibaba/Qwen3-30B-A3B-Base") +# ──────────────────────────────────── +# 1. 服务端 & 权重路径 +# ──────────────────────────────────── +API_URL = "http://localhost:30000/generate" +API_KEY = "token-abc123" +MODEL_PATH = Path("/root/.cradle/Alibaba/Qwen3-30B-A3B-Base") # ← 换成你的权重目录 -# ── 检测权重里的模型名 ───────────────────────────────────────────────────────── -def detect_model_name(model_path: Path) -> str: - cfg = model_path / "config.json" +def detect_model_name(path: Path) -> str: + cfg = path / "config.json" if cfg.exists(): - with open(cfg, "r", encoding="utf-8") as f: - data = json.load(f) - return data.get("architectures", [None])[0] or data.get("model_type") or model_path.name - return model_path.name + data = json.load(cfg.open()) + return data.get("architectures", [None])[0] or data.get("model_type") or path.name + return path.name + MODEL_NAME = detect_model_name(MODEL_PATH) +now = lambda: datetime.datetime.now().strftime("%H:%M:%S") -def now(): - return datetime.datetime.now().strftime("%H:%M:%S") - -# ── 全局日志队列,用于 UI 实时查看 ───────────────────────────────────────────── +# ──────────────────────────────────── +# 2. 日志队列(UI 实时查看) +# ──────────────────────────────────── LOG_Q: "queue.Queue[str]" = queue.Queue() -def log(msg: str): - print(msg) # stdout 保留 - LOG_Q.put(msg) # 送到 UI -def log_worker(log_box): # 后台线程:把队列里的日志刷到 gr.Textbox - buffer = deque(maxlen=300) +def log(msg: str): + print(msg, flush=True) # 写到 stdout + LOG_Q.put(msg) # 送到 UI + +def log_worker(log_box: gr.Textbox): + buf = deque(maxlen=400) # 最近 400 行 while True: line = LOG_Q.get() - buffer.append(line) - log_box.value = "\n".join(buffer) - time.sleep(0.01) + buf.append(line) + log_box.value = "\n".join(buf) -# ── 调用后端 ─────────────────────────────────────────────────────────────────── -def call_backend(text, sampling): - payload = {"model": MODEL_NAME, "text": text, "sampling_params": sampling} - log(f"\n🟡 [{now()} payload]\n{json.dumps(payload, ensure_ascii=False, indent=2)}") +# ──────────────────────────────────── +# 3. 调用 /generate +# ──────────────────────────────────── +def call_backend(text: str, sampling: dict): + payload = { + "model": MODEL_NAME, + "text": text, + "sampling_params": sampling + } + log(f"\n🟡 [{now()}] payload ↓\n{json.dumps(payload, ensure_ascii=False, indent=2)}") try: - resp = requests.post(API_URL, - headers={"Authorization": f"Bearer {API_KEY}", - "Content-Type": "application/json"}, - json=payload, timeout=180) - log(f"🟢 [{now()}] HTTP {resp.status_code}, body前400={resp.text[:400]!r}") - if resp.status_code != 200: - return f"[HTTP {resp.status_code}] {resp.text[:300]}" - data = json.loads(resp.text) - return data.get("text", "").strip() or "[⚠ 后端无 text]" + resp = requests.post( + API_URL, + headers={"Authorization": f"Bearer {API_KEY}", + "Content-Type": "application/json"}, + json=payload, timeout=180 + ) + status = resp.status_code + body = resp.text + try: + data = resp.json() + except Exception: + data = {} + finish = data.get("meta_info", {}).get("finish_reason") + c_tok = data.get("meta_info", {}).get("completion_tokens") + log(f"🟢 [{now()}] HTTP {status} " + f"completion_tokens={c_tok} finish_reason={finish}\n" + f"🟢 resp 前 800 字 ↓\n{body[:800]!r}") + if status != 200: + return f"[HTTP {status}] {body[:300]}" + return data.get("text", "").strip() or "[⚠ 后端返回空文本]" except Exception as e: log(f"[❌ 请求异常] {e}") return f"[❌ 请求异常] {e}" -# ── Gradio 主对话函数 ────────────────────────────────────────────────────────── +# ──────────────────────────────────── +# 4. ChatInterface 回调 +# ──────────────────────────────────── def chat( - user_msg, history, # ChatInterface 固定 - max_new, temp, top_p, top_k, rep_pen, pres_pen, stop_raw + user_msg, history, + max_new, temp, top_p, top_k, + rep_pen, pres_pen, stop_raw ): stop = [s.strip() for s in stop_raw.split(",") if s.strip()] or None sampling = { @@ -69,13 +91,15 @@ def chat( } return call_backend(user_msg, sampling) -# ── UI 构建 ──────────────────────────────────────────────────────────────────── +# ──────────────────────────────────── +# 5. Gradio UI +# ──────────────────────────────────── with gr.Blocks(title="Base 模型调试界面") as demo: gr.Markdown(f"## 💬 Base 模型调试界面 \n权重 **{MODEL_PATH.name}**") - # 采样参数控件 + # 采样控件 with gr.Row(): - max_new = gr.Slider(32, 32768, 1024, label="max_new_tokens") + max_new = gr.Slider(32, 32768, 2048, label="max_new_tokens") temp = gr.Slider(0, 1.5, 0.8, step=0.05, label="temperature") with gr.Row(): top_p = gr.Slider(0, 1, 0.95, step=0.01, label="top_p") @@ -83,14 +107,10 @@ with gr.Blocks(title="Base 模型调试界面") as demo: with gr.Row(): rep_pen = gr.Slider(0.8, 2.0, 1.05, step=0.01, label="repetition_penalty") pres_pen= gr.Slider(0, 2.0, 0.0, step=0.05, label="presence_penalty") - stop_text = gr.Textbox("", label="stop 序列(逗号分隔)") - # Debug 面板 - with gr.Row(): - toggle_dbg = gr.Checkbox(label="📜 打开 / 关闭 Debug 面板", value=False) - - dbg_box = gr.Textbox(label="实时日志", lines=20, visible=False) + dbg_toggle = gr.Checkbox(label="📜 显示 / 隐藏 Debug 面板", value=False) + dbg_box = gr.Textbox(label="实时日志", lines=20, visible=False) # Chatbot chatbot = gr.ChatInterface( @@ -99,10 +119,8 @@ with gr.Blocks(title="Base 模型调试界面") as demo: type="messages" ) - # 切换 debug 面板显示 - def show_dbg(flag: bool): - return gr.update(visible=flag) - toggle_dbg.change(show_dbg, inputs=toggle_dbg, outputs=dbg_box) + # Toggle 日志面板可见性 + dbg_toggle.change(lambda v: gr.update(visible=v), dbg_toggle, dbg_box) # 启动后台日志线程 threading.Thread(target=log_worker, args=(dbg_box,), daemon=True).start()