#!/usr/bin/env python3 import os, json, uuid, time, argparse, random, signal from multiprocessing import Pool, cpu_count from tqdm import tqdm import requests # ========== 配置 ========== API_URL = "http://183.36.35.42:30000/v1/chat/completions" API_KEY = "token-abc123" HEADERS = { "Content-Type": "application/json", "Authorization": f"Bearer {API_KEY}" } # 温度/采样策略(可根据需求扩展) TEMPS = [0.2, 0.5, 0.8, 1.0] # ========== 函数 ========== def ask_teacher(prompt: str, temperature: float = 0.7, top_p: float = 0.95, n: int = 1, logprobs: int = 5): """调用教师模型 API""" payload = { "model": "Qwen3-32B", "messages": [{"role": "user", "content": prompt}], "temperature": temperature, "top_p": top_p, "n": n, "max_tokens": 1024, "logprobs": logprobs, "stream": False, } try: r = requests.post(API_URL, headers=HEADERS, json=payload, timeout=60) r.raise_for_status() return r.json() except Exception as e: return {"error": str(e)} def process_one(prompt: str): """处理单个 prompt,返回多条数据""" records = [] for T in TEMPS: resp = ask_teacher(prompt, temperature=T) if "choices" not in resp: # 出错 continue for ch in resp["choices"]: rec = { "id": str(uuid.uuid4()), "prompt": prompt, "teacher_answer": ch["message"]["content"], "logprobs_topk": ch.get("logprobs", {}), "decode_params": {"temperature": T, "top_p": 0.95, "n": 1}, "ts": int(time.time()) } records.append(rec) return records def worker(job): """单个工作进程""" prompt, out_path = job records = process_one(prompt) if not records: return 0 with open(out_path, "a", encoding="utf-8") as f: for r in records: # 基础过滤 if 32 < len(r["teacher_answer"]) < 4000: f.write(json.dumps(r, ensure_ascii=False) + "\n") return len(records) # ========== 主入口 ========== if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--seed_file", type=str, default="seeds.txt", help="种子 prompt 文件") parser.add_argument("--output", type=str, default="distilled.jsonl", help="输出 JSONL 文件") parser.add_argument("--target", type=int, default=1000000, help="目标样本数") parser.add_argument("--workers", type=int, default=cpu_count(), help="并行进程数") args = parser.parse_args() # 读种子 with open(args.seed_file) as f: seeds = [l.strip() for l in f if l.strip()] # 断点续跑:已有样本数 done = 0 if os.path.exists(args.output): with open(args.output) as f: done = sum(1 for _ in f) print(f"[INFO] 已有 {done} 条,目标 {args.target}") # 主循环 pool = Pool(processes=args.workers) try: with tqdm(total=args.target, initial=done) as pbar: while done < args.target: batch_prompts = random.choices(seeds, k=args.workers) jobs = [(p, args.output) for p in batch_prompts] for added in pool.map(worker, jobs): done += added pbar.update(added) except KeyboardInterrupt: print("[WARN] 用户中断,保存进度中…") pool.terminate() finally: pool.close() pool.join()