sglang.0.4.8.post1/meta_ui.py

116 lines
6.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os, json, datetime, textwrap, requests, gradio as gr
from pathlib import Path
#───────────────────────────────────────────────────────────────────────────────
# 1. 服务端 & 权重路径
#───────────────────────────────────────────────────────────────────────────────
API_URL = "http://localhost:30000/generate"
API_KEY = "token-abc123"
MODEL_PATH = Path("/root/.cradle/Alibaba/Qwen3-30B-A3B-Base") # ← 改成 supervisor 里传的路径
# 自动读取权重里的名字,若失败就退回目录名
def detect_model_name(model_path: Path) -> str:
cfg = model_path / "config.json"
if cfg.exists():
with open(cfg, "r", encoding="utf-8") as f:
data = json.load(f)
# Qwen / LLaMA / GPTNeoX … 都有 "architectures" 或 "model_type"
return data.get("architectures", [None])[0] or data.get("model_type") or model_path.name
return model_path.name
MODEL_NAME = detect_model_name(MODEL_PATH)
def now():
return datetime.datetime.now().strftime("%H:%M:%S")
#───────────────────────────────────────────────────────────────────────────────
# 2. 调用 SGLang /generate
#───────────────────────────────────────────────────────────────────────────────
def call_backend(text, sampling):
payload = {"model": MODEL_NAME, "text": text, "sampling_params": sampling}
print(f"\n🟡 [{now()} payload] {json.dumps(payload, ensure_ascii=False)[:400]}")
resp = requests.post(
API_URL,
headers={"Authorization": f"Bearer {API_KEY}", "Content-Type": "application/json"},
json=payload, timeout=180
)
if resp.status_code != 200:
return f"[HTTP {resp.status_code}] {resp.text[:300]}"
try:
return json.loads(resp.text).get("text", "").strip() or "[⚠ 后端无 text]"
except json.JSONDecodeError:
snippet = textwrap.shorten(resp.text, 300, placeholder="")
return f"[⚠ JSON 解析失败] {snippet}"
#───────────────────────────────────────────────────────────────────────────────
# 3. Gradio 主函数
#───────────────────────────────────────────────────────────────────────────────
def chat(
user_msg, history,
max_new, temperature, top_p, top_k,
rep_pen, pres_pen, stop_raw
):
stop = [s.strip() for s in stop_raw.split(",") if s.strip()] or None
sampling = {
"max_new_tokens": int(max_new),
"temperature": temperature,
"top_p": top_p,
"top_k": int(top_k),
"repetition_penalty": rep_pen,
"presence_penalty": pres_pen,
}
if stop: sampling["stop"] = stop
return call_backend(user_msg, sampling)
#───────────────────────────────────────────────────────────────────────────────
# 4. Gradio UI
#───────────────────────────────────────────────────────────────────────────────
with gr.Blocks(title="Base 模型对话界面") as demo:
gr.Markdown(f"## 💬 Base 模型对话界面 \n*正在使用权重* **{MODEL_PATH.name}**")
# ── 采样参数控件 ───────────────────────────────────────────────────────────
with gr.Row():
max_new = gr.Slider(32, 32768, 2048, label="max_new_tokens")
temperature = gr.Slider(0.0, 1.5, 0.8, step=0.05, label="temperature")
with gr.Row():
top_p = gr.Slider(0.0, 1.0, 0.95, step=0.01, label="top_p")
top_k = gr.Slider(0, 200, 50, step=1, label="top_k")
with gr.Row():
rep_pen = gr.Slider(0.8, 2.0, 1.05, step=0.01, label="repetition_penalty")
pres_pen = gr.Slider(0.0, 2.0, 0.0, step=0.05, label="presence_penalty")
stop_text = gr.Textbox("", label="stop 序列(逗号分隔)", placeholder="如: ###,END")
# ── Chatbot & 按钮 ────────────────────────────────────────────────────────
ping_btn = gr.Button("🔁 测试 API")
ping_out = gr.Textbox(label="API 测试结果", interactive=False)
chat_ui = gr.ChatInterface(
fn=chat,
additional_inputs=[max_new, temperature, top_p, top_k, rep_pen, pres_pen, stop_text],
type="messages"
)
def ping_api(max_new, temperature, top_p, top_k, rep_pen, pres_pen, stop_raw):
stop = [s.strip() for s in stop_raw.split(",") if s.strip()] or None
sampling = {
"max_new_tokens": int(max_new),
"temperature": temperature,
"top_p": top_p,
"top_k": int(top_k),
"repetition_penalty": rep_pen,
"presence_penalty": pres_pen,
**({"stop": stop} if stop else {})
}
return call_backend("Ping?", sampling)[:200]
ping_btn.click(
fn=ping_api,
inputs=[max_new, temperature, top_p, top_k, rep_pen, pres_pen, stop_text],
outputs=ping_out
)
demo.launch(server_name="0.0.0.0", server_port=30001)