108 lines
3.5 KiB
Python
108 lines
3.5 KiB
Python
#!/usr/bin/env python3
|
||
import os, json, uuid, time, argparse, random, signal
|
||
from multiprocessing import Pool, cpu_count
|
||
from tqdm import tqdm
|
||
import requests
|
||
|
||
# ========== 配置 ==========
|
||
API_URL = "http://183.36.35.42:30000/v1/chat/completions"
|
||
API_KEY = "token-abc123"
|
||
|
||
HEADERS = {
|
||
"Content-Type": "application/json",
|
||
"Authorization": f"Bearer {API_KEY}"
|
||
}
|
||
|
||
# 温度/采样策略(可根据需求扩展)
|
||
TEMPS = [0.2, 0.5, 0.8, 1.0]
|
||
|
||
# ========== 函数 ==========
|
||
def ask_teacher(prompt: str, temperature: float = 0.7, top_p: float = 0.95, n: int = 1, logprobs: int = 5):
|
||
"""调用教师模型 API"""
|
||
payload = {
|
||
"model": "Qwen3-32B",
|
||
"messages": [{"role": "user", "content": prompt}],
|
||
"temperature": temperature,
|
||
"top_p": top_p,
|
||
"n": n,
|
||
"max_tokens": 1024,
|
||
"logprobs": logprobs,
|
||
"stream": False,
|
||
}
|
||
try:
|
||
r = requests.post(API_URL, headers=HEADERS, json=payload, timeout=60)
|
||
r.raise_for_status()
|
||
return r.json()
|
||
except Exception as e:
|
||
return {"error": str(e)}
|
||
|
||
def process_one(prompt: str):
|
||
"""处理单个 prompt,返回多条数据"""
|
||
records = []
|
||
for T in TEMPS:
|
||
resp = ask_teacher(prompt, temperature=T)
|
||
if "choices" not in resp: # 出错
|
||
continue
|
||
for ch in resp["choices"]:
|
||
rec = {
|
||
"id": str(uuid.uuid4()),
|
||
"prompt": prompt,
|
||
"teacher_answer": ch["message"]["content"],
|
||
"logprobs_topk": ch.get("logprobs", {}),
|
||
"decode_params": {"temperature": T, "top_p": 0.95, "n": 1},
|
||
"ts": int(time.time())
|
||
}
|
||
records.append(rec)
|
||
return records
|
||
|
||
def worker(job):
|
||
"""单个工作进程"""
|
||
prompt, out_path = job
|
||
records = process_one(prompt)
|
||
if not records:
|
||
return 0
|
||
with open(out_path, "a", encoding="utf-8") as f:
|
||
for r in records:
|
||
# 基础过滤
|
||
if 32 < len(r["teacher_answer"]) < 4000:
|
||
f.write(json.dumps(r, ensure_ascii=False) + "\n")
|
||
return len(records)
|
||
|
||
# ========== 主入口 ==========
|
||
if __name__ == "__main__":
|
||
parser = argparse.ArgumentParser()
|
||
parser.add_argument("--seed_file", type=str, default="seeds.txt", help="种子 prompt 文件")
|
||
parser.add_argument("--output", type=str, default="distilled.jsonl", help="输出 JSONL 文件")
|
||
parser.add_argument("--target", type=int, default=1000000, help="目标样本数")
|
||
parser.add_argument("--workers", type=int, default=cpu_count(), help="并行进程数")
|
||
args = parser.parse_args()
|
||
|
||
# 读种子
|
||
with open(args.seed_file) as f:
|
||
seeds = [l.strip() for l in f if l.strip()]
|
||
|
||
# 断点续跑:已有样本数
|
||
done = 0
|
||
if os.path.exists(args.output):
|
||
with open(args.output) as f:
|
||
done = sum(1 for _ in f)
|
||
print(f"[INFO] 已有 {done} 条,目标 {args.target}")
|
||
|
||
# 主循环
|
||
pool = Pool(processes=args.workers)
|
||
|
||
try:
|
||
with tqdm(total=args.target, initial=done) as pbar:
|
||
while done < args.target:
|
||
batch_prompts = random.choices(seeds, k=args.workers)
|
||
jobs = [(p, args.output) for p in batch_prompts]
|
||
for added in pool.map(worker, jobs):
|
||
done += added
|
||
pbar.update(added)
|
||
except KeyboardInterrupt:
|
||
print("[WARN] 用户中断,保存进度中…")
|
||
pool.terminate()
|
||
finally:
|
||
pool.close()
|
||
pool.join()
|