"""Sentence Transformer Finetuning Engine.""" import logging from typing import Any, List, Optional, Tuple, Type, cast from llama_index.embeddings.adapter import AdapterEmbeddingModel from llama_index.embeddings.base import BaseEmbedding from llama_index.finetuning.embeddings.common import EmbeddingQAFinetuneDataset from llama_index.finetuning.types import BaseEmbeddingFinetuneEngine from llama_index.utils import infer_torch_device logger = logging.getLogger(__name__) class EmbeddingAdapterFinetuneEngine(BaseEmbeddingFinetuneEngine): """Embedding adapter finetune engine. Args: dataset (EmbeddingQAFinetuneDataset): Dataset to finetune on. embed_model (BaseEmbedding): Embedding model to finetune. batch_size (Optional[int]): Batch size. Defaults to 10. epochs (Optional[int]): Number of epochs. Defaults to 1. dim (Optional[int]): Dimension of embedding. Defaults to None. adapter_model (Optional[BaseAdapter]): Adapter model. Defaults to None, in which case a linear adapter is used. device (Optional[str]): Device to use. Defaults to None. model_output_path (str): Path to save model output. Defaults to "model_output". model_checkpoint_path (Optional[str]): Path to save model checkpoints. Defaults to None (don't save checkpoints). verbose (bool): Whether to show progress bar. Defaults to False. bias (bool): Whether to use bias. Defaults to False. """ def __init__( self, dataset: EmbeddingQAFinetuneDataset, embed_model: BaseEmbedding, batch_size: int = 10, epochs: int = 1, adapter_model: Optional[Any] = None, dim: Optional[int] = None, device: Optional[str] = None, model_output_path: str = "model_output", model_checkpoint_path: Optional[str] = None, checkpoint_save_steps: int = 100, verbose: bool = False, bias: bool = False, **train_kwargs: Any, ) -> None: """Init params.""" import torch from llama_index.embeddings.adapter_utils import BaseAdapter, LinearLayer self.dataset = dataset self.embed_model = embed_model # HACK: get dimension by passing text through it if dim is None: test_embedding = self.embed_model.get_text_embedding("hello world") self.dim = len(test_embedding) else: self.dim = dim # load in data, run embedding model, define data loader self.batch_size = batch_size self.loader = self._get_data_loader(dataset) if device is None: device = infer_torch_device() logger.info(f"Use pytorch device: {device}") self._target_device = torch.device(device) if adapter_model is not None: self.model = cast(BaseAdapter, adapter_model) else: self.model = LinearLayer(self.dim, self.dim, bias=bias) self._model_output_path = model_output_path self._model_checkpoint_path = model_checkpoint_path self._checkpoint_save_steps = checkpoint_save_steps self._epochs = epochs self._warmup_steps = int(len(self.loader) * epochs * 0.1) self._train_kwargs = train_kwargs self._verbose = verbose @classmethod def from_model_path( cls, dataset: EmbeddingQAFinetuneDataset, embed_model: BaseEmbedding, model_path: str, model_cls: Optional[Type[Any]] = None, **kwargs: Any, ) -> "EmbeddingAdapterFinetuneEngine": """Load from model path. Args: dataset (EmbeddingQAFinetuneDataset): Dataset to finetune on. embed_model (BaseEmbedding): Embedding model to finetune. model_path (str): Path to model. model_cls (Optional[Type[Any]]): Adapter model class. Defaults to None. **kwargs (Any): Additional kwargs (see __init__) """ from llama_index.embeddings.adapter_utils import LinearLayer model_cls = model_cls or LinearLayer model = model_cls.load(model_path) return cls(dataset, embed_model, adapter_model=model, **kwargs) def smart_batching_collate(self, batch: List) -> Tuple[Any, Any]: """Smart batching collate.""" import torch from torch import Tensor query_embeddings: List[Tensor] = [] text_embeddings: List[Tensor] = [] for query, text in batch: query_embedding = self.embed_model.get_query_embedding(query) text_embedding = self.embed_model.get_text_embedding(text) query_embeddings.append(torch.tensor(query_embedding)) text_embeddings.append(torch.tensor(text_embedding)) query_embeddings_t = torch.stack(query_embeddings) text_embeddings_t = torch.stack(text_embeddings) return query_embeddings_t, text_embeddings_t def _get_data_loader(self, dataset: EmbeddingQAFinetuneDataset) -> Any: """Get data loader.""" from torch.utils.data import DataLoader examples: Any = [] for query_id, query in dataset.queries.items(): node_id = dataset.relevant_docs[query_id][0] text = dataset.corpus[node_id] examples.append((query, text)) data_loader = DataLoader(examples, batch_size=self.batch_size) data_loader.collate_fn = self.smart_batching_collate return data_loader def finetune(self, **train_kwargs: Any) -> None: """Finetune.""" from llama_index.finetuning.embeddings.adapter_utils import train_model # call model training train_model( self.model, self.loader, self._target_device, epochs=self._epochs, output_path=self._model_output_path, warmup_steps=self._warmup_steps, verbose=self._verbose, checkpoint_path=self._model_checkpoint_path, checkpoint_save_steps=self._checkpoint_save_steps, **self._train_kwargs, ) def get_finetuned_model(self, **model_kwargs: Any) -> BaseEmbedding: """Get finetuned model.""" return AdapterEmbeddingModel( self.embed_model, self._model_output_path, **model_kwargs )