faiss_rag_enterprise/llama_index/finetuning/rerankers/cohere_reranker.py

78 lines
2.7 KiB
Python

"""Cohere Reranker Finetuning Engine."""
import importlib.util
import os
from typing import Optional
from llama_index.finetuning.types import BaseCohereRerankerFinetuningEngine
from llama_index.indices.postprocessor import CohereRerank
class CohereRerankerFinetuneEngine(BaseCohereRerankerFinetuningEngine):
"""Cohere Reranker Finetune Engine."""
def __init__(
self,
train_file_name: str = "train.jsonl",
val_file_name: Optional[str] = None,
model_name: str = "exp_finetune",
model_type: str = "RERANK",
base_model: str = "english",
api_key: Optional[str] = None,
) -> None:
"""Init params."""
# This will be None if 'cohere' module is not available
cohere_spec = importlib.util.find_spec("cohere")
if cohere_spec is not None:
import cohere
else:
# Raise an ImportError if 'cohere' is not installed
raise ImportError(
"Cannot import cohere. Please install the package using `pip install cohere`."
)
try:
self.api_key = api_key or os.environ["COHERE_API_KEY"]
except IndexError:
raise ValueError(
"Must pass in cohere api key or "
"specify via COHERE_API_KEY environment variable "
)
self._model = cohere.Client(self.api_key, client_name="llama_index")
self._train_file_name = train_file_name
self._val_file_name = val_file_name
self._model_name = model_name
self._model_type = model_type
self._base_model = base_model
self._finetune_model = None
def finetune(self) -> None:
"""Finetune model."""
from cohere.custom_model_dataset import JsonlDataset
if self._val_file_name:
# Uploading both train file and eval file
dataset = JsonlDataset(
train_file=self._train_file_name, eval_file=self._val_file_name
)
else:
# Single Train File Upload:
dataset = JsonlDataset(train_file=self._train_file_name)
self._finetune_model = self._model.create_custom_model(
name=self._model_name,
dataset=dataset,
model_type=self._model_type,
base_model=self._base_model,
)
def get_finetuned_model(self, top_n: int = 5) -> CohereRerank:
"""Gets finetuned model id."""
if self._finetune_model is None:
raise RuntimeError(
"Finetuned model is not set yet. Please run the finetune method first."
)
return CohereRerank(
model=self._finetune_model.id, top_n=top_n, api_key=self.api_key
)