faiss_rag_enterprise/llama_index/finetuning/rerankers/dataset_gen.py

129 lines
4.6 KiB
Python

import random
from typing import Any, List, Optional, Tuple
from llama_index.bridge.pydantic import BaseModel
from llama_index.finetuning import EmbeddingQAFinetuneDataset
from llama_index.indices.query.embedding_utils import get_top_k_embeddings
class CohereRerankerFinetuneDataset(BaseModel):
"""Class for keeping track of CohereAI Reranker finetuning training/validation Dataset."""
query: str
relevant_passages: List[str]
hard_negatives: Any
def to_jsonl(self) -> str:
"""Convert the BaseModel instance to a JSONL string."""
return self.json() + "\n"
def generate_embeddings(embed_model: Any, text: str) -> List[float]:
# Generate embeddings for a list of texts
return embed_model.get_text_embedding(text)
def generate_hard_negatives(
queries: List[str],
relevant_contexts: List[str],
embed_model: Optional[Any],
num_negatives: int = 5,
method: str = "random",
) -> Any:
hard_negatives = []
if method == "cosine_similarity":
query_embeddings = [
generate_embeddings(embed_model, query) for query in queries
]
relevant_contexts_embeddings = [
generate_embeddings(embed_model, context) for context in relevant_contexts
]
for query_index, _ in enumerate(queries):
if method == "random":
# Exclude the correct context
potential_negatives = (
relevant_contexts[:query_index] + relevant_contexts[query_index + 1 :]
)
# Randomly select hard negatives
hard_negatives.append(
random.sample(
potential_negatives, min(num_negatives, len(potential_negatives))
)
)
elif method == "cosine_similarity":
query_embedding = query_embeddings[query_index]
# Use get_top_k_embeddings to select num_negatives closest but not correct contexts
_, relevant_contexts_indices = get_top_k_embeddings(
query_embedding,
relevant_contexts_embeddings,
)
# Filter out the correct context to only include hard negatives
hard_negative_indices = [
idx for idx in relevant_contexts_indices if idx != query_index
][:num_negatives]
# Map indices to actual contexts to get the hard negatives
hard_negatives_for_query = [
relevant_contexts[idx] for idx in hard_negative_indices
]
hard_negatives.append(hard_negatives_for_query)
return hard_negatives
def get_query_context_lists(
query_context_pairs: EmbeddingQAFinetuneDataset,
) -> Tuple[List[str], List[str]]:
queries = []
relevant_contexts = []
# 'query_context_pairs' is an object with 'queries', 'corpus', and 'relevant_docs' attributes
for query_id, query in query_context_pairs.queries.items():
# Get the first relevant document ID for the current query
relevant_doc_id = query_context_pairs.relevant_docs[query_id][0]
# Get the relevant context using the relevant document ID
relevant_context = query_context_pairs.corpus[relevant_doc_id]
# Append the query and the relevant context to their respective lists
queries.append(query)
relevant_contexts.append(relevant_context)
return queries, relevant_contexts
def generate_cohere_reranker_finetuning_dataset(
query_context_pairs: EmbeddingQAFinetuneDataset,
num_negatives: int = 0,
top_k_dissimilar: int = 100,
hard_negatives_gen_method: str = "random",
finetune_dataset_file_name: str = "train.jsonl",
embed_model: Optional[Any] = None,
) -> Any:
queries, relevant_contexts = get_query_context_lists(query_context_pairs)
if num_negatives:
hard_negatives = generate_hard_negatives(
queries,
relevant_contexts,
embed_model,
num_negatives,
hard_negatives_gen_method,
)
else:
hard_negatives = [[] for _ in queries]
# Open the file in write mode
with open(finetune_dataset_file_name, "w") as outfile:
# Iterate over the lists simultaneously using zip
for query, context, hard_negative in zip(
queries, relevant_contexts, hard_negatives
):
# Instantiate a CohereRerankerFinetuneDataset object for the current entry
entry = CohereRerankerFinetuneDataset(
query=query, relevant_passages=[context], hard_negatives=hard_negative
)
# Write the JSONL string to the file
outfile.write(entry.to_jsonl())