294 lines
9.9 KiB
Python
294 lines
9.9 KiB
Python
import logging
|
|
from typing import Any, Dict, Generator, List, Optional, Tuple, Type, Union, cast
|
|
|
|
from llama_index.agent.openai.utils import resolve_tool_choice
|
|
from llama_index.llms.llm import LLM
|
|
from llama_index.llms.openai import OpenAI
|
|
from llama_index.llms.openai_utils import OpenAIToolCall, to_openai_tool
|
|
from llama_index.program.llm_prompt_program import BaseLLMFunctionProgram
|
|
from llama_index.program.utils import create_list_model
|
|
from llama_index.prompts.base import BasePromptTemplate, PromptTemplate
|
|
from llama_index.types import Model
|
|
|
|
_logger = logging.getLogger(__name__)
|
|
|
|
|
|
def _default_tool_choice(
|
|
output_cls: Type[Model], allow_multiple: bool = False
|
|
) -> Union[str, Dict[str, Any]]:
|
|
"""Default OpenAI tool to choose."""
|
|
if allow_multiple:
|
|
return "auto"
|
|
else:
|
|
schema = output_cls.schema()
|
|
return resolve_tool_choice(schema["title"])
|
|
|
|
|
|
def _get_json_str(raw_str: str, start_idx: int) -> Tuple[Optional[str], int]:
|
|
"""Extract JSON str from raw string and start index."""
|
|
raw_str = raw_str[start_idx:]
|
|
stack_count = 0
|
|
for i, c in enumerate(raw_str):
|
|
if c == "{":
|
|
stack_count += 1
|
|
if c == "}":
|
|
stack_count -= 1
|
|
if stack_count == 0:
|
|
return raw_str[: i + 1], i + 2 + start_idx
|
|
|
|
return None, start_idx
|
|
|
|
|
|
def _parse_tool_calls(
|
|
tool_calls: List[OpenAIToolCall],
|
|
output_cls: Type[Model],
|
|
allow_multiple: bool = False,
|
|
verbose: bool = False,
|
|
) -> Union[Model, List[Model]]:
|
|
outputs = []
|
|
for tool_call in tool_calls:
|
|
function_call = tool_call.function
|
|
# validations to get passed mypy
|
|
assert function_call is not None
|
|
assert function_call.name is not None
|
|
assert function_call.arguments is not None
|
|
if verbose:
|
|
name = function_call.name
|
|
arguments_str = function_call.arguments
|
|
print(f"Function call: {name} with args: {arguments_str}")
|
|
|
|
if isinstance(function_call.arguments, dict):
|
|
output = output_cls.parse_obj(function_call.arguments)
|
|
else:
|
|
output = output_cls.parse_raw(function_call.arguments)
|
|
|
|
outputs.append(output)
|
|
|
|
if allow_multiple:
|
|
return outputs
|
|
else:
|
|
if len(outputs) > 1:
|
|
_logger.warning(
|
|
"Multiple outputs found, returning first one. "
|
|
"If you want to return all outputs, set output_multiple=True."
|
|
)
|
|
|
|
return outputs[0]
|
|
|
|
|
|
class OpenAIPydanticProgram(BaseLLMFunctionProgram[LLM]):
|
|
"""
|
|
An OpenAI-based function that returns a pydantic model.
|
|
|
|
Note: this interface is not yet stable.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
output_cls: Type[Model],
|
|
llm: LLM,
|
|
prompt: BasePromptTemplate,
|
|
tool_choice: Union[str, Dict[str, Any]],
|
|
allow_multiple: bool = False,
|
|
verbose: bool = False,
|
|
) -> None:
|
|
"""Init params."""
|
|
self._output_cls = output_cls
|
|
self._llm = llm
|
|
self._prompt = prompt
|
|
self._verbose = verbose
|
|
self._allow_multiple = allow_multiple
|
|
self._tool_choice = tool_choice
|
|
|
|
@classmethod
|
|
def from_defaults(
|
|
cls,
|
|
output_cls: Type[Model],
|
|
prompt_template_str: Optional[str] = None,
|
|
prompt: Optional[PromptTemplate] = None,
|
|
llm: Optional[LLM] = None,
|
|
verbose: bool = False,
|
|
allow_multiple: bool = False,
|
|
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
|
|
**kwargs: Any,
|
|
) -> "OpenAIPydanticProgram":
|
|
llm = llm or OpenAI(model="gpt-3.5-turbo-0613")
|
|
|
|
if not isinstance(llm, OpenAI):
|
|
raise ValueError(
|
|
"OpenAIPydanticProgram only supports OpenAI LLMs. " f"Got: {type(llm)}"
|
|
)
|
|
|
|
if not llm.metadata.is_function_calling_model:
|
|
raise ValueError(
|
|
f"Model name {llm.metadata.model_name} does not support "
|
|
"function calling API. "
|
|
)
|
|
|
|
if prompt is None and prompt_template_str is None:
|
|
raise ValueError("Must provide either prompt or prompt_template_str.")
|
|
if prompt is not None and prompt_template_str is not None:
|
|
raise ValueError("Must provide either prompt or prompt_template_str.")
|
|
if prompt_template_str is not None:
|
|
prompt = PromptTemplate(prompt_template_str)
|
|
|
|
tool_choice = tool_choice or _default_tool_choice(output_cls, allow_multiple)
|
|
|
|
return cls(
|
|
output_cls=output_cls,
|
|
llm=llm,
|
|
prompt=cast(PromptTemplate, prompt),
|
|
tool_choice=tool_choice,
|
|
allow_multiple=allow_multiple,
|
|
verbose=verbose,
|
|
)
|
|
|
|
@property
|
|
def output_cls(self) -> Type[Model]:
|
|
return self._output_cls
|
|
|
|
@property
|
|
def prompt(self) -> BasePromptTemplate:
|
|
return self._prompt
|
|
|
|
@prompt.setter
|
|
def prompt(self, prompt: BasePromptTemplate) -> None:
|
|
self._prompt = prompt
|
|
|
|
def __call__(
|
|
self,
|
|
llm_kwargs: Optional[Dict[str, Any]] = None,
|
|
*args: Any,
|
|
**kwargs: Any,
|
|
) -> Union[Model, List[Model]]:
|
|
llm_kwargs = llm_kwargs or {}
|
|
description = self._description_eval(**kwargs)
|
|
|
|
openai_fn_spec = to_openai_tool(self._output_cls, description=description)
|
|
|
|
messages = self._prompt.format_messages(llm=self._llm, **kwargs)
|
|
|
|
chat_response = self._llm.chat(
|
|
messages=messages,
|
|
tools=[openai_fn_spec],
|
|
tool_choice=self._tool_choice,
|
|
**llm_kwargs,
|
|
)
|
|
message = chat_response.message
|
|
if "tool_calls" not in message.additional_kwargs:
|
|
raise ValueError(
|
|
"Expected tool_calls in ai_message.additional_kwargs, "
|
|
"but none found."
|
|
)
|
|
|
|
tool_calls = message.additional_kwargs["tool_calls"]
|
|
return _parse_tool_calls(
|
|
tool_calls,
|
|
output_cls=self.output_cls,
|
|
allow_multiple=self._allow_multiple,
|
|
verbose=self._verbose,
|
|
)
|
|
|
|
async def acall(
|
|
self,
|
|
llm_kwargs: Optional[Dict[str, Any]] = None,
|
|
*args: Any,
|
|
**kwargs: Any,
|
|
) -> Union[Model, List[Model]]:
|
|
llm_kwargs = llm_kwargs or {}
|
|
description = self._description_eval(**kwargs)
|
|
|
|
openai_fn_spec = to_openai_tool(self._output_cls, description=description)
|
|
|
|
messages = self._prompt.format_messages(llm=self._llm, **kwargs)
|
|
|
|
chat_response = await self._llm.achat(
|
|
messages=messages,
|
|
tools=[openai_fn_spec],
|
|
tool_choice=self._tool_choice,
|
|
**llm_kwargs,
|
|
)
|
|
message = chat_response.message
|
|
if "tool_calls" not in message.additional_kwargs:
|
|
raise ValueError(
|
|
"Expected function call in ai_message.additional_kwargs, "
|
|
"but none found."
|
|
)
|
|
|
|
tool_calls = message.additional_kwargs["tool_calls"]
|
|
return _parse_tool_calls(
|
|
tool_calls,
|
|
output_cls=self.output_cls,
|
|
allow_multiple=self._allow_multiple,
|
|
verbose=self._verbose,
|
|
)
|
|
|
|
def stream_list(
|
|
self,
|
|
llm_kwargs: Optional[Dict[str, Any]] = None,
|
|
*args: Any,
|
|
**kwargs: Any,
|
|
) -> Generator[Model, None, None]:
|
|
"""Streams a list of objects."""
|
|
llm_kwargs = llm_kwargs or {}
|
|
messages = self._prompt.format_messages(llm=self._llm, **kwargs)
|
|
|
|
description = self._description_eval(**kwargs)
|
|
|
|
list_output_cls = create_list_model(self._output_cls)
|
|
openai_fn_spec = to_openai_tool(list_output_cls, description=description)
|
|
|
|
chat_response_gen = self._llm.stream_chat(
|
|
messages=messages,
|
|
tools=[openai_fn_spec],
|
|
tool_choice=_default_tool_choice(list_output_cls),
|
|
**llm_kwargs,
|
|
)
|
|
# extract function call arguments
|
|
# obj_start_idx finds start position (before a new "{" in JSON)
|
|
obj_start_idx: int = -1 # NOTE: uninitialized
|
|
for stream_resp in chat_response_gen:
|
|
kwargs = stream_resp.message.additional_kwargs
|
|
tool_calls = kwargs["tool_calls"]
|
|
if len(tool_calls) == 0:
|
|
continue
|
|
|
|
# NOTE: right now assume only one tool call
|
|
# TODO: handle parallel tool calls in streaming setting
|
|
fn_args = kwargs["tool_calls"][0].function.arguments
|
|
|
|
# this is inspired by `get_object` from `MultiTaskBase` in
|
|
# the openai_function_call repo
|
|
|
|
if fn_args.find("[") != -1:
|
|
if obj_start_idx == -1:
|
|
obj_start_idx = fn_args.find("[") + 1
|
|
else:
|
|
# keep going until we find the start position
|
|
continue
|
|
|
|
new_obj_json_str, obj_start_idx = _get_json_str(fn_args, obj_start_idx)
|
|
if new_obj_json_str is not None:
|
|
obj_json_str = new_obj_json_str
|
|
obj = self._output_cls.parse_raw(obj_json_str)
|
|
if self._verbose:
|
|
print(f"Extracted object: {obj.json()}")
|
|
yield obj
|
|
|
|
def _description_eval(self, **kwargs: Any) -> Optional[str]:
|
|
description = kwargs.get("description", None)
|
|
|
|
## __doc__ checks if docstring is provided in the Pydantic Model
|
|
if not (self._output_cls.__doc__ or description):
|
|
raise ValueError(
|
|
"Must provide description for your Pydantic Model. Either provide a docstring or add `description=<your_description>` to the method. Required to convert Pydantic Model to OpenAI Function."
|
|
)
|
|
|
|
## If both docstring and description are provided, raise error
|
|
if self._output_cls.__doc__ and description:
|
|
raise ValueError(
|
|
"Must provide either a docstring or a description, not both."
|
|
)
|
|
|
|
return description
|