369 lines
12 KiB
Python
369 lines
12 KiB
Python
import os
|
|
import json
|
|
import time
|
|
import gc
|
|
import torch
|
|
import argparse
|
|
import random
|
|
from hashlib import md5
|
|
import multiprocessing as mp
|
|
from typing import List, Optional
|
|
|
|
from constant import TaskType, Language, CodeLanguage, NUM_HARD_NEGATIVES
|
|
from corpus_generator import CorpusGenerator
|
|
from triplet_generator import TripletGenerator
|
|
from search import get_top1
|
|
|
|
|
|
def compute_md5(text: str):
|
|
return md5(text.encode()).hexdigest()
|
|
|
|
|
|
def get_args():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument(
|
|
'--task_type',
|
|
type=str,
|
|
required=True,
|
|
help='The task type to generate data for',
|
|
choices=[t.name for t in TaskType]
|
|
)
|
|
parser.add_argument(
|
|
'--code_language',
|
|
type=str,
|
|
required=True,
|
|
help='The code language to generate questions for.',
|
|
choices=[c.name for c in CodeLanguage]
|
|
)
|
|
parser.add_argument(
|
|
'--corpus_root',
|
|
type=str,
|
|
required=True,
|
|
help='The root directory of the corpus data.'
|
|
)
|
|
parser.add_argument(
|
|
'--save_dir',
|
|
type=str,
|
|
required=True,
|
|
help='The path to save the generated data'
|
|
)
|
|
parser.add_argument(
|
|
'--examples_dir',
|
|
type=str,
|
|
default=None,
|
|
help='The path to the examples directory. If not None, the examples will be used for few-shot generation.'
|
|
)
|
|
parser.add_argument(
|
|
'--num_examples',
|
|
type=int,
|
|
default=3,
|
|
help='The number of examples to use for few-shot generation. Default: 3'
|
|
)
|
|
parser.add_argument(
|
|
'--cache_dir',
|
|
type=str,
|
|
default=None,
|
|
help='The cache directory'
|
|
)
|
|
parser.add_argument(
|
|
'--language',
|
|
type=str,
|
|
default='en',
|
|
help='The language to generate for. ISO 639-1 code. Default: en',
|
|
choices=[l.name for l in Language]
|
|
)
|
|
parser.add_argument(
|
|
'--tgt_code_language',
|
|
type=str,
|
|
default=None,
|
|
help='The target code language to generate code translations for.',
|
|
choices=[c.name for c in CodeLanguage]
|
|
)
|
|
parser.add_argument(
|
|
'--num_samples',
|
|
type=int,
|
|
default=-1,
|
|
help='The number of examples to use for generation. Default: -1. Use all available examples.'
|
|
)
|
|
parser.add_argument(
|
|
'--model',
|
|
type=str,
|
|
default='Qwen2.5-72B-Instruct',
|
|
help='The model to use for generation. Default: Qwen2.5-72B-Instruct'
|
|
)
|
|
parser.add_argument(
|
|
'--model_type',
|
|
type=str,
|
|
default='open-source',
|
|
help='The type of model to use for generation. Default: open-source',
|
|
)
|
|
parser.add_argument(
|
|
'--port',
|
|
type=int,
|
|
default=8000,
|
|
help='The port for vllm.'
|
|
)
|
|
parser.add_argument(
|
|
'--num_processes',
|
|
type=int,
|
|
default=1,
|
|
help='The number of processes to use for generation. Default: 1'
|
|
)
|
|
parser.add_argument(
|
|
'--doc_length',
|
|
type=str,
|
|
default='len_0_500',
|
|
help='The corpus length used to load dataset. Default: len_0_500'
|
|
)
|
|
parser.add_argument(
|
|
'--external_path',
|
|
type=str,
|
|
default='',
|
|
help='The corpus length used to load dataset. Default: len_0_500'
|
|
)
|
|
parser.add_argument(
|
|
'--sim_model_name',
|
|
type=str,
|
|
default=None,
|
|
help='The language of source corpus.'
|
|
)
|
|
parser.add_argument(
|
|
'--max_corpus',
|
|
type=int,
|
|
default=500000,
|
|
help='The max num of corpus to load.'
|
|
)
|
|
parser.add_argument(
|
|
'--overwrite',
|
|
action='store_true',
|
|
help='Whether to overwrite the existing data.'
|
|
)
|
|
parser.add_argument(
|
|
'--debug_mode',
|
|
action='store_true',
|
|
help='Whether to open debug mode.'
|
|
)
|
|
parser.add_argument(
|
|
'--gen_hard_neg',
|
|
action='store_true',
|
|
help='Whether to generate hard negatives.'
|
|
)
|
|
parser.add_argument(
|
|
'--seed',
|
|
type=int,
|
|
default=None,
|
|
help='Random seed for generating triplets using the same positive. Default: 42'
|
|
)
|
|
args = parser.parse_args()
|
|
return args
|
|
|
|
|
|
def gen_triplets(
|
|
model: str,
|
|
model_type: str,
|
|
port: int,
|
|
positives: List[dict],
|
|
task_type: str,
|
|
language: str,
|
|
code_language: str,
|
|
tgt_code_language: str,
|
|
examples_pool: Optional[List[dict]] = None,
|
|
num_examples: int = 3,
|
|
tqdm_desc: str = "Generating triplets",
|
|
thread_count: int = 1,
|
|
gen_cache_dir: Optional[str] = None,
|
|
debug_mode: bool = False,
|
|
gen_hard_neg: bool = False,
|
|
):
|
|
triplet_generator = TripletGenerator(model, model_type, port, cache_dir=gen_cache_dir)
|
|
triplets = triplet_generator.run(
|
|
positives=positives,
|
|
task_type=task_type,
|
|
language=language,
|
|
code_language=code_language,
|
|
tgt_code_language=tgt_code_language,
|
|
examples_pool=examples_pool,
|
|
num_examples=num_examples,
|
|
tqdm_desc=tqdm_desc,
|
|
thread_count=thread_count,
|
|
debug_mode=debug_mode,
|
|
gen_hard_neg=gen_hard_neg,
|
|
num_negatives=NUM_HARD_NEGATIVES,
|
|
)
|
|
return triplets
|
|
|
|
|
|
def get_save_path(
|
|
save_dir: str,
|
|
task_type: str,
|
|
language: str,
|
|
code_language: str,
|
|
tgt_code_language: Optional[str] = None
|
|
):
|
|
save_dir = os.path.join(save_dir, language, task_type)
|
|
if tgt_code_language is not None:
|
|
file_name = f"{language}-{code_language}-to-{tgt_code_language}-triplets.jsonl"
|
|
else:
|
|
file_name = f"{language}-{code_language}-triplets.jsonl"
|
|
save_path = os.path.join(save_dir, file_name)
|
|
os.makedirs(save_dir, exist_ok=True)
|
|
return save_path
|
|
|
|
|
|
def save_triplets(
|
|
triplets: list,
|
|
save_dir: str,
|
|
task_type: str,
|
|
language: str,
|
|
code_language: str,
|
|
tgt_code_language: Optional[str] = None
|
|
):
|
|
if len(triplets) == 0:
|
|
print(f"No triplets to save: {task_type} | {language} | {code_language} | {tgt_code_language}")
|
|
return
|
|
|
|
save_path = get_save_path(save_dir, task_type, language, code_language, tgt_code_language)
|
|
query_md5s = set()
|
|
pos_md5s = set()
|
|
old_triplets = []
|
|
if os.path.exists(save_path):
|
|
with open(save_path, "r", encoding="utf-8") as f:
|
|
for line in f.readlines():
|
|
triplet = json.loads(line)
|
|
old_triplets.append(triplet)
|
|
query_md5s.add(compute_md5(triplet['query']))
|
|
pos_md5s.add(compute_md5(triplet['pos'][0]))
|
|
|
|
with open(save_path, 'w', encoding='utf-8') as f:
|
|
for triplet in old_triplets:
|
|
f.write(json.dumps(triplet, ensure_ascii=False) + '\n')
|
|
|
|
for triplet in triplets:
|
|
_query_md5 = compute_md5(triplet['query'])
|
|
_pos_md5 = compute_md5(triplet['pos'][0])
|
|
if _query_md5 in query_md5s or _pos_md5 in pos_md5s:
|
|
continue
|
|
f.write(json.dumps(triplet, ensure_ascii=False) + '\n')
|
|
print(f"Triplets saved to {save_path}")
|
|
|
|
|
|
def main(args):
|
|
# set seed
|
|
seed = args.seed
|
|
if seed is not None:
|
|
print(f"------------------- Seed set to {seed} -------------------")
|
|
random.seed(seed)
|
|
|
|
model = args.model
|
|
model_type = args.model_type
|
|
port = args.port
|
|
|
|
num_samples = args.num_samples
|
|
|
|
task_type = args.task_type
|
|
language = args.language
|
|
code_language = args.code_language
|
|
tgt_code_language = args.tgt_code_language
|
|
|
|
corpus_root = args.corpus_root
|
|
corpus_dir = os.path.join(corpus_root, code_language)
|
|
doc_length = args.doc_length.split()
|
|
external_path = args.external_path.split()
|
|
|
|
save_dir = args.save_dir
|
|
cache_dir = args.cache_dir
|
|
num_processes = min(args.num_processes, int(mp.cpu_count() * 0.8))
|
|
overwrite = args.overwrite
|
|
debug_mode = args.debug_mode
|
|
gen_hard_neg = args.gen_hard_neg
|
|
|
|
save_path = get_save_path(save_dir, task_type, language, code_language, tgt_code_language)
|
|
# if os.path.exists(save_path) and not overwrite:
|
|
# data = []
|
|
# with open(save_path) as f:
|
|
# for line in f:
|
|
# data.append(json.loads(line))
|
|
# if len(data) >= num_samples * 0.8:
|
|
# print(f"Triplets already exist at {save_path}. Skipping generation.")
|
|
# return
|
|
# else:
|
|
# print(f"Triplets already exist at {save_path}. But samples is really small, continue generation.")
|
|
# num_samples = int((num_samples - len(data)) * 1.25) # consider the filtered samples
|
|
|
|
corpus_generator = CorpusGenerator(cache_dir)
|
|
|
|
examples_dir = args.examples_dir
|
|
num_examples = args.num_examples
|
|
if examples_dir is not None:
|
|
# if task_type in ["single_turn_code_qa", "multi_turn_code_qa"]:
|
|
# examples_path = os.path.join(examples_dir, language, task_type, "sample_examples.json")
|
|
if task_type in ["code_translation_retrieval"]:
|
|
examples_path = os.path.join(examples_dir, language, task_type,
|
|
f"{code_language}-to-{tgt_code_language}_sample_examples.json")
|
|
else:
|
|
examples_path = os.path.join(examples_dir, language, task_type, f"{code_language}_sample_examples.json")
|
|
try:
|
|
with open(examples_path, 'r', encoding='utf-8') as f:
|
|
examples_pool = json.load(f)
|
|
examples_pool = random.sample(examples_pool,
|
|
min(30, len(examples_pool))) # sample 30 examples for few-shot generation
|
|
except:
|
|
print(f'Error for loading examples from {examples_path}')
|
|
examples_pool = None
|
|
else:
|
|
examples_pool = None
|
|
|
|
positives, large_positives = corpus_generator.run(
|
|
num_samples=num_samples,
|
|
max_corpus=args.max_corpus,
|
|
corpus_dir=corpus_dir,
|
|
doc_length=doc_length,
|
|
external_path=external_path,
|
|
source_language=code_language
|
|
)
|
|
|
|
if task_type in ["code_modification_retrieval", "code_comparison_retrieval"]:
|
|
top1_docs = get_top1([e['text'] for e in positives], args.sim_model_name, [e['text'] for e in large_positives])
|
|
for i in range(len(top1_docs)):
|
|
positives[i]['similar'] = top1_docs[i]
|
|
gc.collect()
|
|
torch.cuda.empty_cache()
|
|
|
|
print("=================== Generate training data ===================")
|
|
print(f'Task Type: {task_type} | Language: {language} | Code Language: {code_language} | Target Code Language: {tgt_code_language}')
|
|
start_time = time.time()
|
|
triplets = gen_triplets(
|
|
model=model,
|
|
model_type=model_type,
|
|
port=port,
|
|
positives=positives,
|
|
task_type=task_type,
|
|
language=language,
|
|
code_language=code_language,
|
|
tgt_code_language=tgt_code_language,
|
|
examples_pool=examples_pool,
|
|
num_examples=num_examples,
|
|
thread_count=num_processes,
|
|
gen_cache_dir=os.path.join(save_dir, language, task_type, "gen_cache_dir"),
|
|
debug_mode=debug_mode,
|
|
gen_hard_neg=gen_hard_neg,
|
|
)
|
|
save_triplets(
|
|
triplets=triplets,
|
|
save_dir=save_dir,
|
|
task_type=task_type,
|
|
language=language,
|
|
code_language=code_language,
|
|
tgt_code_language=tgt_code_language
|
|
)
|
|
end_time = time.time()
|
|
print("=============================================================")
|
|
print(f"Time taken: {end_time - start_time:.2f} seconds")
|
|
print("=============================================================")
|
|
print("DONE!")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
args = get_args()
|
|
main(args)
|