391 lines
15 KiB
Python
391 lines
15 KiB
Python
import logging
|
|
from dataclasses import dataclass
|
|
from typing import Any, List, Optional, cast
|
|
|
|
import llama_index
|
|
from llama_index.bridge.pydantic import BaseModel
|
|
from llama_index.callbacks.base import CallbackManager
|
|
from llama_index.core.embeddings.base import BaseEmbedding
|
|
from llama_index.indices.prompt_helper import PromptHelper
|
|
from llama_index.llm_predictor import LLMPredictor
|
|
from llama_index.llm_predictor.base import BaseLLMPredictor, LLMMetadata
|
|
from llama_index.llms.llm import LLM
|
|
from llama_index.llms.utils import LLMType, resolve_llm
|
|
from llama_index.logger import LlamaLogger
|
|
from llama_index.node_parser.interface import NodeParser, TextSplitter
|
|
from llama_index.node_parser.text.sentence import (
|
|
DEFAULT_CHUNK_SIZE,
|
|
SENTENCE_CHUNK_OVERLAP,
|
|
SentenceSplitter,
|
|
)
|
|
from llama_index.prompts.base import BasePromptTemplate
|
|
from llama_index.schema import TransformComponent
|
|
from llama_index.types import PydanticProgramMode
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def _get_default_node_parser(
|
|
chunk_size: int = DEFAULT_CHUNK_SIZE,
|
|
chunk_overlap: int = SENTENCE_CHUNK_OVERLAP,
|
|
callback_manager: Optional[CallbackManager] = None,
|
|
) -> NodeParser:
|
|
"""Get default node parser."""
|
|
return SentenceSplitter(
|
|
chunk_size=chunk_size,
|
|
chunk_overlap=chunk_overlap,
|
|
callback_manager=callback_manager or CallbackManager(),
|
|
)
|
|
|
|
|
|
def _get_default_prompt_helper(
|
|
llm_metadata: LLMMetadata,
|
|
context_window: Optional[int] = None,
|
|
num_output: Optional[int] = None,
|
|
) -> PromptHelper:
|
|
"""Get default prompt helper."""
|
|
if context_window is not None:
|
|
llm_metadata.context_window = context_window
|
|
if num_output is not None:
|
|
llm_metadata.num_output = num_output
|
|
return PromptHelper.from_llm_metadata(llm_metadata=llm_metadata)
|
|
|
|
|
|
class ServiceContextData(BaseModel):
|
|
llm: dict
|
|
llm_predictor: dict
|
|
prompt_helper: dict
|
|
embed_model: dict
|
|
transformations: List[dict]
|
|
|
|
|
|
@dataclass
|
|
class ServiceContext:
|
|
"""Service Context container.
|
|
|
|
The service context container is a utility container for LlamaIndex
|
|
index and query classes. It contains the following:
|
|
- llm_predictor: BaseLLMPredictor
|
|
- prompt_helper: PromptHelper
|
|
- embed_model: BaseEmbedding
|
|
- node_parser: NodeParser
|
|
- llama_logger: LlamaLogger (deprecated)
|
|
- callback_manager: CallbackManager
|
|
|
|
"""
|
|
|
|
llm_predictor: BaseLLMPredictor
|
|
prompt_helper: PromptHelper
|
|
embed_model: BaseEmbedding
|
|
transformations: List[TransformComponent]
|
|
llama_logger: LlamaLogger
|
|
callback_manager: CallbackManager
|
|
|
|
@classmethod
|
|
def from_defaults(
|
|
cls,
|
|
llm_predictor: Optional[BaseLLMPredictor] = None,
|
|
llm: Optional[LLMType] = "default",
|
|
prompt_helper: Optional[PromptHelper] = None,
|
|
embed_model: Optional[Any] = "default",
|
|
node_parser: Optional[NodeParser] = None,
|
|
text_splitter: Optional[TextSplitter] = None,
|
|
transformations: Optional[List[TransformComponent]] = None,
|
|
llama_logger: Optional[LlamaLogger] = None,
|
|
callback_manager: Optional[CallbackManager] = None,
|
|
system_prompt: Optional[str] = None,
|
|
query_wrapper_prompt: Optional[BasePromptTemplate] = None,
|
|
# pydantic program mode (used if output_cls is specified)
|
|
pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT,
|
|
# node parser kwargs
|
|
chunk_size: Optional[int] = None,
|
|
chunk_overlap: Optional[int] = None,
|
|
# prompt helper kwargs
|
|
context_window: Optional[int] = None,
|
|
num_output: Optional[int] = None,
|
|
# deprecated kwargs
|
|
chunk_size_limit: Optional[int] = None,
|
|
) -> "ServiceContext":
|
|
"""Create a ServiceContext from defaults.
|
|
If an argument is specified, then use the argument value provided for that
|
|
parameter. If an argument is not specified, then use the default value.
|
|
|
|
You can change the base defaults by setting llama_index.global_service_context
|
|
to a ServiceContext object with your desired settings.
|
|
|
|
Args:
|
|
llm_predictor (Optional[BaseLLMPredictor]): LLMPredictor
|
|
prompt_helper (Optional[PromptHelper]): PromptHelper
|
|
embed_model (Optional[BaseEmbedding]): BaseEmbedding
|
|
or "local" (use local model)
|
|
node_parser (Optional[NodeParser]): NodeParser
|
|
llama_logger (Optional[LlamaLogger]): LlamaLogger (deprecated)
|
|
chunk_size (Optional[int]): chunk_size
|
|
callback_manager (Optional[CallbackManager]): CallbackManager
|
|
system_prompt (Optional[str]): System-wide prompt to be prepended
|
|
to all input prompts, used to guide system "decision making"
|
|
query_wrapper_prompt (Optional[BasePromptTemplate]): A format to wrap
|
|
passed-in input queries.
|
|
|
|
Deprecated Args:
|
|
chunk_size_limit (Optional[int]): renamed to chunk_size
|
|
|
|
"""
|
|
from llama_index.embeddings.utils import EmbedType, resolve_embed_model
|
|
|
|
embed_model = cast(EmbedType, embed_model)
|
|
|
|
if chunk_size_limit is not None and chunk_size is None:
|
|
logger.warning(
|
|
"chunk_size_limit is deprecated, please specify chunk_size instead"
|
|
)
|
|
chunk_size = chunk_size_limit
|
|
|
|
if llama_index.global_service_context is not None:
|
|
return cls.from_service_context(
|
|
llama_index.global_service_context,
|
|
llm=llm,
|
|
llm_predictor=llm_predictor,
|
|
prompt_helper=prompt_helper,
|
|
embed_model=embed_model,
|
|
node_parser=node_parser,
|
|
text_splitter=text_splitter,
|
|
llama_logger=llama_logger,
|
|
callback_manager=callback_manager,
|
|
context_window=context_window,
|
|
chunk_size=chunk_size,
|
|
chunk_size_limit=chunk_size_limit,
|
|
chunk_overlap=chunk_overlap,
|
|
num_output=num_output,
|
|
system_prompt=system_prompt,
|
|
query_wrapper_prompt=query_wrapper_prompt,
|
|
transformations=transformations,
|
|
)
|
|
|
|
callback_manager = callback_manager or CallbackManager([])
|
|
if llm != "default":
|
|
if llm_predictor is not None:
|
|
raise ValueError("Cannot specify both llm and llm_predictor")
|
|
llm = resolve_llm(llm)
|
|
llm.system_prompt = llm.system_prompt or system_prompt
|
|
llm.query_wrapper_prompt = llm.query_wrapper_prompt or query_wrapper_prompt
|
|
llm.pydantic_program_mode = (
|
|
llm.pydantic_program_mode or pydantic_program_mode
|
|
)
|
|
|
|
if llm_predictor is not None:
|
|
print("LLMPredictor is deprecated, please use LLM instead.")
|
|
llm_predictor = llm_predictor or LLMPredictor(
|
|
llm=llm, pydantic_program_mode=pydantic_program_mode
|
|
)
|
|
if isinstance(llm_predictor, LLMPredictor):
|
|
llm_predictor.llm.callback_manager = callback_manager
|
|
if system_prompt:
|
|
llm_predictor.system_prompt = system_prompt
|
|
if query_wrapper_prompt:
|
|
llm_predictor.query_wrapper_prompt = query_wrapper_prompt
|
|
|
|
# NOTE: the embed_model isn't used in all indices
|
|
# NOTE: embed model should be a transformation, but the way the service
|
|
# context works, we can't put in there yet.
|
|
embed_model = resolve_embed_model(embed_model)
|
|
embed_model.callback_manager = callback_manager
|
|
|
|
prompt_helper = prompt_helper or _get_default_prompt_helper(
|
|
llm_metadata=llm_predictor.metadata,
|
|
context_window=context_window,
|
|
num_output=num_output,
|
|
)
|
|
|
|
if text_splitter is not None and node_parser is not None:
|
|
raise ValueError("Cannot specify both text_splitter and node_parser")
|
|
|
|
node_parser = (
|
|
text_splitter # text splitter extends node parser
|
|
or node_parser
|
|
or _get_default_node_parser(
|
|
chunk_size=chunk_size or DEFAULT_CHUNK_SIZE,
|
|
chunk_overlap=chunk_overlap or SENTENCE_CHUNK_OVERLAP,
|
|
callback_manager=callback_manager,
|
|
)
|
|
)
|
|
|
|
transformations = transformations or [node_parser]
|
|
|
|
llama_logger = llama_logger or LlamaLogger()
|
|
|
|
return cls(
|
|
llm_predictor=llm_predictor,
|
|
embed_model=embed_model,
|
|
prompt_helper=prompt_helper,
|
|
transformations=transformations,
|
|
llama_logger=llama_logger, # deprecated
|
|
callback_manager=callback_manager,
|
|
)
|
|
|
|
@classmethod
|
|
def from_service_context(
|
|
cls,
|
|
service_context: "ServiceContext",
|
|
llm_predictor: Optional[BaseLLMPredictor] = None,
|
|
llm: Optional[LLMType] = "default",
|
|
prompt_helper: Optional[PromptHelper] = None,
|
|
embed_model: Optional[Any] = "default",
|
|
node_parser: Optional[NodeParser] = None,
|
|
text_splitter: Optional[TextSplitter] = None,
|
|
transformations: Optional[List[TransformComponent]] = None,
|
|
llama_logger: Optional[LlamaLogger] = None,
|
|
callback_manager: Optional[CallbackManager] = None,
|
|
system_prompt: Optional[str] = None,
|
|
query_wrapper_prompt: Optional[BasePromptTemplate] = None,
|
|
# node parser kwargs
|
|
chunk_size: Optional[int] = None,
|
|
chunk_overlap: Optional[int] = None,
|
|
# prompt helper kwargs
|
|
context_window: Optional[int] = None,
|
|
num_output: Optional[int] = None,
|
|
# deprecated kwargs
|
|
chunk_size_limit: Optional[int] = None,
|
|
) -> "ServiceContext":
|
|
"""Instantiate a new service context using a previous as the defaults."""
|
|
from llama_index.embeddings.utils import EmbedType, resolve_embed_model
|
|
|
|
embed_model = cast(EmbedType, embed_model)
|
|
|
|
if chunk_size_limit is not None and chunk_size is None:
|
|
logger.warning(
|
|
"chunk_size_limit is deprecated, please specify chunk_size",
|
|
DeprecationWarning,
|
|
)
|
|
chunk_size = chunk_size_limit
|
|
|
|
callback_manager = callback_manager or service_context.callback_manager
|
|
if llm != "default":
|
|
if llm_predictor is not None:
|
|
raise ValueError("Cannot specify both llm and llm_predictor")
|
|
llm = resolve_llm(llm)
|
|
llm_predictor = LLMPredictor(llm=llm)
|
|
|
|
llm_predictor = llm_predictor or service_context.llm_predictor
|
|
if isinstance(llm_predictor, LLMPredictor):
|
|
llm_predictor.llm.callback_manager = callback_manager
|
|
if system_prompt:
|
|
llm_predictor.system_prompt = system_prompt
|
|
if query_wrapper_prompt:
|
|
llm_predictor.query_wrapper_prompt = query_wrapper_prompt
|
|
|
|
# NOTE: the embed_model isn't used in all indices
|
|
# default to using the embed model passed from the service context
|
|
if embed_model == "default":
|
|
embed_model = service_context.embed_model
|
|
embed_model = resolve_embed_model(embed_model)
|
|
embed_model.callback_manager = callback_manager
|
|
|
|
prompt_helper = prompt_helper or service_context.prompt_helper
|
|
if context_window is not None or num_output is not None:
|
|
prompt_helper = _get_default_prompt_helper(
|
|
llm_metadata=llm_predictor.metadata,
|
|
context_window=context_window,
|
|
num_output=num_output,
|
|
)
|
|
|
|
transformations = transformations or []
|
|
node_parser_found = False
|
|
for transform in service_context.transformations:
|
|
if isinstance(transform, NodeParser):
|
|
node_parser_found = True
|
|
node_parser = transform
|
|
break
|
|
|
|
if text_splitter is not None and node_parser is not None:
|
|
raise ValueError("Cannot specify both text_splitter and node_parser")
|
|
|
|
if not node_parser_found:
|
|
node_parser = (
|
|
text_splitter # text splitter extends node parser
|
|
or node_parser
|
|
or _get_default_node_parser(
|
|
chunk_size=chunk_size or DEFAULT_CHUNK_SIZE,
|
|
chunk_overlap=chunk_overlap or SENTENCE_CHUNK_OVERLAP,
|
|
callback_manager=callback_manager,
|
|
)
|
|
)
|
|
|
|
transformations = transformations or service_context.transformations
|
|
|
|
llama_logger = llama_logger or service_context.llama_logger
|
|
|
|
return cls(
|
|
llm_predictor=llm_predictor,
|
|
embed_model=embed_model,
|
|
prompt_helper=prompt_helper,
|
|
transformations=transformations,
|
|
llama_logger=llama_logger, # deprecated
|
|
callback_manager=callback_manager,
|
|
)
|
|
|
|
@property
|
|
def llm(self) -> LLM:
|
|
return self.llm_predictor.llm
|
|
|
|
@property
|
|
def node_parser(self) -> NodeParser:
|
|
"""Get the node parser."""
|
|
for transform in self.transformations:
|
|
if isinstance(transform, NodeParser):
|
|
return transform
|
|
raise ValueError("No node parser found.")
|
|
|
|
def to_dict(self) -> dict:
|
|
"""Convert service context to dict."""
|
|
llm_dict = self.llm_predictor.llm.to_dict()
|
|
llm_predictor_dict = self.llm_predictor.to_dict()
|
|
|
|
embed_model_dict = self.embed_model.to_dict()
|
|
|
|
prompt_helper_dict = self.prompt_helper.to_dict()
|
|
|
|
tranform_list_dict = [x.to_dict() for x in self.transformations]
|
|
|
|
return ServiceContextData(
|
|
llm=llm_dict,
|
|
llm_predictor=llm_predictor_dict,
|
|
prompt_helper=prompt_helper_dict,
|
|
embed_model=embed_model_dict,
|
|
transformations=tranform_list_dict,
|
|
).dict()
|
|
|
|
@classmethod
|
|
def from_dict(cls, data: dict) -> "ServiceContext":
|
|
from llama_index.embeddings.loading import load_embed_model
|
|
from llama_index.extractors.loading import load_extractor
|
|
from llama_index.llm_predictor.loading import load_predictor
|
|
from llama_index.node_parser.loading import load_parser
|
|
|
|
service_context_data = ServiceContextData.parse_obj(data)
|
|
|
|
llm_predictor = load_predictor(service_context_data.llm_predictor)
|
|
|
|
embed_model = load_embed_model(service_context_data.embed_model)
|
|
|
|
prompt_helper = PromptHelper.from_dict(service_context_data.prompt_helper)
|
|
|
|
transformations: List[TransformComponent] = []
|
|
for transform in service_context_data.transformations:
|
|
try:
|
|
transformations.append(load_parser(transform))
|
|
except ValueError:
|
|
transformations.append(load_extractor(transform))
|
|
|
|
return cls.from_defaults(
|
|
llm_predictor=llm_predictor,
|
|
prompt_helper=prompt_helper,
|
|
embed_model=embed_model,
|
|
transformations=transformations,
|
|
)
|
|
|
|
|
|
def set_global_service_context(service_context: Optional[ServiceContext]) -> None:
|
|
"""Helper function to set the global service context."""
|
|
llama_index.global_service_context = service_context
|