117 lines
4.1 KiB
Python
117 lines
4.1 KiB
Python
"""Embedding adapter model."""
|
|
|
|
import logging
|
|
from typing import Any, List, Optional, Type, cast
|
|
|
|
from llama_index.bridge.pydantic import PrivateAttr
|
|
from llama_index.callbacks import CallbackManager
|
|
from llama_index.constants import DEFAULT_EMBED_BATCH_SIZE
|
|
from llama_index.core.embeddings.base import BaseEmbedding
|
|
from llama_index.utils import infer_torch_device
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class AdapterEmbeddingModel(BaseEmbedding):
|
|
"""Adapter for any embedding model.
|
|
|
|
This is a wrapper around any embedding model that adds an adapter layer \
|
|
on top of it.
|
|
This is useful for finetuning an embedding model on a downstream task.
|
|
The embedding model can be any model - it does not need to expose gradients.
|
|
|
|
Args:
|
|
base_embed_model (BaseEmbedding): Base embedding model.
|
|
adapter_path (str): Path to adapter.
|
|
adapter_cls (Optional[Type[Any]]): Adapter class. Defaults to None, in which \
|
|
case a linear adapter is used.
|
|
transform_query (bool): Whether to transform query embeddings. Defaults to True.
|
|
device (Optional[str]): Device to use. Defaults to None.
|
|
embed_batch_size (int): Batch size for embedding. Defaults to 10.
|
|
callback_manager (Optional[CallbackManager]): Callback manager. \
|
|
Defaults to None.
|
|
|
|
"""
|
|
|
|
_base_embed_model: BaseEmbedding = PrivateAttr()
|
|
_adapter: Any = PrivateAttr()
|
|
_transform_query: bool = PrivateAttr()
|
|
_device: Optional[str] = PrivateAttr()
|
|
_target_device: Any = PrivateAttr()
|
|
|
|
def __init__(
|
|
self,
|
|
base_embed_model: BaseEmbedding,
|
|
adapter_path: str,
|
|
adapter_cls: Optional[Type[Any]] = None,
|
|
transform_query: bool = True,
|
|
device: Optional[str] = None,
|
|
embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE,
|
|
callback_manager: Optional[CallbackManager] = None,
|
|
) -> None:
|
|
"""Init params."""
|
|
import torch
|
|
|
|
from llama_index.embeddings.adapter_utils import BaseAdapter, LinearLayer
|
|
|
|
if device is None:
|
|
device = infer_torch_device()
|
|
logger.info(f"Use pytorch device: {device}")
|
|
self._target_device = torch.device(device)
|
|
|
|
self._base_embed_model = base_embed_model
|
|
|
|
if adapter_cls is None:
|
|
adapter_cls = LinearLayer
|
|
else:
|
|
adapter_cls = cast(Type[BaseAdapter], adapter_cls)
|
|
|
|
adapter = adapter_cls.load(adapter_path)
|
|
self._adapter = cast(BaseAdapter, adapter)
|
|
self._adapter.to(self._target_device)
|
|
|
|
self._transform_query = transform_query
|
|
super().__init__(
|
|
embed_batch_size=embed_batch_size,
|
|
callback_manager=callback_manager,
|
|
model_name=f"Adapter for {base_embed_model.model_name}",
|
|
)
|
|
|
|
@classmethod
|
|
def class_name(cls) -> str:
|
|
return "AdapterEmbeddingModel"
|
|
|
|
def _get_query_embedding(self, query: str) -> List[float]:
|
|
"""Get query embedding."""
|
|
import torch
|
|
|
|
query_embedding = self._base_embed_model._get_query_embedding(query)
|
|
if self._transform_query:
|
|
query_embedding_t = torch.tensor(query_embedding).to(self._target_device)
|
|
query_embedding_t = self._adapter.forward(query_embedding_t)
|
|
query_embedding = query_embedding_t.tolist()
|
|
|
|
return query_embedding
|
|
|
|
async def _aget_query_embedding(self, query: str) -> List[float]:
|
|
"""Get query embedding."""
|
|
import torch
|
|
|
|
query_embedding = await self._base_embed_model._aget_query_embedding(query)
|
|
if self._transform_query:
|
|
query_embedding_t = torch.tensor(query_embedding).to(self._target_device)
|
|
query_embedding_t = self._adapter.forward(query_embedding_t)
|
|
query_embedding = query_embedding_t.tolist()
|
|
|
|
return query_embedding
|
|
|
|
def _get_text_embedding(self, text: str) -> List[float]:
|
|
return self._base_embed_model._get_text_embedding(text)
|
|
|
|
async def _aget_text_embedding(self, text: str) -> List[float]:
|
|
return await self._base_embed_model._aget_text_embedding(text)
|
|
|
|
|
|
# Maintain for backwards compatibility
|
|
LinearAdapterEmbeddingModel = AdapterEmbeddingModel
|