"""Common utils for embeddings.""" import json import re import uuid from typing import Dict, List, Tuple from tqdm import tqdm from llama_index.bridge.pydantic import BaseModel from llama_index.llms.utils import LLM from llama_index.schema import MetadataMode, TextNode class EmbeddingQAFinetuneDataset(BaseModel): """Embedding QA Finetuning Dataset. Args: queries (Dict[str, str]): Dict id -> query. corpus (Dict[str, str]): Dict id -> string. relevant_docs (Dict[str, List[str]]): Dict query id -> list of doc ids. """ queries: Dict[str, str] # dict id -> query corpus: Dict[str, str] # dict id -> string relevant_docs: Dict[str, List[str]] # query id -> list of doc ids mode: str = "text" @property def query_docid_pairs(self) -> List[Tuple[str, List[str]]]: """Get query, relevant doc ids.""" return [ (query, self.relevant_docs[query_id]) for query_id, query in self.queries.items() ] def save_json(self, path: str) -> None: """Save json.""" with open(path, "w") as f: json.dump(self.dict(), f, indent=4) @classmethod def from_json(cls, path: str) -> "EmbeddingQAFinetuneDataset": """Load json.""" with open(path) as f: data = json.load(f) return cls(**data) DEFAULT_QA_GENERATE_PROMPT_TMPL = """\ Context information is below. --------------------- {context_str} --------------------- Given the context information and not prior knowledge. generate only questions based on the below query. You are a Teacher/ Professor. Your task is to setup \ {num_questions_per_chunk} questions for an upcoming \ quiz/examination. The questions should be diverse in nature \ across the document. Restrict the questions to the \ context information provided." """ # generate queries as a convenience function def generate_qa_embedding_pairs( nodes: List[TextNode], llm: LLM, qa_generate_prompt_tmpl: str = DEFAULT_QA_GENERATE_PROMPT_TMPL, num_questions_per_chunk: int = 2, ) -> EmbeddingQAFinetuneDataset: """Generate examples given a set of nodes.""" node_dict = { node.node_id: node.get_content(metadata_mode=MetadataMode.NONE) for node in nodes } queries = {} relevant_docs = {} for node_id, text in tqdm(node_dict.items()): query = qa_generate_prompt_tmpl.format( context_str=text, num_questions_per_chunk=num_questions_per_chunk ) response = llm.complete(query) result = str(response).strip().split("\n") questions = [ re.sub(r"^\d+[\).\s]", "", question).strip() for question in result ] questions = [question for question in questions if len(question) > 0] for question in questions: question_id = str(uuid.uuid4()) queries[question_id] = question relevant_docs[question_id] = [node_id] # construct dataset return EmbeddingQAFinetuneDataset( queries=queries, corpus=node_dict, relevant_docs=relevant_docs )