655 lines
20 KiB
Python
655 lines
20 KiB
Python
import os
|
|
import json
|
|
import random
|
|
from tqdm import tqdm
|
|
from hashlib import md5
|
|
from warnings import warn
|
|
from typing import List, Optional
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
|
|
from llm import LLM
|
|
from utils import clean_content
|
|
from constant import TaskType, Task, SPECIAL_TASK_STEPS, \
|
|
get_task, get_generation_prompt, get_quality_control_prompt, \
|
|
get_gen_hard_neg_prompt
|
|
|
|
|
|
def compute_md5(text: str):
|
|
return md5(text.encode()).hexdigest()
|
|
|
|
|
|
class TripletGenerator(LLM):
|
|
def __init__(
|
|
self,
|
|
model: str = "Qwen2-5-Coder-32B-Instruct",
|
|
model_type: str = "open-source",
|
|
port: int = 8000,
|
|
cache_dir: Optional[str] = None
|
|
):
|
|
super().__init__(model, model_type, port)
|
|
self.cache_dir = cache_dir
|
|
if self.cache_dir is not None:
|
|
os.makedirs(self.cache_dir, exist_ok=True)
|
|
|
|
def _gen_for_code_modification_retrieval(
|
|
self,
|
|
task: Task,
|
|
text: str,
|
|
text_b: Optional[str] = None,
|
|
examples: Optional[List[dict]] = None,
|
|
debug_mode: bool = False,
|
|
**kwargs
|
|
):
|
|
gen_prompt = get_generation_prompt(
|
|
task=task,
|
|
text=text,
|
|
text_b=text_b,
|
|
examples=examples,
|
|
idx=0
|
|
)
|
|
response = self.chat(gen_prompt, **kwargs)[0]
|
|
diff = clean_content(response)
|
|
gen_prompt = get_generation_prompt(
|
|
task=task,
|
|
text=diff,
|
|
examples=examples,
|
|
idx=1
|
|
)
|
|
response = self.chat(gen_prompt, **kwargs)[0]
|
|
modification_instr = clean_content(response)
|
|
|
|
query = f"{modification_instr}\n```\n{text}\n```"
|
|
pos = text_b
|
|
|
|
if debug_mode:
|
|
result = {
|
|
"generation_prompt": gen_prompt,
|
|
"prompt": task.task_instruction,
|
|
"query": query,
|
|
"pos": [pos],
|
|
"neg": []
|
|
}
|
|
else:
|
|
result = {
|
|
"prompt": task.task_instruction,
|
|
"query": query,
|
|
"pos": [pos],
|
|
"neg": []
|
|
}
|
|
return result
|
|
|
|
def _gen_for_code_comparison_retrieval(
|
|
self,
|
|
task: Task,
|
|
text: str,
|
|
text_b: Optional[str] = None,
|
|
examples: Optional[List[dict]] = None,
|
|
debug_mode: bool = False,
|
|
**kwargs
|
|
):
|
|
gen_prompt = get_generation_prompt(
|
|
task=task,
|
|
text=text,
|
|
text_b=text_b,
|
|
examples=examples,
|
|
idx=0
|
|
)
|
|
response = self.chat(gen_prompt, **kwargs)[0]
|
|
diff_question = clean_content(response)
|
|
query = f"{diff_question}\n\nInput Code:\n```\n{text}\n```\n\nOutput Code:\n```\n{text_b}\n```"
|
|
gen_prompt = get_generation_prompt(
|
|
task=task,
|
|
text=query,
|
|
examples=examples,
|
|
idx=1
|
|
)
|
|
response = self.chat(gen_prompt, **kwargs)[0]
|
|
pos = clean_content(response)
|
|
|
|
if debug_mode:
|
|
result = {
|
|
"generation_prompt": gen_prompt,
|
|
"prompt": task.task_instruction,
|
|
"query": query,
|
|
"pos": [pos],
|
|
"neg": []
|
|
}
|
|
else:
|
|
result = {
|
|
"prompt": task.task_instruction,
|
|
"query": query,
|
|
"pos": [pos],
|
|
"neg": []
|
|
}
|
|
return result
|
|
|
|
def _gen_for_code_context_retrieval(
|
|
self,
|
|
task: Task,
|
|
text: str,
|
|
anchor_points: Optional[tuple] = (0.4, 0.7),
|
|
**kwargs
|
|
):
|
|
former_part, latter_part = self.split_text(
|
|
text,
|
|
anchor_points=anchor_points
|
|
)
|
|
result = {
|
|
"prompt": task.task_instruction,
|
|
"query": former_part,
|
|
"pos": [latter_part],
|
|
"neg": []
|
|
}
|
|
return result
|
|
|
|
@staticmethod
|
|
def _arrange_query_and_pos(task: Task, input_text: str, response: str):
|
|
"""
|
|
Arrange the query and positive example based on the task type.
|
|
|
|
Args:
|
|
- task: Task
|
|
- input_text: str
|
|
- response: str
|
|
|
|
Returns:
|
|
- query: str
|
|
- pos: str
|
|
"""
|
|
# TODO: support more task types, including some special task types.
|
|
if task.main_task_type in ["text2code", "hybrid"]:
|
|
query = clean_content(response)
|
|
pos = input_text
|
|
else:
|
|
query = input_text
|
|
pos = clean_content(response)
|
|
return query, pos
|
|
|
|
def _gen_for_normal_task(
|
|
self,
|
|
task: Task,
|
|
text: str,
|
|
examples: Optional[List[dict]] = None,
|
|
debug_mode: bool = False,
|
|
**kwargs
|
|
):
|
|
gen_prompt = get_generation_prompt(
|
|
task=task,
|
|
text=text,
|
|
examples=examples
|
|
)
|
|
response = self.chat(gen_prompt, **kwargs)[0]
|
|
|
|
# Arrange the query and positive example based on the task type.
|
|
query, pos = self._arrange_query_and_pos(
|
|
task=task,
|
|
input_text=text,
|
|
response=response
|
|
)
|
|
|
|
if debug_mode:
|
|
result = {
|
|
"generation_prompt": gen_prompt,
|
|
"prompt": task.task_instruction,
|
|
"query": query,
|
|
"pos": [pos],
|
|
"neg": [],
|
|
"response": response
|
|
}
|
|
else:
|
|
result = {
|
|
"prompt": task.task_instruction,
|
|
"query": query,
|
|
"pos": [pos],
|
|
"neg": []
|
|
}
|
|
return result
|
|
|
|
def _gen_for_bug_desc_retrieval(
|
|
self,
|
|
task: Task,
|
|
text: str,
|
|
examples: Optional[List[dict]] = None,
|
|
debug_mode: bool = False,
|
|
**kwargs
|
|
):
|
|
gen_prompt = get_generation_prompt(
|
|
task=task,
|
|
text=text,
|
|
examples=examples,
|
|
idx=0
|
|
)
|
|
response = self.chat(gen_prompt, **kwargs)[0]
|
|
if response is None:
|
|
raise ValueError("Response is None.")
|
|
buggy_code = response
|
|
gen_prompt = get_generation_prompt(
|
|
task=task,
|
|
text=buggy_code,
|
|
examples=examples,
|
|
idx=1
|
|
)
|
|
response = self.chat(gen_prompt, **kwargs)[0]
|
|
query = clean_content(response)
|
|
pos = text
|
|
|
|
if debug_mode:
|
|
result = {
|
|
"generation_prompt": gen_prompt,
|
|
"prompt": task.task_instruction,
|
|
"query": query,
|
|
"pos": [pos],
|
|
"neg": []
|
|
}
|
|
else:
|
|
result = {
|
|
"prompt": task.task_instruction,
|
|
"query": query,
|
|
"pos": [pos],
|
|
"neg": []
|
|
}
|
|
return result
|
|
|
|
def _gen_for_two_step_not_use_last(
|
|
self,
|
|
task: Task,
|
|
text: str,
|
|
examples: Optional[List[dict]] = None,
|
|
debug_mode: bool = False,
|
|
reverse_query_pos: bool = False,
|
|
**kwargs
|
|
):
|
|
gen_prompt = get_generation_prompt(
|
|
task=task,
|
|
text=text,
|
|
idx=0
|
|
)
|
|
response = self.chat(gen_prompt, **kwargs)[0]
|
|
query = clean_content(response)
|
|
gen_prompt = get_generation_prompt(
|
|
task=task,
|
|
text=query,
|
|
examples=examples,
|
|
idx=1
|
|
)
|
|
response = self.chat(gen_prompt, **kwargs)[0]
|
|
pos = clean_content(response)
|
|
if reverse_query_pos:
|
|
query, pos = pos, query
|
|
|
|
if debug_mode:
|
|
result = {
|
|
"generation_prompt": gen_prompt,
|
|
"prompt": task.task_instruction,
|
|
"query": query,
|
|
"pos": [pos],
|
|
"neg": []
|
|
}
|
|
else:
|
|
result = {
|
|
"prompt": task.task_instruction,
|
|
"query": query,
|
|
"pos": [pos],
|
|
"neg": []
|
|
}
|
|
return result
|
|
|
|
def _gen_for_two_step_use_last(
|
|
self,
|
|
task: Task,
|
|
text: str,
|
|
examples: Optional[List[dict]] = None,
|
|
debug_mode: bool = False,
|
|
reverse_query_pos: bool = False,
|
|
**kwargs
|
|
):
|
|
gen_prompt = get_generation_prompt(
|
|
task=task,
|
|
text=text,
|
|
idx=0
|
|
)
|
|
response = self.chat(gen_prompt, **kwargs)[0]
|
|
query = clean_content(response) + f"\n```\n{text}\n```"
|
|
gen_prompt = get_generation_prompt(
|
|
task=task,
|
|
text=query,
|
|
examples=examples,
|
|
idx=1
|
|
)
|
|
response = self.chat(gen_prompt, **kwargs)[0]
|
|
pos = clean_content(response)
|
|
if reverse_query_pos:
|
|
query, pos = pos, query
|
|
|
|
if debug_mode:
|
|
result = {
|
|
"generation_prompt": gen_prompt,
|
|
"prompt": task.task_instruction,
|
|
"query": query,
|
|
"pos": [pos],
|
|
"neg": []
|
|
}
|
|
else:
|
|
result = {
|
|
"prompt": task.task_instruction,
|
|
"query": query,
|
|
"pos": [pos],
|
|
"neg": []
|
|
}
|
|
return result
|
|
|
|
def generate_triplets(
|
|
self,
|
|
data: dict,
|
|
task: Task,
|
|
examples_pool: Optional[List[dict]] = None,
|
|
num_examples: int = 3,
|
|
debug_mode: bool = False,
|
|
**kwargs
|
|
):
|
|
kwargs["remove_thinking"] = not debug_mode
|
|
|
|
result_list = []
|
|
|
|
examples = None
|
|
if examples_pool is not None:
|
|
examples = random.sample(examples_pool, min(num_examples, len(examples_pool)))
|
|
|
|
try:
|
|
if task.task_type in SPECIAL_TASK_STEPS:
|
|
text = data["text"]
|
|
|
|
if task.task_type == TaskType.code_modification_retrieval:
|
|
text_b = data["similar"][0]
|
|
|
|
result = self._gen_for_code_modification_retrieval(
|
|
task=task,
|
|
text=text,
|
|
text_b=text_b,
|
|
examples=examples,
|
|
debug_mode=debug_mode
|
|
)
|
|
elif task.task_type == TaskType.code_comparison_retrieval:
|
|
text_b = data["similar"][0]
|
|
|
|
result = self._gen_for_code_comparison_retrieval(
|
|
task=task,
|
|
text=text,
|
|
text_b=text_b,
|
|
examples=examples,
|
|
debug_mode=debug_mode
|
|
)
|
|
elif task.task_type == TaskType.bug_desc_retrieval:
|
|
result = self._gen_for_bug_desc_retrieval(
|
|
task=task,
|
|
text=text,
|
|
examples=examples,
|
|
debug_mode=debug_mode
|
|
)
|
|
elif task.task_type in [
|
|
# cf - updated
|
|
TaskType.code_issue_discussion_retrieval,
|
|
TaskType.code_version_update_retrieval,
|
|
TaskType.code_bug_fix_example_retrieval,
|
|
]:
|
|
result = self._gen_for_two_step_not_use_last(
|
|
task=task,
|
|
text=text,
|
|
examples=examples,
|
|
debug_mode=debug_mode,
|
|
reverse_query_pos=False
|
|
)
|
|
elif task.task_type in [
|
|
# cf - updated
|
|
TaskType.code_refactoring_pattern_retrieval,
|
|
TaskType.code_style_guideline_example_retrieval,
|
|
TaskType.code_migration_retrieval,
|
|
# jl - updated
|
|
TaskType.code_optimization_hybrid_retrieval,
|
|
TaskType.code_best_practices_retrieval,
|
|
TaskType.security_vulnerability_fix_retrieval,
|
|
]:
|
|
result = self._gen_for_two_step_use_last(
|
|
task=task,
|
|
text=text,
|
|
examples=examples,
|
|
debug_mode=debug_mode,
|
|
reverse_query_pos=False
|
|
)
|
|
else:
|
|
raise NotImplementedError(f"Task type {task.task_type} not implemented.")
|
|
elif task.task_type == TaskType.code_context_retrieval:
|
|
text = data["text"]
|
|
|
|
result = self._gen_for_code_context_retrieval(
|
|
task=task,
|
|
text=text,
|
|
**kwargs
|
|
)
|
|
# NOTE: no need to do quality control for code context retrieval task
|
|
result_list.append(result)
|
|
return result_list
|
|
else:
|
|
text = data["text"]
|
|
|
|
result = self._gen_for_normal_task(
|
|
task=task,
|
|
text=text,
|
|
examples=examples,
|
|
debug_mode=debug_mode,
|
|
**kwargs
|
|
)
|
|
|
|
# print(gen_prompt)
|
|
# print('================================================')
|
|
qc_prompt = get_quality_control_prompt(
|
|
task=task,
|
|
query=result["query"],
|
|
pos=result["pos"][0]
|
|
)
|
|
# print(qc_prompt)
|
|
# print('*********************************************************************')
|
|
response = self.chat(qc_prompt, **kwargs)[0]
|
|
judge = clean_content(response)
|
|
# print(response, judge)
|
|
if "1" in judge:
|
|
if debug_mode:
|
|
result["judge"] = judge
|
|
result["judge_response"] = response
|
|
result_list.append(result)
|
|
else:
|
|
if debug_mode:
|
|
result["judge"] = judge
|
|
result["judge_response"] = response
|
|
result_list.append(result)
|
|
except Exception as e:
|
|
warn(f"Error: {e}")
|
|
|
|
return result_list
|
|
|
|
def gen_hard_negatives(self, result: dict, task: Task, num_negatives: int = 7, **kwargs):
|
|
gen_hard_neg_prompt = get_gen_hard_neg_prompt(
|
|
task=task,
|
|
query=result["query"],
|
|
pos=result["pos"][0]
|
|
)
|
|
response_list = self.chat(gen_hard_neg_prompt, n=num_negatives, **kwargs)
|
|
for response in response_list:
|
|
if response is None:
|
|
continue
|
|
hard_neg = clean_content(response)
|
|
result["neg"].append(hard_neg)
|
|
result["neg"] = list(set(result["neg"]))
|
|
return result
|
|
|
|
def run_single(
|
|
self,
|
|
data: dict,
|
|
task: Task,
|
|
examples_pool: Optional[List[dict]] = None,
|
|
num_examples: int = 3,
|
|
debug_mode: bool = False,
|
|
gen_hard_neg: bool = False,
|
|
num_negatives: int = 7,
|
|
**kwargs
|
|
):
|
|
result_list = []
|
|
|
|
docid = compute_md5(data["text"])
|
|
if self.cache_dir is not None:
|
|
gen_data_cache_path = os.path.join(self.cache_dir, f"{docid}.json")
|
|
if os.path.exists(gen_data_cache_path):
|
|
with open(gen_data_cache_path, "r", encoding="utf-8") as f:
|
|
result_list = json.load(f)
|
|
|
|
if len(result_list) > 0:
|
|
if gen_hard_neg:
|
|
for i in range(len(result_list)):
|
|
if len(result_list[i]["neg"]) == 0:
|
|
result_list[i] = self.gen_hard_negatives(
|
|
result=result_list[i],
|
|
task=task,
|
|
num_negatives=num_negatives,
|
|
**kwargs
|
|
)
|
|
# overwrite the cache file
|
|
with open(gen_data_cache_path, "w", encoding="utf-8") as f:
|
|
json.dump(result_list, f, indent=4, ensure_ascii=False)
|
|
return result_list
|
|
|
|
triplets = self.generate_triplets(
|
|
data,
|
|
task=task,
|
|
examples_pool=examples_pool,
|
|
num_examples=num_examples,
|
|
debug_mode=debug_mode,
|
|
**kwargs
|
|
)
|
|
if len(triplets) == 0:
|
|
return []
|
|
|
|
result = triplets[0]
|
|
if debug_mode:
|
|
result["docid"] = docid
|
|
|
|
if gen_hard_neg:
|
|
result = self.gen_hard_negatives(
|
|
result,
|
|
task=task,
|
|
num_negatives=num_negatives,
|
|
**kwargs
|
|
)
|
|
|
|
result_list.append(result)
|
|
|
|
if self.cache_dir is not None:
|
|
gen_data_cache_path = os.path.join(self.cache_dir, f"{docid}.json")
|
|
with open(gen_data_cache_path, "w", encoding="utf-8") as f:
|
|
json.dump(result_list, f, indent=4, ensure_ascii=False)
|
|
|
|
return result_list
|
|
|
|
def run(
|
|
self,
|
|
positives: List[dict],
|
|
task_type: str,
|
|
language: str = "en",
|
|
code_language: str = "python",
|
|
tgt_code_language: Optional[str] = None,
|
|
examples_pool: Optional[List[dict]] = None,
|
|
num_examples: int = 3,
|
|
tqdm_desc: str = "Generating triplets",
|
|
debug_mode: bool = False,
|
|
gen_hard_neg: bool = False,
|
|
num_negatives: int = 7,
|
|
thread_count: int = 1,
|
|
**kwargs
|
|
):
|
|
task = get_task(
|
|
task_type=task_type,
|
|
language=language,
|
|
code_language=code_language,
|
|
tgt_code_language=tgt_code_language
|
|
)
|
|
|
|
result_list = []
|
|
|
|
def process_positive(positive):
|
|
return self.run_single(
|
|
data=positive,
|
|
task=task,
|
|
examples_pool=examples_pool,
|
|
num_examples=num_examples,
|
|
debug_mode=debug_mode,
|
|
gen_hard_neg=gen_hard_neg,
|
|
num_negatives=num_negatives,
|
|
**kwargs
|
|
)
|
|
# Use thread pool for parallel processing with tqdm progress bar.
|
|
with ThreadPoolExecutor(max_workers=thread_count) as executor:
|
|
results = list(tqdm(executor.map(
|
|
process_positive,
|
|
positives
|
|
), total=len(positives), desc=tqdm_desc))
|
|
|
|
# Collect results into result_list.
|
|
for res in results:
|
|
if isinstance(res, list):
|
|
result_list.extend(res)
|
|
else:
|
|
result_list.append(res)
|
|
# result_list.extend(results)
|
|
|
|
return result_list
|
|
|
|
def run_for_gen_neg(
|
|
self,
|
|
pairs: List[dict],
|
|
task_type: str,
|
|
language: str = "en",
|
|
code_language: str = "python",
|
|
tgt_code_language: Optional[str] = None,
|
|
examples_pool: Optional[List[dict]] = None,
|
|
num_examples: int = 3,
|
|
tqdm_desc: str = "Generating triplets",
|
|
debug_mode: bool = False,
|
|
gen_hard_neg: bool = False,
|
|
num_negatives: int = 7,
|
|
thread_count: int = 1,
|
|
**kwargs
|
|
):
|
|
task = get_task(
|
|
task_type=task_type,
|
|
language=language,
|
|
code_language=code_language,
|
|
tgt_code_language=tgt_code_language
|
|
)
|
|
|
|
result_list = []
|
|
|
|
def gen_single_negative(pair):
|
|
result = self.gen_hard_negatives(
|
|
pair,
|
|
task=task,
|
|
num_negatives=num_negatives,
|
|
**kwargs
|
|
)
|
|
return [result]
|
|
|
|
# Use thread pool for parallel processing with tqdm progress bar.
|
|
with ThreadPoolExecutor(max_workers=thread_count) as executor:
|
|
results = list(tqdm(executor.map(
|
|
gen_single_negative,
|
|
pairs
|
|
), total=len(pairs), desc=tqdm_desc))
|
|
|
|
# Collect results into result_list.
|
|
for res in results:
|
|
if isinstance(res, list):
|
|
result_list.extend(res)
|
|
else:
|
|
result_list.append(res)
|
|
# result_list.extend(results)
|
|
|
|
return result_list
|