faiss_rag_enterprise/llama_index/query_engine/multi_modal.py

233 lines
8.8 KiB
Python

from typing import Any, Dict, List, Optional, Sequence, Tuple
from llama_index.callbacks.base import CallbackManager
from llama_index.callbacks.schema import CBEventType, EventPayload
from llama_index.core.response.schema import RESPONSE_TYPE, Response
from llama_index.indices.multi_modal import MultiModalVectorIndexRetriever
from llama_index.indices.query.base import BaseQueryEngine
from llama_index.indices.query.schema import QueryBundle, QueryType
from llama_index.multi_modal_llms.base import MultiModalLLM
from llama_index.multi_modal_llms.openai import OpenAIMultiModal
from llama_index.postprocessor.types import BaseNodePostprocessor
from llama_index.prompts import BasePromptTemplate
from llama_index.prompts.default_prompts import DEFAULT_TEXT_QA_PROMPT
from llama_index.prompts.mixin import PromptMixinType
from llama_index.schema import ImageNode, NodeWithScore
def _get_image_and_text_nodes(
nodes: List[NodeWithScore],
) -> Tuple[List[NodeWithScore], List[NodeWithScore]]:
image_nodes = []
text_nodes = []
for res_node in nodes:
if isinstance(res_node.node, ImageNode):
image_nodes.append(res_node)
else:
text_nodes.append(res_node)
return image_nodes, text_nodes
class SimpleMultiModalQueryEngine(BaseQueryEngine):
"""Simple Multi Modal Retriever query engine.
Assumes that retrieved text context fits within context window of LLM, along with images.
Args:
retriever (MultiModalVectorIndexRetriever): A retriever object.
multi_modal_llm (Optional[MultiModalLLM]): MultiModalLLM Models.
text_qa_template (Optional[BasePromptTemplate]): Text QA Prompt Template.
image_qa_template (Optional[BasePromptTemplate]): Image QA Prompt Template.
node_postprocessors (Optional[List[BaseNodePostprocessor]]): Node Postprocessors.
callback_manager (Optional[CallbackManager]): A callback manager.
"""
def __init__(
self,
retriever: MultiModalVectorIndexRetriever,
multi_modal_llm: Optional[MultiModalLLM] = None,
text_qa_template: Optional[BasePromptTemplate] = None,
image_qa_template: Optional[BasePromptTemplate] = None,
node_postprocessors: Optional[List[BaseNodePostprocessor]] = None,
callback_manager: Optional[CallbackManager] = None,
**kwargs: Any,
) -> None:
self._retriever = retriever
self._multi_modal_llm = multi_modal_llm or OpenAIMultiModal(
model="gpt-4-vision-preview", max_new_tokens=1000
)
self._text_qa_template = text_qa_template or DEFAULT_TEXT_QA_PROMPT
self._image_qa_template = image_qa_template or DEFAULT_TEXT_QA_PROMPT
self._node_postprocessors = node_postprocessors or []
callback_manager = callback_manager or CallbackManager([])
for node_postprocessor in self._node_postprocessors:
node_postprocessor.callback_manager = callback_manager
super().__init__(callback_manager)
def _get_prompts(self) -> Dict[str, Any]:
"""Get prompts."""
return {"text_qa_template": self._text_qa_template}
def _get_prompt_modules(self) -> PromptMixinType:
"""Get prompt sub-modules."""
return {}
def _apply_node_postprocessors(
self, nodes: List[NodeWithScore], query_bundle: QueryBundle
) -> List[NodeWithScore]:
for node_postprocessor in self._node_postprocessors:
nodes = node_postprocessor.postprocess_nodes(
nodes, query_bundle=query_bundle
)
return nodes
def retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
nodes = self._retriever.retrieve(query_bundle)
return self._apply_node_postprocessors(nodes, query_bundle=query_bundle)
async def aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
nodes = await self._retriever.aretrieve(query_bundle)
return self._apply_node_postprocessors(nodes, query_bundle=query_bundle)
def synthesize(
self,
query_bundle: QueryBundle,
nodes: List[NodeWithScore],
additional_source_nodes: Optional[Sequence[NodeWithScore]] = None,
) -> RESPONSE_TYPE:
image_nodes, text_nodes = _get_image_and_text_nodes(nodes)
context_str = "\n\n".join([r.get_content() for r in text_nodes])
fmt_prompt = self._text_qa_template.format(
context_str=context_str, query_str=query_bundle.query_str
)
llm_response = self._multi_modal_llm.complete(
prompt=fmt_prompt,
image_documents=[image_node.node for image_node in image_nodes],
)
return Response(
response=str(llm_response),
source_nodes=nodes,
metadata={"text_nodes": text_nodes, "image_nodes": image_nodes},
)
def _get_response_with_images(
self,
prompt_str: str,
image_nodes: List[ImageNode],
) -> RESPONSE_TYPE:
fmt_prompt = self._image_qa_template.format(
query_str=prompt_str,
)
llm_response = self._multi_modal_llm.complete(
prompt=fmt_prompt,
image_documents=[image_node.node for image_node in image_nodes],
)
return Response(
response=str(llm_response),
source_nodes=image_nodes,
metadata={"image_nodes": image_nodes},
)
async def asynthesize(
self,
query_bundle: QueryBundle,
nodes: List[NodeWithScore],
additional_source_nodes: Optional[Sequence[NodeWithScore]] = None,
) -> RESPONSE_TYPE:
image_nodes, text_nodes = _get_image_and_text_nodes(nodes)
context_str = "\n\n".join([r.get_content() for r in text_nodes])
fmt_prompt = self._text_qa_template.format(
context_str=context_str, query_str=query_bundle.query_str
)
llm_response = await self._multi_modal_llm.acomplete(
prompt=fmt_prompt,
image_documents=image_nodes,
)
return Response(
response=str(llm_response),
source_nodes=nodes,
metadata={"text_nodes": text_nodes, "image_nodes": image_nodes},
)
def _query(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
"""Answer a query."""
with self.callback_manager.event(
CBEventType.QUERY, payload={EventPayload.QUERY_STR: query_bundle.query_str}
) as query_event:
with self.callback_manager.event(
CBEventType.RETRIEVE,
payload={EventPayload.QUERY_STR: query_bundle.query_str},
) as retrieve_event:
nodes = self.retrieve(query_bundle)
retrieve_event.on_end(
payload={EventPayload.NODES: nodes},
)
response = self.synthesize(
query_bundle,
nodes=nodes,
)
query_event.on_end(payload={EventPayload.RESPONSE: response})
return response
def image_query(self, image_path: QueryType, prompt_str: str) -> RESPONSE_TYPE:
"""Answer a image query."""
with self.callback_manager.event(
CBEventType.QUERY, payload={EventPayload.QUERY_STR: str(image_path)}
) as query_event:
with self.callback_manager.event(
CBEventType.RETRIEVE,
payload={EventPayload.QUERY_STR: str(image_path)},
) as retrieve_event:
nodes = self._retriever.image_to_image_retrieve(image_path)
retrieve_event.on_end(
payload={EventPayload.NODES: nodes},
)
image_nodes, _ = _get_image_and_text_nodes(nodes)
response = self._get_response_with_images(
prompt_str=prompt_str,
image_nodes=image_nodes,
)
query_event.on_end(payload={EventPayload.RESPONSE: response})
return response
async def _aquery(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
"""Answer a query."""
with self.callback_manager.event(
CBEventType.QUERY, payload={EventPayload.QUERY_STR: query_bundle.query_str}
) as query_event:
with self.callback_manager.event(
CBEventType.RETRIEVE,
payload={EventPayload.QUERY_STR: query_bundle.query_str},
) as retrieve_event:
nodes = await self.aretrieve(query_bundle)
retrieve_event.on_end(
payload={EventPayload.NODES: nodes},
)
response = await self.asynthesize(
query_bundle,
nodes=nodes,
)
query_event.on_end(payload={EventPayload.RESPONSE: response})
return response
@property
def retriever(self) -> MultiModalVectorIndexRetriever:
"""Get the retriever object."""
return self._retriever