This commit is contained in:
hailin 2025-07-27 12:42:37 +08:00
parent d18985e8a3
commit 82e5957f8e
1 changed files with 68 additions and 77 deletions

View File

@ -1,103 +1,86 @@
import os, json, datetime, textwrap, requests, gradio as gr import json, datetime, textwrap, requests, gradio as gr
from pathlib import Path from pathlib import Path
from collections import deque from collections import deque
import threading, time, queue import queue, threading, time
# ──────────────────────────────────── # ───────────────────── 基础配置 ─────────────────────
# 1. 服务端 & 权重路径
# ────────────────────────────────────
API_URL = "http://localhost:30000/generate" API_URL = "http://localhost:30000/generate"
API_KEY = "token-abc123" API_KEY = "token-abc123"
MODEL_PATH = Path("/root/.cradle/Alibaba/Qwen3-30B-A3B-Base") # ← 换成你的权重目录 MODEL_PATH = Path("/root/.cradle/Alibaba/Qwen3-30B-A3B-Base")
def detect_model_name(path: Path) -> str: def model_name(path: Path):
cfg = path / "config.json" cfg = path / "config.json"
if cfg.exists(): if cfg.exists():
data = json.load(cfg.open()) data = json.load(cfg.open())
return data.get("architectures", [None])[0] or data.get("model_type") or path.name return data.get("architectures", [None])[0] or data.get("model_type") or path.name
return path.name return path.name
MODEL_NAME = detect_model_name(MODEL_PATH) MODEL_NAME = model_name(MODEL_PATH)
now = lambda: datetime.datetime.now().strftime("%H:%M:%S") now = lambda: datetime.datetime.now().strftime("%H:%M:%S")
# ──────────────────────────────────── # ───────────────────── 日志队列 ─────────────────────
# 2. 日志队列UI 实时查看)
# ────────────────────────────────────
LOG_Q: "queue.Queue[str]" = queue.Queue() LOG_Q: "queue.Queue[str]" = queue.Queue()
def log(msg): # 写终端 + 推队列
print(msg, flush=True)
LOG_Q.put(msg)
def log(msg: str): def consume_logs(state_txt: str):
print(msg, flush=True) # 写到 stdout """供 Interval 调用:把队列里所有新行取出拼接到 state"""
LOG_Q.put(msg) # 送到 UI buf = deque(state_txt.splitlines(), maxlen=400)
while not LOG_Q.empty():
buf.append(LOG_Q.get())
return "\n".join(buf)
def log_worker(log_box: gr.Textbox): # ───────────────────── 后端调用 ─────────────────────
buf = deque(maxlen=400) # 最近 400 行 def backend(text, sampling):
while True: payload = {"model": MODEL_NAME, "text": text, "sampling_params": sampling}
line = LOG_Q.get() log(f"\n🟡 [{now()}] payload\n{json.dumps(payload, ensure_ascii=False, indent=2)}")
buf.append(line)
log_box.value = "\n".join(buf)
# ────────────────────────────────────
# 3. 调用 /generate
# ────────────────────────────────────
def call_backend(text: str, sampling: dict):
payload = {
"model": MODEL_NAME,
"text": text,
"sampling_params": sampling
}
log(f"\n🟡 [{now()}] payload ↓\n{json.dumps(payload, ensure_ascii=False, indent=2)}")
try: try:
resp = requests.post( r = requests.post(API_URL,
API_URL, headers={"Authorization": f"Bearer {API_KEY}",
headers={"Authorization": f"Bearer {API_KEY}", "Content-Type": "application/json"},
"Content-Type": "application/json"}, json=payload, timeout=180)
json=payload, timeout=180
)
status = resp.status_code
body = resp.text
try: try:
data = resp.json() data = r.json()
except Exception: except Exception:
data = {} data = {}
finish = data.get("meta_info", {}).get("finish_reason") fr = data.get("meta_info", {}).get("finish_reason")
c_tok = data.get("meta_info", {}).get("completion_tokens") ctok = data.get("meta_info", {}).get("completion_tokens")
log(f"🟢 [{now()}] HTTP {status} " log(f"🟢 [{now()}] HTTP {r.status_code} tokens={ctok} finish={fr}\n"
f"completion_tokens={c_tok} finish_reason={finish}\n" f"🟢 resp800={r.text[:800]!r}")
f"🟢 resp 前 800 字 ↓\n{body[:800]!r}") if r.status_code != 200:
if status != 200: return f"[HTTP {r.status_code}] {r.text[:300]}"
return f"[HTTP {status}] {body[:300]}" return data.get("text", "").strip() or "[⚠ 空]"
return data.get("text", "").strip() or "[⚠ 后端返回空文本]"
except Exception as e: except Exception as e:
log(f"[❌ 请求异常] {e}") log(f"[❌ 请求异常] {e}")
return f"[❌ 请求异常] {e}" return f"[❌ 请求异常] {e}"
# ──────────────────────────────────── # ───────────────────── Chat 回调 ─────────────────────
# 4. ChatInterface 回调
# ────────────────────────────────────
def chat( def chat(
user_msg, history, user, history,
max_new, temp, top_p, top_k, max_new, temp, top_p, top_k,
rep_pen, pres_pen, stop_raw rep_pen, pres_pen, stop_raw,
log_state
): ):
stop = [s.strip() for s in stop_raw.split(",") if s.strip()] or None stop = [s.strip() for s in stop_raw.split(",") if s.strip()] or None
sampling = { samp = {
"max_new_tokens": int(max_new), "max_new_tokens": int(max_new),
"temperature": temp, "temperature": temp,
"top_p": top_p, "top_p": top_p,
"top_k": int(top_k), "top_k": int(top_k),
"repetition_penalty": rep_pen, "repetition_penalty": rep_pen,
"presence_penalty": pres_pen, "presence_penalty": pres_pen,
**({"stop": stop} if stop else {}) **({"stop": stop} if stop else {})
} }
return call_backend(user_msg, sampling) out = backend(user, samp)
# 返回回答,同时把 log_state 原样带回(不刷新由 Interval 处理)
return out, log_state
# ──────────────────────────────────── # ───────────────────── Gradio UI ─────────────────────
# 5. Gradio UI with gr.Blocks(title="调试界面") as demo:
# ──────────────────────────────────── gr.Markdown(f"## 💬 调试界面 \n权重 **{MODEL_PATH.name}**")
with gr.Blocks(title="Base 模型调试界面") as demo:
gr.Markdown(f"## 💬 Base 模型调试界面 \n权重 **{MODEL_PATH.name}**")
# 采样控件 # 采样参数控件
with gr.Row(): with gr.Row():
max_new = gr.Slider(32, 32768, 2048, label="max_new_tokens") max_new = gr.Slider(32, 32768, 2048, label="max_new_tokens")
temp = gr.Slider(0, 1.5, 0.8, step=0.05, label="temperature") temp = gr.Slider(0, 1.5, 0.8, step=0.05, label="temperature")
@ -105,24 +88,32 @@ with gr.Blocks(title="Base 模型调试界面") as demo:
top_p = gr.Slider(0, 1, 0.95, step=0.01, label="top_p") top_p = gr.Slider(0, 1, 0.95, step=0.01, label="top_p")
top_k = gr.Slider(0, 200, 50, step=1, label="top_k") top_k = gr.Slider(0, 200, 50, step=1, label="top_k")
with gr.Row(): with gr.Row():
rep_pen = gr.Slider(0.8, 2.0, 1.05, step=0.01, label="repetition_penalty") rep_pen = gr.Slider(0.8, 2, 1.05, step=0.01, label="repetition_penalty")
pres_pen= gr.Slider(0, 2.0, 0.0, step=0.05, label="presence_penalty") pres_pen= gr.Slider(0, 2, 0.0, step=0.05, label="presence_penalty")
stop_text = gr.Textbox("", label="stop 序列(逗号分隔)") stop_txt = gr.Textbox("", label="stop 序列(逗号分隔)")
dbg_toggle = gr.Checkbox(label="📜 显示 / 隐藏 Debug 面板", value=False) dbg_chk = gr.Checkbox(label="📜 显示 Debug 面板", value=True)
dbg_box = gr.Textbox(label="实时日志", lines=20, visible=False) log_box = gr.Textbox(label="实时日志", lines=20, interactive=False, visible=True)
log_state= gr.State("") # 保存全部日志文本
# 定时刷新日志
logger = gr.Interval(1.0, visible=False)
logger.set_event_trigger(
fn=consume_logs,
inputs=log_state,
outputs=log_state
)
# 显示到 log_box
log_state.change(lambda txt: gr.update(value=txt), log_state, log_box)
# Debug 面板可见性切换
dbg_chk.change(lambda v: gr.update(visible=v), dbg_chk, log_box)
# Chatbot # Chatbot
chatbot = gr.ChatInterface( chatbot = gr.ChatInterface(
fn=chat, fn=chat,
additional_inputs=[max_new, temp, top_p, top_k, rep_pen, pres_pen, stop_text], additional_inputs=[max_new, temp, top_p, top_k, rep_pen, pres_pen, stop_txt, log_state],
type="messages" type="messages"
) )
# Toggle 日志面板可见性
dbg_toggle.change(lambda v: gr.update(visible=v), dbg_toggle, dbg_box)
# 启动后台日志线程
threading.Thread(target=log_worker, args=(dbg_box,), daemon=True).start()
demo.launch(server_name="0.0.0.0", server_port=30001) demo.launch(server_name="0.0.0.0", server_port=30001)