190 lines
6.7 KiB
Python
190 lines
6.7 KiB
Python
"""Base embeddings file."""
|
|
|
|
import asyncio
|
|
from abc import abstractmethod
|
|
from typing import Coroutine, List, Tuple
|
|
|
|
from llama_index.callbacks.schema import CBEventType, EventPayload
|
|
from llama_index.core.embeddings.base import (
|
|
BaseEmbedding,
|
|
Embedding,
|
|
)
|
|
from llama_index.schema import ImageType
|
|
from llama_index.utils import get_tqdm_iterable
|
|
|
|
|
|
class MultiModalEmbedding(BaseEmbedding):
|
|
"""Base class for Multi Modal embeddings."""
|
|
|
|
@abstractmethod
|
|
def _get_image_embedding(self, img_file_path: ImageType) -> Embedding:
|
|
"""
|
|
Embed the input image synchronously.
|
|
|
|
Subclasses should implement this method. Reference get_image_embedding's
|
|
docstring for more information.
|
|
"""
|
|
|
|
@abstractmethod
|
|
async def _aget_image_embedding(self, img_file_path: ImageType) -> Embedding:
|
|
"""
|
|
Embed the input image asynchronously.
|
|
|
|
Subclasses should implement this method. Reference get_image_embedding's
|
|
docstring for more information.
|
|
"""
|
|
|
|
def get_image_embedding(self, img_file_path: ImageType) -> Embedding:
|
|
"""
|
|
Embed the input image.
|
|
"""
|
|
with self.callback_manager.event(
|
|
CBEventType.EMBEDDING, payload={EventPayload.SERIALIZED: self.to_dict()}
|
|
) as event:
|
|
image_embedding = self._get_image_embedding(img_file_path)
|
|
|
|
event.on_end(
|
|
payload={
|
|
EventPayload.CHUNKS: [img_file_path],
|
|
EventPayload.EMBEDDINGS: [image_embedding],
|
|
},
|
|
)
|
|
return image_embedding
|
|
|
|
async def aget_image_embedding(self, img_file_path: ImageType) -> Embedding:
|
|
"""Get image embedding."""
|
|
with self.callback_manager.event(
|
|
CBEventType.EMBEDDING, payload={EventPayload.SERIALIZED: self.to_dict()}
|
|
) as event:
|
|
image_embedding = await self._aget_image_embedding(img_file_path)
|
|
|
|
event.on_end(
|
|
payload={
|
|
EventPayload.CHUNKS: [img_file_path],
|
|
EventPayload.EMBEDDINGS: [image_embedding],
|
|
},
|
|
)
|
|
return image_embedding
|
|
|
|
def _get_image_embeddings(self, img_file_paths: List[ImageType]) -> List[Embedding]:
|
|
"""
|
|
Embed the input sequence of image synchronously.
|
|
|
|
Subclasses can implement this method if batch queries are supported.
|
|
"""
|
|
# Default implementation just loops over _get_image_embedding
|
|
return [
|
|
self._get_image_embedding(img_file_path) for img_file_path in img_file_paths
|
|
]
|
|
|
|
async def _aget_image_embeddings(
|
|
self, img_file_paths: List[ImageType]
|
|
) -> List[Embedding]:
|
|
"""
|
|
Embed the input sequence of image asynchronously.
|
|
|
|
Subclasses can implement this method if batch queries are supported.
|
|
"""
|
|
return await asyncio.gather(
|
|
*[
|
|
self._aget_image_embedding(img_file_path)
|
|
for img_file_path in img_file_paths
|
|
]
|
|
)
|
|
|
|
def get_image_embedding_batch(
|
|
self, img_file_paths: List[ImageType], show_progress: bool = False
|
|
) -> List[Embedding]:
|
|
"""Get a list of image embeddings, with batching."""
|
|
cur_batch: List[ImageType] = []
|
|
result_embeddings: List[Embedding] = []
|
|
|
|
queue_with_progress = enumerate(
|
|
get_tqdm_iterable(
|
|
img_file_paths, show_progress, "Generating image embeddings"
|
|
)
|
|
)
|
|
|
|
for idx, img_file_path in queue_with_progress:
|
|
cur_batch.append(img_file_path)
|
|
if (
|
|
idx == len(img_file_paths) - 1
|
|
or len(cur_batch) == self.embed_batch_size
|
|
):
|
|
# flush
|
|
with self.callback_manager.event(
|
|
CBEventType.EMBEDDING,
|
|
payload={EventPayload.SERIALIZED: self.to_dict()},
|
|
) as event:
|
|
embeddings = self._get_image_embeddings(cur_batch)
|
|
result_embeddings.extend(embeddings)
|
|
event.on_end(
|
|
payload={
|
|
EventPayload.CHUNKS: cur_batch,
|
|
EventPayload.EMBEDDINGS: embeddings,
|
|
},
|
|
)
|
|
cur_batch = []
|
|
|
|
return result_embeddings
|
|
|
|
async def aget_image_embedding_batch(
|
|
self, img_file_paths: List[ImageType], show_progress: bool = False
|
|
) -> List[Embedding]:
|
|
"""Asynchronously get a list of image embeddings, with batching."""
|
|
cur_batch: List[ImageType] = []
|
|
callback_payloads: List[Tuple[str, List[ImageType]]] = []
|
|
result_embeddings: List[Embedding] = []
|
|
embeddings_coroutines: List[Coroutine] = []
|
|
for idx, img_file_path in enumerate(img_file_paths):
|
|
cur_batch.append(img_file_path)
|
|
if (
|
|
idx == len(img_file_paths) - 1
|
|
or len(cur_batch) == self.embed_batch_size
|
|
):
|
|
# flush
|
|
event_id = self.callback_manager.on_event_start(
|
|
CBEventType.EMBEDDING,
|
|
payload={EventPayload.SERIALIZED: self.to_dict()},
|
|
)
|
|
callback_payloads.append((event_id, cur_batch))
|
|
embeddings_coroutines.append(self._aget_image_embeddings(cur_batch))
|
|
cur_batch = []
|
|
|
|
# flatten the results of asyncio.gather, which is a list of embeddings lists
|
|
nested_embeddings = []
|
|
if show_progress:
|
|
try:
|
|
from tqdm.auto import tqdm
|
|
|
|
nested_embeddings = [
|
|
await f
|
|
for f in tqdm(
|
|
asyncio.as_completed(embeddings_coroutines),
|
|
total=len(embeddings_coroutines),
|
|
desc="Generating image embeddings",
|
|
)
|
|
]
|
|
except ImportError:
|
|
nested_embeddings = await asyncio.gather(*embeddings_coroutines)
|
|
else:
|
|
nested_embeddings = await asyncio.gather(*embeddings_coroutines)
|
|
|
|
result_embeddings = [
|
|
embedding for embeddings in nested_embeddings for embedding in embeddings
|
|
]
|
|
|
|
for (event_id, image_batch), embeddings in zip(
|
|
callback_payloads, nested_embeddings
|
|
):
|
|
self.callback_manager.on_event_end(
|
|
CBEventType.EMBEDDING,
|
|
payload={
|
|
EventPayload.CHUNKS: image_batch,
|
|
EventPayload.EMBEDDINGS: embeddings,
|
|
},
|
|
event_id=event_id,
|
|
)
|
|
|
|
return result_embeddings
|