"""Utils for jupyter notebook.""" import os from io import BytesIO from typing import Any, Dict, List, Tuple import matplotlib.pyplot as plt import requests from IPython.display import Markdown, display from PIL import Image from llama_index.core.response.schema import Response from llama_index.img_utils import b64_2_img from llama_index.schema import ImageNode, MetadataMode, NodeWithScore from llama_index.utils import truncate_text DEFAULT_THUMBNAIL_SIZE = (512, 512) DEFAULT_IMAGE_MATRIX = (3, 3) DEFAULT_SHOW_TOP_K = 3 def display_image(img_str: str, size: Tuple[int, int] = DEFAULT_THUMBNAIL_SIZE) -> None: """Display base64 encoded image str as image for jupyter notebook.""" img = b64_2_img(img_str) img.thumbnail(size) display(img) def display_image_uris( image_paths: List[str], image_matrix: Tuple[int, int] = DEFAULT_IMAGE_MATRIX, top_k: int = DEFAULT_SHOW_TOP_K, ) -> None: """Display base64 encoded image str as image for jupyter notebook.""" images_shown = 0 plt.figure(figsize=(16, 9)) for img_path in image_paths[:top_k]: if os.path.isfile(img_path): image = Image.open(img_path) plt.subplot(image_matrix[0], image_matrix[1], images_shown + 1) plt.imshow(image) plt.xticks([]) plt.yticks([]) images_shown += 1 if images_shown >= image_matrix[0] * image_matrix[1]: break def display_source_node( source_node: NodeWithScore, source_length: int = 100, show_source_metadata: bool = False, metadata_mode: MetadataMode = MetadataMode.NONE, ) -> None: """Display source node for jupyter notebook.""" source_text_fmt = truncate_text( source_node.node.get_content(metadata_mode=metadata_mode).strip(), source_length ) text_md = ( f"**Node ID:** {source_node.node.node_id}
" f"**Similarity:** {source_node.score}
" f"**Text:** {source_text_fmt}
" ) if show_source_metadata: text_md += f"**Metadata:** {source_node.node.metadata}
" if isinstance(source_node.node, ImageNode): text_md += "**Image:**" display(Markdown(text_md)) if isinstance(source_node.node, ImageNode) and source_node.node.image is not None: display_image(source_node.node.image) def display_metadata(metadata: Dict[str, Any]) -> None: """Display metadata for jupyter notebook.""" display(metadata) def display_response( response: Response, source_length: int = 100, show_source: bool = False, show_metadata: bool = False, show_source_metadata: bool = False, ) -> None: """Display response for jupyter notebook.""" if response.response is None: response_text = "None" else: response_text = response.response.strip() display(Markdown(f"**`Final Response:`** {response_text}")) if show_source: for ind, source_node in enumerate(response.source_nodes): display(Markdown("---")) display( Markdown(f"**`Source Node {ind + 1}/{len(response.source_nodes)}`**") ) display_source_node( source_node, source_length=source_length, show_source_metadata=show_source_metadata, ) if show_metadata: if response.metadata is not None: display_metadata(response.metadata) def display_query_and_multimodal_response( query_str: str, response: Response, plot_height: int = 2, plot_width: int = 5 ) -> None: """For displaying a query and its multi-modal response.""" if response.metadata: image_nodes = response.metadata["image_nodes"] or [] else: image_nodes = [] num_subplots = len(image_nodes) f, axarr = plt.subplots(1, num_subplots) f.set_figheight(plot_height) f.set_figwidth(plot_width) ix = 0 for ix, scored_img_node in enumerate(image_nodes): img_node = scored_img_node.node image = None if img_node.image_url: img_response = requests.get(img_node.image_url) image = Image.open(BytesIO(img_response.content)) elif img_node.image_path: image = Image.open(img_node.image_path).convert("RGB") else: raise ValueError( "A retrieved image must have image_path or image_url specified." ) if num_subplots > 1: axarr[ix].imshow(image) axarr[ix].set_title(f"Retrieved Position: {ix}", pad=10, fontsize=9) else: axarr.imshow(image) axarr.set_title(f"Retrieved Position: {ix}", pad=10, fontsize=9) f.tight_layout() print(f"Query: {query_str}\n=======") print(f"Retrieved Images:\n") plt.show() print("=======") print(f"Response: {response.response}\n=======\n")