faiss_rag_enterprise/llama_index/embeddings/google_palm.py

80 lines
2.6 KiB
Python

"""Google PaLM embeddings file."""
from typing import Any, List, Optional
from llama_index.bridge.pydantic import PrivateAttr
from llama_index.callbacks.base import CallbackManager
from llama_index.core.embeddings.base import DEFAULT_EMBED_BATCH_SIZE, BaseEmbedding
class GooglePaLMEmbedding(BaseEmbedding):
"""Class for Google PaLM embeddings.
Args:
model_name (str): Model for embedding.
Defaults to "models/embedding-gecko-001".
api_key (Optional[str]): API key to access the model. Defaults to None.
"""
_model: Any = PrivateAttr()
def __init__(
self,
model_name: str = "models/embedding-gecko-001",
api_key: Optional[str] = None,
embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE,
callback_manager: Optional[CallbackManager] = None,
**kwargs: Any,
):
try:
import google.generativeai as palm
except ImportError:
raise ImportError(
"google-generativeai package not found, install with"
"'pip install google-generativeai'"
)
palm.configure(api_key=api_key)
self._model = palm
super().__init__(
model_name=model_name,
embed_batch_size=embed_batch_size,
callback_manager=callback_manager,
**kwargs,
)
@classmethod
def class_name(cls) -> str:
return "PaLMEmbedding"
def _get_query_embedding(self, query: str) -> List[float]:
"""Get query embedding."""
return self._model.generate_embeddings(model=self.model_name, text=query)[
"embedding"
]
async def _aget_query_embedding(self, query: str) -> List[float]:
"""The asynchronous version of _get_query_embedding."""
return await self._model.aget_embedding(query)
def _get_text_embedding(self, text: str) -> List[float]:
"""Get text embedding."""
return self._model.generate_embeddings(model=self.model_name, text=text)[
"embedding"
]
async def _aget_text_embedding(self, text: str) -> List[float]:
"""Asynchronously get text embedding."""
return self._model._get_text_embedding(text)
def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
"""Get text embeddings."""
return self._model.generate_embeddings(model=self.model_name, text=texts)[
"embedding"
]
async def _aget_text_embeddings(self, texts: List[str]) -> List[List[float]]:
"""Asynchronously get text embeddings."""
return await self._model._get_embeddings(texts)