faiss_rag_enterprise/llama_index/embeddings/adapter.py

117 lines
4.1 KiB
Python

"""Embedding adapter model."""
import logging
from typing import Any, List, Optional, Type, cast
from llama_index.bridge.pydantic import PrivateAttr
from llama_index.callbacks import CallbackManager
from llama_index.constants import DEFAULT_EMBED_BATCH_SIZE
from llama_index.core.embeddings.base import BaseEmbedding
from llama_index.utils import infer_torch_device
logger = logging.getLogger(__name__)
class AdapterEmbeddingModel(BaseEmbedding):
"""Adapter for any embedding model.
This is a wrapper around any embedding model that adds an adapter layer \
on top of it.
This is useful for finetuning an embedding model on a downstream task.
The embedding model can be any model - it does not need to expose gradients.
Args:
base_embed_model (BaseEmbedding): Base embedding model.
adapter_path (str): Path to adapter.
adapter_cls (Optional[Type[Any]]): Adapter class. Defaults to None, in which \
case a linear adapter is used.
transform_query (bool): Whether to transform query embeddings. Defaults to True.
device (Optional[str]): Device to use. Defaults to None.
embed_batch_size (int): Batch size for embedding. Defaults to 10.
callback_manager (Optional[CallbackManager]): Callback manager. \
Defaults to None.
"""
_base_embed_model: BaseEmbedding = PrivateAttr()
_adapter: Any = PrivateAttr()
_transform_query: bool = PrivateAttr()
_device: Optional[str] = PrivateAttr()
_target_device: Any = PrivateAttr()
def __init__(
self,
base_embed_model: BaseEmbedding,
adapter_path: str,
adapter_cls: Optional[Type[Any]] = None,
transform_query: bool = True,
device: Optional[str] = None,
embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE,
callback_manager: Optional[CallbackManager] = None,
) -> None:
"""Init params."""
import torch
from llama_index.embeddings.adapter_utils import BaseAdapter, LinearLayer
if device is None:
device = infer_torch_device()
logger.info(f"Use pytorch device: {device}")
self._target_device = torch.device(device)
self._base_embed_model = base_embed_model
if adapter_cls is None:
adapter_cls = LinearLayer
else:
adapter_cls = cast(Type[BaseAdapter], adapter_cls)
adapter = adapter_cls.load(adapter_path)
self._adapter = cast(BaseAdapter, adapter)
self._adapter.to(self._target_device)
self._transform_query = transform_query
super().__init__(
embed_batch_size=embed_batch_size,
callback_manager=callback_manager,
model_name=f"Adapter for {base_embed_model.model_name}",
)
@classmethod
def class_name(cls) -> str:
return "AdapterEmbeddingModel"
def _get_query_embedding(self, query: str) -> List[float]:
"""Get query embedding."""
import torch
query_embedding = self._base_embed_model._get_query_embedding(query)
if self._transform_query:
query_embedding_t = torch.tensor(query_embedding).to(self._target_device)
query_embedding_t = self._adapter.forward(query_embedding_t)
query_embedding = query_embedding_t.tolist()
return query_embedding
async def _aget_query_embedding(self, query: str) -> List[float]:
"""Get query embedding."""
import torch
query_embedding = await self._base_embed_model._aget_query_embedding(query)
if self._transform_query:
query_embedding_t = torch.tensor(query_embedding).to(self._target_device)
query_embedding_t = self._adapter.forward(query_embedding_t)
query_embedding = query_embedding_t.tolist()
return query_embedding
def _get_text_embedding(self, text: str) -> List[float]:
return self._base_embed_model._get_text_embedding(text)
async def _aget_text_embedding(self, text: str) -> List[float]:
return await self._base_embed_model._aget_text_embedding(text)
# Maintain for backwards compatibility
LinearAdapterEmbeddingModel = AdapterEmbeddingModel