This commit is contained in:
hailin 2025-07-27 12:35:49 +08:00
parent 4071f51150
commit d18985e8a3
1 changed files with 70 additions and 52 deletions

View File

@ -1,61 +1,83 @@
import os, json, datetime, textwrap, requests, gradio as gr import os, json, datetime, textwrap, requests, gradio as gr
from pathlib import Path from pathlib import Path
from collections import deque from collections import deque
import threading, time, queue, sys import threading, time, queue
# ────────────────────────────────────
# 1. 服务端 & 权重路径
# ────────────────────────────────────
API_URL = "http://localhost:30000/generate" API_URL = "http://localhost:30000/generate"
API_KEY = "token-abc123" API_KEY = "token-abc123"
MODEL_PATH = Path("/root/.cradle/Alibaba/Qwen3-30B-A3B-Base") MODEL_PATH = Path("/root/.cradle/Alibaba/Qwen3-30B-A3B-Base") # ← 换成你的权重目录
# ── 检测权重里的模型名 ───────────────────────────────────────────────────────── def detect_model_name(path: Path) -> str:
def detect_model_name(model_path: Path) -> str: cfg = path / "config.json"
cfg = model_path / "config.json"
if cfg.exists(): if cfg.exists():
with open(cfg, "r", encoding="utf-8") as f: data = json.load(cfg.open())
data = json.load(f) return data.get("architectures", [None])[0] or data.get("model_type") or path.name
return data.get("architectures", [None])[0] or data.get("model_type") or model_path.name return path.name
return model_path.name
MODEL_NAME = detect_model_name(MODEL_PATH) MODEL_NAME = detect_model_name(MODEL_PATH)
now = lambda: datetime.datetime.now().strftime("%H:%M:%S")
def now(): # ────────────────────────────────────
return datetime.datetime.now().strftime("%H:%M:%S") # 2. 日志队列UI 实时查看)
# ────────────────────────────────────
# ── 全局日志队列,用于 UI 实时查看 ─────────────────────────────────────────────
LOG_Q: "queue.Queue[str]" = queue.Queue() LOG_Q: "queue.Queue[str]" = queue.Queue()
def log(msg: str): def log(msg: str):
print(msg) # stdout 保留 print(msg, flush=True) # 写到 stdout
LOG_Q.put(msg) # 送到 UI LOG_Q.put(msg) # 送到 UI
def log_worker(log_box): # 后台线程:把队列里的日志刷到 gr.Textbox def log_worker(log_box: gr.Textbox):
buffer = deque(maxlen=300) buf = deque(maxlen=400) # 最近 400 行
while True: while True:
line = LOG_Q.get() line = LOG_Q.get()
buffer.append(line) buf.append(line)
log_box.value = "\n".join(buffer) log_box.value = "\n".join(buf)
time.sleep(0.01)
# ── 调用后端 ─────────────────────────────────────────────────────────────────── # ────────────────────────────────────
def call_backend(text, sampling): # 3. 调用 /generate
payload = {"model": MODEL_NAME, "text": text, "sampling_params": sampling} # ────────────────────────────────────
log(f"\n🟡 [{now()} payload]\n{json.dumps(payload, ensure_ascii=False, indent=2)}") def call_backend(text: str, sampling: dict):
payload = {
"model": MODEL_NAME,
"text": text,
"sampling_params": sampling
}
log(f"\n🟡 [{now()}] payload ↓\n{json.dumps(payload, ensure_ascii=False, indent=2)}")
try: try:
resp = requests.post(API_URL, resp = requests.post(
API_URL,
headers={"Authorization": f"Bearer {API_KEY}", headers={"Authorization": f"Bearer {API_KEY}",
"Content-Type": "application/json"}, "Content-Type": "application/json"},
json=payload, timeout=180) json=payload, timeout=180
log(f"🟢 [{now()}] HTTP {resp.status_code}, body前400={resp.text[:400]!r}") )
if resp.status_code != 200: status = resp.status_code
return f"[HTTP {resp.status_code}] {resp.text[:300]}" body = resp.text
data = json.loads(resp.text) try:
return data.get("text", "").strip() or "[⚠ 后端无 text]" data = resp.json()
except Exception:
data = {}
finish = data.get("meta_info", {}).get("finish_reason")
c_tok = data.get("meta_info", {}).get("completion_tokens")
log(f"🟢 [{now()}] HTTP {status} "
f"completion_tokens={c_tok} finish_reason={finish}\n"
f"🟢 resp 前 800 字 ↓\n{body[:800]!r}")
if status != 200:
return f"[HTTP {status}] {body[:300]}"
return data.get("text", "").strip() or "[⚠ 后端返回空文本]"
except Exception as e: except Exception as e:
log(f"[❌ 请求异常] {e}") log(f"[❌ 请求异常] {e}")
return f"[❌ 请求异常] {e}" return f"[❌ 请求异常] {e}"
# ── Gradio 主对话函数 ────────────────────────────────────────────────────────── # ────────────────────────────────────
# 4. ChatInterface 回调
# ────────────────────────────────────
def chat( def chat(
user_msg, history, # ChatInterface 固定 user_msg, history,
max_new, temp, top_p, top_k, rep_pen, pres_pen, stop_raw max_new, temp, top_p, top_k,
rep_pen, pres_pen, stop_raw
): ):
stop = [s.strip() for s in stop_raw.split(",") if s.strip()] or None stop = [s.strip() for s in stop_raw.split(",") if s.strip()] or None
sampling = { sampling = {
@ -69,13 +91,15 @@ def chat(
} }
return call_backend(user_msg, sampling) return call_backend(user_msg, sampling)
# ── UI 构建 ──────────────────────────────────────────────────────────────────── # ────────────────────────────────────
# 5. Gradio UI
# ────────────────────────────────────
with gr.Blocks(title="Base 模型调试界面") as demo: with gr.Blocks(title="Base 模型调试界面") as demo:
gr.Markdown(f"## 💬 Base 模型调试界面 \n权重 **{MODEL_PATH.name}**") gr.Markdown(f"## 💬 Base 模型调试界面 \n权重 **{MODEL_PATH.name}**")
# 采样参数控件 # 采样控件
with gr.Row(): with gr.Row():
max_new = gr.Slider(32, 32768, 1024, label="max_new_tokens") max_new = gr.Slider(32, 32768, 2048, label="max_new_tokens")
temp = gr.Slider(0, 1.5, 0.8, step=0.05, label="temperature") temp = gr.Slider(0, 1.5, 0.8, step=0.05, label="temperature")
with gr.Row(): with gr.Row():
top_p = gr.Slider(0, 1, 0.95, step=0.01, label="top_p") top_p = gr.Slider(0, 1, 0.95, step=0.01, label="top_p")
@ -83,13 +107,9 @@ with gr.Blocks(title="Base 模型调试界面") as demo:
with gr.Row(): with gr.Row():
rep_pen = gr.Slider(0.8, 2.0, 1.05, step=0.01, label="repetition_penalty") rep_pen = gr.Slider(0.8, 2.0, 1.05, step=0.01, label="repetition_penalty")
pres_pen= gr.Slider(0, 2.0, 0.0, step=0.05, label="presence_penalty") pres_pen= gr.Slider(0, 2.0, 0.0, step=0.05, label="presence_penalty")
stop_text = gr.Textbox("", label="stop 序列(逗号分隔)") stop_text = gr.Textbox("", label="stop 序列(逗号分隔)")
# Debug 面板 dbg_toggle = gr.Checkbox(label="📜 显示 / 隐藏 Debug 面板", value=False)
with gr.Row():
toggle_dbg = gr.Checkbox(label="📜 打开 / 关闭 Debug 面板", value=False)
dbg_box = gr.Textbox(label="实时日志", lines=20, visible=False) dbg_box = gr.Textbox(label="实时日志", lines=20, visible=False)
# Chatbot # Chatbot
@ -99,10 +119,8 @@ with gr.Blocks(title="Base 模型调试界面") as demo:
type="messages" type="messages"
) )
# 切换 debug 面板显示 # Toggle 日志面板可见性
def show_dbg(flag: bool): dbg_toggle.change(lambda v: gr.update(visible=v), dbg_toggle, dbg_box)
return gr.update(visible=flag)
toggle_dbg.change(show_dbg, inputs=toggle_dbg, outputs=dbg_box)
# 启动后台日志线程 # 启动后台日志线程
threading.Thread(target=log_worker, args=(dbg_box,), daemon=True).start() threading.Thread(target=log_worker, args=(dbg_box,), daemon=True).start()