This commit is contained in:
hailin 2025-07-27 12:29:24 +08:00
parent 68a12b4b4a
commit 818a722192
1 changed files with 62 additions and 67 deletions

View File

@ -1,115 +1,110 @@
import os, json, datetime, textwrap, requests, gradio as gr import os, json, datetime, textwrap, requests, gradio as gr
from pathlib import Path from pathlib import Path
from collections import deque
import threading, time, queue, sys
#───────────────────────────────────────────────────────────────────────────────
# 1. 服务端 & 权重路径
#───────────────────────────────────────────────────────────────────────────────
API_URL = "http://localhost:30000/generate" API_URL = "http://localhost:30000/generate"
API_KEY = "token-abc123" API_KEY = "token-abc123"
MODEL_PATH = Path("/root/.cradle/Alibaba/Qwen3-30B-A3B-Base") # ← 改成 supervisor 里传的路径 MODEL_PATH = Path("/root/.cradle/Alibaba/Qwen3-30B-A3B-Base")
# 自动读取权重里的名字,若失败就退回目录名 # ── 检测权重里的模型名 ─────────────────────────────────────────────────────────
def detect_model_name(model_path: Path) -> str: def detect_model_name(model_path: Path) -> str:
cfg = model_path / "config.json" cfg = model_path / "config.json"
if cfg.exists(): if cfg.exists():
with open(cfg, "r", encoding="utf-8") as f: with open(cfg, "r", encoding="utf-8") as f:
data = json.load(f) data = json.load(f)
# Qwen / LLaMA / GPTNeoX … 都有 "architectures" 或 "model_type"
return data.get("architectures", [None])[0] or data.get("model_type") or model_path.name return data.get("architectures", [None])[0] or data.get("model_type") or model_path.name
return model_path.name return model_path.name
MODEL_NAME = detect_model_name(MODEL_PATH) MODEL_NAME = detect_model_name(MODEL_PATH)
def now(): def now():
return datetime.datetime.now().strftime("%H:%M:%S") return datetime.datetime.now().strftime("%H:%M:%S")
#─────────────────────────────────────────────────────────────────────────────── # ── 全局日志队列,用于 UI 实时查看 ─────────────────────────────────────────────
# 2. 调用 SGLang /generate LOG_Q: "queue.Queue[str]" = queue.Queue()
#─────────────────────────────────────────────────────────────────────────────── def log(msg: str):
print(msg) # stdout 保留
LOG_Q.put(msg) # 送到 UI
def log_worker(log_box): # 后台线程:把队列里的日志刷到 gr.Textbox
buffer = deque(maxlen=300)
while True:
line = LOG_Q.get()
buffer.append(line)
log_box.value = "\n".join(buffer)
time.sleep(0.01)
# ── 调用后端 ───────────────────────────────────────────────────────────────────
def call_backend(text, sampling): def call_backend(text, sampling):
payload = {"model": MODEL_NAME, "text": text, "sampling_params": sampling} payload = {"model": MODEL_NAME, "text": text, "sampling_params": sampling}
print(f"\n🟡 [{now()} payload] {json.dumps(payload, ensure_ascii=False)[:400]}") log(f"\n🟡 [{now()} payload]\n{json.dumps(payload, ensure_ascii=False, indent=2)}")
resp = requests.post(
API_URL,
headers={"Authorization": f"Bearer {API_KEY}", "Content-Type": "application/json"},
json=payload, timeout=180
)
if resp.status_code != 200:
return f"[HTTP {resp.status_code}] {resp.text[:300]}"
try: try:
return json.loads(resp.text).get("text", "").strip() or "[⚠ 后端无 text]" resp = requests.post(API_URL,
except json.JSONDecodeError: headers={"Authorization": f"Bearer {API_KEY}",
snippet = textwrap.shorten(resp.text, 300, placeholder="") "Content-Type": "application/json"},
return f"[⚠ JSON 解析失败] {snippet}" json=payload, timeout=180)
log(f"🟢 [{now()}] HTTP {resp.status_code}, body前400={resp.text[:400]!r}")
if resp.status_code != 200:
return f"[HTTP {resp.status_code}] {resp.text[:300]}"
data = json.loads(resp.text)
return data.get("text", "").strip() or "[⚠ 后端无 text]"
except Exception as e:
log(f"[❌ 请求异常] {e}")
return f"[❌ 请求异常] {e}"
#─────────────────────────────────────────────────────────────────────────────── # ── Gradio 主对话函数 ──────────────────────────────────────────────────────────
# 3. Gradio 主函数
#───────────────────────────────────────────────────────────────────────────────
def chat( def chat(
user_msg, history, user_msg, history, # ChatInterface 固定
max_new, temperature, top_p, top_k, max_new, temp, top_p, top_k, rep_pen, pres_pen, stop_raw
rep_pen, pres_pen, stop_raw
): ):
stop = [s.strip() for s in stop_raw.split(",") if s.strip()] or None stop = [s.strip() for s in stop_raw.split(",") if s.strip()] or None
sampling = { sampling = {
"max_new_tokens": int(max_new), "max_new_tokens": int(max_new),
"temperature": temperature, "temperature": temp,
"top_p": top_p, "top_p": top_p,
"top_k": int(top_k), "top_k": int(top_k),
"repetition_penalty": rep_pen, "repetition_penalty": rep_pen,
"presence_penalty": pres_pen, "presence_penalty": pres_pen,
**({"stop": stop} if stop else {})
} }
if stop: sampling["stop"] = stop
return call_backend(user_msg, sampling) return call_backend(user_msg, sampling)
#─────────────────────────────────────────────────────────────────────────────── # ── UI 构建 ────────────────────────────────────────────────────────────────────
# 4. Gradio UI with gr.Blocks(title="Base 模型调试界面") as demo:
#─────────────────────────────────────────────────────────────────────────────── gr.Markdown(f"## 💬 Base 模型调试界面 \n权重 **{MODEL_PATH.name}**")
with gr.Blocks(title="Base 模型对话界面") as demo:
gr.Markdown(f"## 💬 Base 模型对话界面 \n*正在使用权重* **{MODEL_PATH.name}**")
# ── 采样参数控件 ─────────────────────────────────────────────────────────── # 采样参数控件
with gr.Row(): with gr.Row():
max_new = gr.Slider(32, 32768, 2048, label="max_new_tokens") max_new = gr.Slider(32, 32768, 2048, label="max_new_tokens")
temperature = gr.Slider(0.0, 1.5, 0.8, step=0.05, label="temperature") temp = gr.Slider(0, 1.5, 0.8, step=0.05, label="temperature")
with gr.Row(): with gr.Row():
top_p = gr.Slider(0.0, 1.0, 0.95, step=0.01, label="top_p") top_p = gr.Slider(0, 1, 0.95, step=0.01, label="top_p")
top_k = gr.Slider(0, 200, 50, step=1, label="top_k") top_k = gr.Slider(0, 200, 50, step=1, label="top_k")
with gr.Row(): with gr.Row():
rep_pen = gr.Slider(0.8, 2.0, 1.05, step=0.01, label="repetition_penalty") rep_pen = gr.Slider(0.8, 2.0, 1.05, step=0.01, label="repetition_penalty")
pres_pen = gr.Slider(0.0, 2.0, 0.0, step=0.05, label="presence_penalty") pres_pen= gr.Slider(0, 2.0, 0.0, step=0.05, label="presence_penalty")
stop_text = gr.Textbox("", label="stop 序列(逗号分隔)", placeholder="如: ###,END") stop_text = gr.Textbox("", label="stop 序列(逗号分隔)")
# ── Chatbot & 按钮 ──────────────────────────────────────────────────────── # Debug 面板
ping_btn = gr.Button("🔁 测试 API") with gr.Row():
ping_out = gr.Textbox(label="API 测试结果", interactive=False) toggle_dbg = gr.Checkbox(label="📜 打开 / 关闭 Debug 面板", value=False)
chat_ui = gr.ChatInterface( dbg_box = gr.Textbox(label="实时日志", lines=20, visible=False)
# Chatbot
chatbot = gr.ChatInterface(
fn=chat, fn=chat,
additional_inputs=[max_new, temperature, top_p, top_k, rep_pen, pres_pen, stop_text], additional_inputs=[max_new, temp, top_p, top_k, rep_pen, pres_pen, stop_text],
type="messages" type="messages"
) )
def ping_api(max_new, temperature, top_p, top_k, rep_pen, pres_pen, stop_raw): # 切换 debug 面板显示
stop = [s.strip() for s in stop_raw.split(",") if s.strip()] or None def show_dbg(flag: bool):
sampling = { return gr.update(visible=flag)
"max_new_tokens": int(max_new), toggle_dbg.change(show_dbg, inputs=toggle_dbg, outputs=dbg_box)
"temperature": temperature,
"top_p": top_p,
"top_k": int(top_k),
"repetition_penalty": rep_pen,
"presence_penalty": pres_pen,
**({"stop": stop} if stop else {})
}
return call_backend("Ping?", sampling)[:200]
ping_btn.click( # 启动后台日志线程
fn=ping_api, threading.Thread(target=log_worker, args=(dbg_box,), daemon=True).start()
inputs=[max_new, temperature, top_p, top_k, rep_pen, pres_pen, stop_text],
outputs=ping_out
)
demo.launch(server_name="0.0.0.0", server_port=30001) demo.launch(server_name="0.0.0.0", server_port=30001)