from typing import Any, Dict, List, Optional, Sequence, Tuple from llama_index.callbacks.base import CallbackManager from llama_index.callbacks.schema import CBEventType, EventPayload from llama_index.core.response.schema import RESPONSE_TYPE, Response from llama_index.indices.multi_modal import MultiModalVectorIndexRetriever from llama_index.indices.query.base import BaseQueryEngine from llama_index.indices.query.schema import QueryBundle, QueryType from llama_index.multi_modal_llms.base import MultiModalLLM from llama_index.multi_modal_llms.openai import OpenAIMultiModal from llama_index.postprocessor.types import BaseNodePostprocessor from llama_index.prompts import BasePromptTemplate from llama_index.prompts.default_prompts import DEFAULT_TEXT_QA_PROMPT from llama_index.prompts.mixin import PromptMixinType from llama_index.schema import ImageNode, NodeWithScore def _get_image_and_text_nodes( nodes: List[NodeWithScore], ) -> Tuple[List[NodeWithScore], List[NodeWithScore]]: image_nodes = [] text_nodes = [] for res_node in nodes: if isinstance(res_node.node, ImageNode): image_nodes.append(res_node) else: text_nodes.append(res_node) return image_nodes, text_nodes class SimpleMultiModalQueryEngine(BaseQueryEngine): """Simple Multi Modal Retriever query engine. Assumes that retrieved text context fits within context window of LLM, along with images. Args: retriever (MultiModalVectorIndexRetriever): A retriever object. multi_modal_llm (Optional[MultiModalLLM]): MultiModalLLM Models. text_qa_template (Optional[BasePromptTemplate]): Text QA Prompt Template. image_qa_template (Optional[BasePromptTemplate]): Image QA Prompt Template. node_postprocessors (Optional[List[BaseNodePostprocessor]]): Node Postprocessors. callback_manager (Optional[CallbackManager]): A callback manager. """ def __init__( self, retriever: MultiModalVectorIndexRetriever, multi_modal_llm: Optional[MultiModalLLM] = None, text_qa_template: Optional[BasePromptTemplate] = None, image_qa_template: Optional[BasePromptTemplate] = None, node_postprocessors: Optional[List[BaseNodePostprocessor]] = None, callback_manager: Optional[CallbackManager] = None, **kwargs: Any, ) -> None: self._retriever = retriever self._multi_modal_llm = multi_modal_llm or OpenAIMultiModal( model="gpt-4-vision-preview", max_new_tokens=1000 ) self._text_qa_template = text_qa_template or DEFAULT_TEXT_QA_PROMPT self._image_qa_template = image_qa_template or DEFAULT_TEXT_QA_PROMPT self._node_postprocessors = node_postprocessors or [] callback_manager = callback_manager or CallbackManager([]) for node_postprocessor in self._node_postprocessors: node_postprocessor.callback_manager = callback_manager super().__init__(callback_manager) def _get_prompts(self) -> Dict[str, Any]: """Get prompts.""" return {"text_qa_template": self._text_qa_template} def _get_prompt_modules(self) -> PromptMixinType: """Get prompt sub-modules.""" return {} def _apply_node_postprocessors( self, nodes: List[NodeWithScore], query_bundle: QueryBundle ) -> List[NodeWithScore]: for node_postprocessor in self._node_postprocessors: nodes = node_postprocessor.postprocess_nodes( nodes, query_bundle=query_bundle ) return nodes def retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: nodes = self._retriever.retrieve(query_bundle) return self._apply_node_postprocessors(nodes, query_bundle=query_bundle) async def aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: nodes = await self._retriever.aretrieve(query_bundle) return self._apply_node_postprocessors(nodes, query_bundle=query_bundle) def synthesize( self, query_bundle: QueryBundle, nodes: List[NodeWithScore], additional_source_nodes: Optional[Sequence[NodeWithScore]] = None, ) -> RESPONSE_TYPE: image_nodes, text_nodes = _get_image_and_text_nodes(nodes) context_str = "\n\n".join([r.get_content() for r in text_nodes]) fmt_prompt = self._text_qa_template.format( context_str=context_str, query_str=query_bundle.query_str ) llm_response = self._multi_modal_llm.complete( prompt=fmt_prompt, image_documents=[image_node.node for image_node in image_nodes], ) return Response( response=str(llm_response), source_nodes=nodes, metadata={"text_nodes": text_nodes, "image_nodes": image_nodes}, ) def _get_response_with_images( self, prompt_str: str, image_nodes: List[ImageNode], ) -> RESPONSE_TYPE: fmt_prompt = self._image_qa_template.format( query_str=prompt_str, ) llm_response = self._multi_modal_llm.complete( prompt=fmt_prompt, image_documents=[image_node.node for image_node in image_nodes], ) return Response( response=str(llm_response), source_nodes=image_nodes, metadata={"image_nodes": image_nodes}, ) async def asynthesize( self, query_bundle: QueryBundle, nodes: List[NodeWithScore], additional_source_nodes: Optional[Sequence[NodeWithScore]] = None, ) -> RESPONSE_TYPE: image_nodes, text_nodes = _get_image_and_text_nodes(nodes) context_str = "\n\n".join([r.get_content() for r in text_nodes]) fmt_prompt = self._text_qa_template.format( context_str=context_str, query_str=query_bundle.query_str ) llm_response = await self._multi_modal_llm.acomplete( prompt=fmt_prompt, image_documents=image_nodes, ) return Response( response=str(llm_response), source_nodes=nodes, metadata={"text_nodes": text_nodes, "image_nodes": image_nodes}, ) def _query(self, query_bundle: QueryBundle) -> RESPONSE_TYPE: """Answer a query.""" with self.callback_manager.event( CBEventType.QUERY, payload={EventPayload.QUERY_STR: query_bundle.query_str} ) as query_event: with self.callback_manager.event( CBEventType.RETRIEVE, payload={EventPayload.QUERY_STR: query_bundle.query_str}, ) as retrieve_event: nodes = self.retrieve(query_bundle) retrieve_event.on_end( payload={EventPayload.NODES: nodes}, ) response = self.synthesize( query_bundle, nodes=nodes, ) query_event.on_end(payload={EventPayload.RESPONSE: response}) return response def image_query(self, image_path: QueryType, prompt_str: str) -> RESPONSE_TYPE: """Answer a image query.""" with self.callback_manager.event( CBEventType.QUERY, payload={EventPayload.QUERY_STR: str(image_path)} ) as query_event: with self.callback_manager.event( CBEventType.RETRIEVE, payload={EventPayload.QUERY_STR: str(image_path)}, ) as retrieve_event: nodes = self._retriever.image_to_image_retrieve(image_path) retrieve_event.on_end( payload={EventPayload.NODES: nodes}, ) image_nodes, _ = _get_image_and_text_nodes(nodes) response = self._get_response_with_images( prompt_str=prompt_str, image_nodes=image_nodes, ) query_event.on_end(payload={EventPayload.RESPONSE: response}) return response async def _aquery(self, query_bundle: QueryBundle) -> RESPONSE_TYPE: """Answer a query.""" with self.callback_manager.event( CBEventType.QUERY, payload={EventPayload.QUERY_STR: query_bundle.query_str} ) as query_event: with self.callback_manager.event( CBEventType.RETRIEVE, payload={EventPayload.QUERY_STR: query_bundle.query_str}, ) as retrieve_event: nodes = await self.aretrieve(query_bundle) retrieve_event.on_end( payload={EventPayload.NODES: nodes}, ) response = await self.asynthesize( query_bundle, nodes=nodes, ) query_event.on_end(payload={EventPayload.RESPONSE: response}) return response @property def retriever(self) -> MultiModalVectorIndexRetriever: """Get the retriever object.""" return self._retriever