jd_train/distill_hf.py

108 lines
3.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
import os, json, uuid, time, argparse, random, signal
from multiprocessing import Pool, cpu_count
from tqdm import tqdm
import requests
# ========== 配置 ==========
API_URL = "http://183.36.35.42:30000/v1/chat/completions"
API_KEY = "token-abc123"
HEADERS = {
"Content-Type": "application/json",
"Authorization": f"Bearer {API_KEY}"
}
# 温度/采样策略(可根据需求扩展)
TEMPS = [0.2, 0.5, 0.8, 1.0]
# ========== 函数 ==========
def ask_teacher(prompt: str, temperature: float = 0.7, top_p: float = 0.95, n: int = 1, logprobs: int = 5):
"""调用教师模型 API"""
payload = {
"model": "Qwen3-32B",
"messages": [{"role": "user", "content": prompt}],
"temperature": temperature,
"top_p": top_p,
"n": n,
"max_tokens": 1024,
"logprobs": logprobs,
"stream": False,
}
try:
r = requests.post(API_URL, headers=HEADERS, json=payload, timeout=60)
r.raise_for_status()
return r.json()
except Exception as e:
return {"error": str(e)}
def process_one(prompt: str):
"""处理单个 prompt返回多条数据"""
records = []
for T in TEMPS:
resp = ask_teacher(prompt, temperature=T)
if "choices" not in resp: # 出错
continue
for ch in resp["choices"]:
rec = {
"id": str(uuid.uuid4()),
"prompt": prompt,
"teacher_answer": ch["message"]["content"],
"logprobs_topk": ch.get("logprobs", {}),
"decode_params": {"temperature": T, "top_p": 0.95, "n": 1},
"ts": int(time.time())
}
records.append(rec)
return records
def worker(job):
"""单个工作进程"""
prompt, out_path = job
records = process_one(prompt)
if not records:
return 0
with open(out_path, "a", encoding="utf-8") as f:
for r in records:
# 基础过滤
if 32 < len(r["teacher_answer"]) < 4000:
f.write(json.dumps(r, ensure_ascii=False) + "\n")
return len(records)
# ========== 主入口 ==========
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--seed_file", type=str, default="seeds.txt", help="种子 prompt 文件")
parser.add_argument("--output", type=str, default="distilled.jsonl", help="输出 JSONL 文件")
parser.add_argument("--target", type=int, default=1000000, help="目标样本数")
parser.add_argument("--workers", type=int, default=cpu_count(), help="并行进程数")
args = parser.parse_args()
# 读种子
with open(args.seed_file) as f:
seeds = [l.strip() for l in f if l.strip()]
# 断点续跑:已有样本数
done = 0
if os.path.exists(args.output):
with open(args.output) as f:
done = sum(1 for _ in f)
print(f"[INFO] 已有 {done} 条,目标 {args.target}")
# 主循环
pool = Pool(processes=args.workers)
try:
with tqdm(total=args.target, initial=done) as pbar:
while done < args.target:
batch_prompts = random.choices(seeds, k=args.workers)
jobs = [(p, args.output) for p in batch_prompts]
for added in pool.map(worker, jobs):
done += added
pbar.update(added)
except KeyboardInterrupt:
print("[WARN] 用户中断,保存进度中…")
pool.terminate()
finally:
pool.close()
pool.join()