104 lines
3.0 KiB
Python
104 lines
3.0 KiB
Python
"""Common utils for embeddings."""
|
|
import json
|
|
import re
|
|
import uuid
|
|
from typing import Dict, List, Tuple
|
|
|
|
from tqdm import tqdm
|
|
|
|
from llama_index.bridge.pydantic import BaseModel
|
|
from llama_index.llms.utils import LLM
|
|
from llama_index.schema import MetadataMode, TextNode
|
|
|
|
|
|
class EmbeddingQAFinetuneDataset(BaseModel):
|
|
"""Embedding QA Finetuning Dataset.
|
|
|
|
Args:
|
|
queries (Dict[str, str]): Dict id -> query.
|
|
corpus (Dict[str, str]): Dict id -> string.
|
|
relevant_docs (Dict[str, List[str]]): Dict query id -> list of doc ids.
|
|
|
|
"""
|
|
|
|
queries: Dict[str, str] # dict id -> query
|
|
corpus: Dict[str, str] # dict id -> string
|
|
relevant_docs: Dict[str, List[str]] # query id -> list of doc ids
|
|
mode: str = "text"
|
|
|
|
@property
|
|
def query_docid_pairs(self) -> List[Tuple[str, List[str]]]:
|
|
"""Get query, relevant doc ids."""
|
|
return [
|
|
(query, self.relevant_docs[query_id])
|
|
for query_id, query in self.queries.items()
|
|
]
|
|
|
|
def save_json(self, path: str) -> None:
|
|
"""Save json."""
|
|
with open(path, "w") as f:
|
|
json.dump(self.dict(), f, indent=4)
|
|
|
|
@classmethod
|
|
def from_json(cls, path: str) -> "EmbeddingQAFinetuneDataset":
|
|
"""Load json."""
|
|
with open(path) as f:
|
|
data = json.load(f)
|
|
return cls(**data)
|
|
|
|
|
|
DEFAULT_QA_GENERATE_PROMPT_TMPL = """\
|
|
Context information is below.
|
|
|
|
---------------------
|
|
{context_str}
|
|
---------------------
|
|
|
|
Given the context information and not prior knowledge.
|
|
generate only questions based on the below query.
|
|
|
|
You are a Teacher/ Professor. Your task is to setup \
|
|
{num_questions_per_chunk} questions for an upcoming \
|
|
quiz/examination. The questions should be diverse in nature \
|
|
across the document. Restrict the questions to the \
|
|
context information provided."
|
|
"""
|
|
|
|
|
|
# generate queries as a convenience function
|
|
def generate_qa_embedding_pairs(
|
|
nodes: List[TextNode],
|
|
llm: LLM,
|
|
qa_generate_prompt_tmpl: str = DEFAULT_QA_GENERATE_PROMPT_TMPL,
|
|
num_questions_per_chunk: int = 2,
|
|
) -> EmbeddingQAFinetuneDataset:
|
|
"""Generate examples given a set of nodes."""
|
|
node_dict = {
|
|
node.node_id: node.get_content(metadata_mode=MetadataMode.NONE)
|
|
for node in nodes
|
|
}
|
|
|
|
queries = {}
|
|
relevant_docs = {}
|
|
for node_id, text in tqdm(node_dict.items()):
|
|
query = qa_generate_prompt_tmpl.format(
|
|
context_str=text, num_questions_per_chunk=num_questions_per_chunk
|
|
)
|
|
response = llm.complete(query)
|
|
|
|
result = str(response).strip().split("\n")
|
|
questions = [
|
|
re.sub(r"^\d+[\).\s]", "", question).strip() for question in result
|
|
]
|
|
questions = [question for question in questions if len(question) > 0]
|
|
|
|
for question in questions:
|
|
question_id = str(uuid.uuid4())
|
|
queries[question_id] = question
|
|
relevant_docs[question_id] = [node_id]
|
|
|
|
# construct dataset
|
|
return EmbeddingQAFinetuneDataset(
|
|
queries=queries, corpus=node_dict, relevant_docs=relevant_docs
|
|
)
|