import logging from dataclasses import dataclass from typing import Any, List, Optional, cast import llama_index from llama_index.bridge.pydantic import BaseModel from llama_index.callbacks.base import CallbackManager from llama_index.core.embeddings.base import BaseEmbedding from llama_index.indices.prompt_helper import PromptHelper from llama_index.llm_predictor import LLMPredictor from llama_index.llm_predictor.base import BaseLLMPredictor, LLMMetadata from llama_index.llms.llm import LLM from llama_index.llms.utils import LLMType, resolve_llm from llama_index.logger import LlamaLogger from llama_index.node_parser.interface import NodeParser, TextSplitter from llama_index.node_parser.text.sentence import ( DEFAULT_CHUNK_SIZE, SENTENCE_CHUNK_OVERLAP, SentenceSplitter, ) from llama_index.prompts.base import BasePromptTemplate from llama_index.schema import TransformComponent from llama_index.types import PydanticProgramMode logger = logging.getLogger(__name__) def _get_default_node_parser( chunk_size: int = DEFAULT_CHUNK_SIZE, chunk_overlap: int = SENTENCE_CHUNK_OVERLAP, callback_manager: Optional[CallbackManager] = None, ) -> NodeParser: """Get default node parser.""" return SentenceSplitter( chunk_size=chunk_size, chunk_overlap=chunk_overlap, callback_manager=callback_manager or CallbackManager(), ) def _get_default_prompt_helper( llm_metadata: LLMMetadata, context_window: Optional[int] = None, num_output: Optional[int] = None, ) -> PromptHelper: """Get default prompt helper.""" if context_window is not None: llm_metadata.context_window = context_window if num_output is not None: llm_metadata.num_output = num_output return PromptHelper.from_llm_metadata(llm_metadata=llm_metadata) class ServiceContextData(BaseModel): llm: dict llm_predictor: dict prompt_helper: dict embed_model: dict transformations: List[dict] @dataclass class ServiceContext: """Service Context container. The service context container is a utility container for LlamaIndex index and query classes. It contains the following: - llm_predictor: BaseLLMPredictor - prompt_helper: PromptHelper - embed_model: BaseEmbedding - node_parser: NodeParser - llama_logger: LlamaLogger (deprecated) - callback_manager: CallbackManager """ llm_predictor: BaseLLMPredictor prompt_helper: PromptHelper embed_model: BaseEmbedding transformations: List[TransformComponent] llama_logger: LlamaLogger callback_manager: CallbackManager @classmethod def from_defaults( cls, llm_predictor: Optional[BaseLLMPredictor] = None, llm: Optional[LLMType] = "default", prompt_helper: Optional[PromptHelper] = None, embed_model: Optional[Any] = "default", node_parser: Optional[NodeParser] = None, text_splitter: Optional[TextSplitter] = None, transformations: Optional[List[TransformComponent]] = None, llama_logger: Optional[LlamaLogger] = None, callback_manager: Optional[CallbackManager] = None, system_prompt: Optional[str] = None, query_wrapper_prompt: Optional[BasePromptTemplate] = None, # pydantic program mode (used if output_cls is specified) pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT, # node parser kwargs chunk_size: Optional[int] = None, chunk_overlap: Optional[int] = None, # prompt helper kwargs context_window: Optional[int] = None, num_output: Optional[int] = None, # deprecated kwargs chunk_size_limit: Optional[int] = None, ) -> "ServiceContext": """Create a ServiceContext from defaults. If an argument is specified, then use the argument value provided for that parameter. If an argument is not specified, then use the default value. You can change the base defaults by setting llama_index.global_service_context to a ServiceContext object with your desired settings. Args: llm_predictor (Optional[BaseLLMPredictor]): LLMPredictor prompt_helper (Optional[PromptHelper]): PromptHelper embed_model (Optional[BaseEmbedding]): BaseEmbedding or "local" (use local model) node_parser (Optional[NodeParser]): NodeParser llama_logger (Optional[LlamaLogger]): LlamaLogger (deprecated) chunk_size (Optional[int]): chunk_size callback_manager (Optional[CallbackManager]): CallbackManager system_prompt (Optional[str]): System-wide prompt to be prepended to all input prompts, used to guide system "decision making" query_wrapper_prompt (Optional[BasePromptTemplate]): A format to wrap passed-in input queries. Deprecated Args: chunk_size_limit (Optional[int]): renamed to chunk_size """ from llama_index.embeddings.utils import EmbedType, resolve_embed_model embed_model = cast(EmbedType, embed_model) if chunk_size_limit is not None and chunk_size is None: logger.warning( "chunk_size_limit is deprecated, please specify chunk_size instead" ) chunk_size = chunk_size_limit if llama_index.global_service_context is not None: return cls.from_service_context( llama_index.global_service_context, llm=llm, llm_predictor=llm_predictor, prompt_helper=prompt_helper, embed_model=embed_model, node_parser=node_parser, text_splitter=text_splitter, llama_logger=llama_logger, callback_manager=callback_manager, context_window=context_window, chunk_size=chunk_size, chunk_size_limit=chunk_size_limit, chunk_overlap=chunk_overlap, num_output=num_output, system_prompt=system_prompt, query_wrapper_prompt=query_wrapper_prompt, transformations=transformations, ) callback_manager = callback_manager or CallbackManager([]) if llm != "default": if llm_predictor is not None: raise ValueError("Cannot specify both llm and llm_predictor") llm = resolve_llm(llm) llm.system_prompt = llm.system_prompt or system_prompt llm.query_wrapper_prompt = llm.query_wrapper_prompt or query_wrapper_prompt llm.pydantic_program_mode = ( llm.pydantic_program_mode or pydantic_program_mode ) if llm_predictor is not None: print("LLMPredictor is deprecated, please use LLM instead.") llm_predictor = llm_predictor or LLMPredictor( llm=llm, pydantic_program_mode=pydantic_program_mode ) if isinstance(llm_predictor, LLMPredictor): llm_predictor.llm.callback_manager = callback_manager if system_prompt: llm_predictor.system_prompt = system_prompt if query_wrapper_prompt: llm_predictor.query_wrapper_prompt = query_wrapper_prompt # NOTE: the embed_model isn't used in all indices # NOTE: embed model should be a transformation, but the way the service # context works, we can't put in there yet. embed_model = resolve_embed_model(embed_model) embed_model.callback_manager = callback_manager prompt_helper = prompt_helper or _get_default_prompt_helper( llm_metadata=llm_predictor.metadata, context_window=context_window, num_output=num_output, ) if text_splitter is not None and node_parser is not None: raise ValueError("Cannot specify both text_splitter and node_parser") node_parser = ( text_splitter # text splitter extends node parser or node_parser or _get_default_node_parser( chunk_size=chunk_size or DEFAULT_CHUNK_SIZE, chunk_overlap=chunk_overlap or SENTENCE_CHUNK_OVERLAP, callback_manager=callback_manager, ) ) transformations = transformations or [node_parser] llama_logger = llama_logger or LlamaLogger() return cls( llm_predictor=llm_predictor, embed_model=embed_model, prompt_helper=prompt_helper, transformations=transformations, llama_logger=llama_logger, # deprecated callback_manager=callback_manager, ) @classmethod def from_service_context( cls, service_context: "ServiceContext", llm_predictor: Optional[BaseLLMPredictor] = None, llm: Optional[LLMType] = "default", prompt_helper: Optional[PromptHelper] = None, embed_model: Optional[Any] = "default", node_parser: Optional[NodeParser] = None, text_splitter: Optional[TextSplitter] = None, transformations: Optional[List[TransformComponent]] = None, llama_logger: Optional[LlamaLogger] = None, callback_manager: Optional[CallbackManager] = None, system_prompt: Optional[str] = None, query_wrapper_prompt: Optional[BasePromptTemplate] = None, # node parser kwargs chunk_size: Optional[int] = None, chunk_overlap: Optional[int] = None, # prompt helper kwargs context_window: Optional[int] = None, num_output: Optional[int] = None, # deprecated kwargs chunk_size_limit: Optional[int] = None, ) -> "ServiceContext": """Instantiate a new service context using a previous as the defaults.""" from llama_index.embeddings.utils import EmbedType, resolve_embed_model embed_model = cast(EmbedType, embed_model) if chunk_size_limit is not None and chunk_size is None: logger.warning( "chunk_size_limit is deprecated, please specify chunk_size", DeprecationWarning, ) chunk_size = chunk_size_limit callback_manager = callback_manager or service_context.callback_manager if llm != "default": if llm_predictor is not None: raise ValueError("Cannot specify both llm and llm_predictor") llm = resolve_llm(llm) llm_predictor = LLMPredictor(llm=llm) llm_predictor = llm_predictor or service_context.llm_predictor if isinstance(llm_predictor, LLMPredictor): llm_predictor.llm.callback_manager = callback_manager if system_prompt: llm_predictor.system_prompt = system_prompt if query_wrapper_prompt: llm_predictor.query_wrapper_prompt = query_wrapper_prompt # NOTE: the embed_model isn't used in all indices # default to using the embed model passed from the service context if embed_model == "default": embed_model = service_context.embed_model embed_model = resolve_embed_model(embed_model) embed_model.callback_manager = callback_manager prompt_helper = prompt_helper or service_context.prompt_helper if context_window is not None or num_output is not None: prompt_helper = _get_default_prompt_helper( llm_metadata=llm_predictor.metadata, context_window=context_window, num_output=num_output, ) transformations = transformations or [] node_parser_found = False for transform in service_context.transformations: if isinstance(transform, NodeParser): node_parser_found = True node_parser = transform break if text_splitter is not None and node_parser is not None: raise ValueError("Cannot specify both text_splitter and node_parser") if not node_parser_found: node_parser = ( text_splitter # text splitter extends node parser or node_parser or _get_default_node_parser( chunk_size=chunk_size or DEFAULT_CHUNK_SIZE, chunk_overlap=chunk_overlap or SENTENCE_CHUNK_OVERLAP, callback_manager=callback_manager, ) ) transformations = transformations or service_context.transformations llama_logger = llama_logger or service_context.llama_logger return cls( llm_predictor=llm_predictor, embed_model=embed_model, prompt_helper=prompt_helper, transformations=transformations, llama_logger=llama_logger, # deprecated callback_manager=callback_manager, ) @property def llm(self) -> LLM: return self.llm_predictor.llm @property def node_parser(self) -> NodeParser: """Get the node parser.""" for transform in self.transformations: if isinstance(transform, NodeParser): return transform raise ValueError("No node parser found.") def to_dict(self) -> dict: """Convert service context to dict.""" llm_dict = self.llm_predictor.llm.to_dict() llm_predictor_dict = self.llm_predictor.to_dict() embed_model_dict = self.embed_model.to_dict() prompt_helper_dict = self.prompt_helper.to_dict() tranform_list_dict = [x.to_dict() for x in self.transformations] return ServiceContextData( llm=llm_dict, llm_predictor=llm_predictor_dict, prompt_helper=prompt_helper_dict, embed_model=embed_model_dict, transformations=tranform_list_dict, ).dict() @classmethod def from_dict(cls, data: dict) -> "ServiceContext": from llama_index.embeddings.loading import load_embed_model from llama_index.extractors.loading import load_extractor from llama_index.llm_predictor.loading import load_predictor from llama_index.node_parser.loading import load_parser service_context_data = ServiceContextData.parse_obj(data) llm_predictor = load_predictor(service_context_data.llm_predictor) embed_model = load_embed_model(service_context_data.embed_model) prompt_helper = PromptHelper.from_dict(service_context_data.prompt_helper) transformations: List[TransformComponent] = [] for transform in service_context_data.transformations: try: transformations.append(load_parser(transform)) except ValueError: transformations.append(load_extractor(transform)) return cls.from_defaults( llm_predictor=llm_predictor, prompt_helper=prompt_helper, embed_model=embed_model, transformations=transformations, ) def set_global_service_context(service_context: Optional[ServiceContext]) -> None: """Helper function to set the global service context.""" llama_index.global_service_context = service_context