545 lines
16 KiB
Plaintext
545 lines
16 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"# Tool and Function Calling\n",
|
||
"\n",
|
||
"This guide demonstrates how to use SGLang’s [Funcion calling](https://platform.openai.com/docs/guides/function-calling) functionality."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## OpenAI Compatible API"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"### Launching the Server"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"from openai import OpenAI\n",
|
||
"import json\n",
|
||
"from sglang.utils import wait_for_server, print_highlight, terminate_process\n",
|
||
"from sglang.test.test_utils import is_in_ci\n",
|
||
"\n",
|
||
"if is_in_ci():\n",
|
||
" from patch import launch_server_cmd\n",
|
||
"else:\n",
|
||
" from sglang.utils import launch_server_cmd\n",
|
||
"\n",
|
||
"\n",
|
||
"server_process, port = launch_server_cmd(\n",
|
||
" \"python3 -m sglang.launch_server --model-path Qwen/Qwen2.5-7B-Instruct --tool-call-parser qwen25 --host 0.0.0.0\" # qwen25\n",
|
||
")\n",
|
||
"wait_for_server(f\"http://localhost:{port}\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Note that `--tool-call-parser` defines the parser used to interpret responses. Currently supported parsers include:\n",
|
||
"\n",
|
||
"- llama3: Llama 3.1 / 3.2 (e.g. meta-llama/Llama-3.1-8B-Instruct, meta-llama/Llama-3.2-1B-Instruct).\n",
|
||
"- mistral: Mistral (e.g. mistralai/Mistral-7B-Instruct-v0.3, mistralai/Mistral-Nemo-Instruct-2407, mistralai/\n",
|
||
"Mistral-Nemo-Instruct-2407, mistralai/Mistral-7B-v0.3).\n",
|
||
"- qwen25: Qwen 2.5 (e.g. Qwen/Qwen2.5-1.5B-Instruct, Qwen/Qwen2.5-7B-Instruct) and QwQ (i.e. Qwen/QwQ-32B). Especially, for QwQ, we can enable the reasoning parser together with tool call parser, details about reasoning parser can be found in [reasoning parser](https://docs.sglang.ai/backend/separate_reasoning.html)."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"### Define Tools for Function Call\n",
|
||
"Below is a Python snippet that shows how to define a tool as a dictionary. The dictionary includes a tool name, a description, and property defined Parameters."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# Define tools\n",
|
||
"tools = [\n",
|
||
" {\n",
|
||
" \"type\": \"function\",\n",
|
||
" \"function\": {\n",
|
||
" \"name\": \"get_current_weather\",\n",
|
||
" \"description\": \"Get the current weather in a given location\",\n",
|
||
" \"parameters\": {\n",
|
||
" \"type\": \"object\",\n",
|
||
" \"properties\": {\n",
|
||
" \"city\": {\n",
|
||
" \"type\": \"string\",\n",
|
||
" \"description\": \"The city to find the weather for, e.g. 'San Francisco'\",\n",
|
||
" },\n",
|
||
" \"state\": {\n",
|
||
" \"type\": \"string\",\n",
|
||
" \"description\": \"the two-letter abbreviation for the state that the city is\"\n",
|
||
" \" in, e.g. 'CA' which would mean 'California'\",\n",
|
||
" },\n",
|
||
" \"unit\": {\n",
|
||
" \"type\": \"string\",\n",
|
||
" \"description\": \"The unit to fetch the temperature in\",\n",
|
||
" \"enum\": [\"celsius\", \"fahrenheit\"],\n",
|
||
" },\n",
|
||
" },\n",
|
||
" \"required\": [\"city\", \"state\", \"unit\"],\n",
|
||
" },\n",
|
||
" },\n",
|
||
" }\n",
|
||
"]"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"### Define Messages"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"def get_messages():\n",
|
||
" return [\n",
|
||
" {\n",
|
||
" \"role\": \"user\",\n",
|
||
" \"content\": \"What's the weather like in Boston today? Output a reasoning before act, then use the tools to help you.\",\n",
|
||
" }\n",
|
||
" ]\n",
|
||
"\n",
|
||
"\n",
|
||
"messages = get_messages()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"### Initialize the Client"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# Initialize OpenAI-like client\n",
|
||
"client = OpenAI(api_key=\"None\", base_url=f\"http://0.0.0.0:{port}/v1\")\n",
|
||
"model_name = client.models.list().data[0].id"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"### Non-Streaming Request"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# Non-streaming mode test\n",
|
||
"response_non_stream = client.chat.completions.create(\n",
|
||
" model=model_name,\n",
|
||
" messages=messages,\n",
|
||
" temperature=0.1,\n",
|
||
" top_p=0.95,\n",
|
||
" max_tokens=1024,\n",
|
||
" stream=False, # Non-streaming\n",
|
||
" tools=tools,\n",
|
||
")\n",
|
||
"print_highlight(\"Non-stream response:\")\n",
|
||
"print(response_non_stream)\n",
|
||
"print_highlight(\"==== content ====\")\n",
|
||
"print(response_non_stream.choices[0].message.content)\n",
|
||
"print_highlight(\"==== tool_calls ====\")\n",
|
||
"print(response_non_stream.choices[0].message.tool_calls)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"#### Handle Tools\n",
|
||
"When the engine determines it should call a particular tool, it will return arguments or partial arguments through the response. You can parse these arguments and later invoke the tool accordingly."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"name_non_stream = response_non_stream.choices[0].message.tool_calls[0].function.name\n",
|
||
"arguments_non_stream = (\n",
|
||
" response_non_stream.choices[0].message.tool_calls[0].function.arguments\n",
|
||
")\n",
|
||
"\n",
|
||
"print_highlight(f\"Final streamed function call name: {name_non_stream}\")\n",
|
||
"print_highlight(f\"Final streamed function call arguments: {arguments_non_stream}\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"### Streaming Request"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# Streaming mode test\n",
|
||
"print_highlight(\"Streaming response:\")\n",
|
||
"response_stream = client.chat.completions.create(\n",
|
||
" model=model_name,\n",
|
||
" messages=messages,\n",
|
||
" temperature=0.1,\n",
|
||
" top_p=0.95,\n",
|
||
" max_tokens=1024,\n",
|
||
" stream=True, # Enable streaming\n",
|
||
" tools=tools,\n",
|
||
")\n",
|
||
"\n",
|
||
"texts = \"\"\n",
|
||
"tool_calls = []\n",
|
||
"name = \"\"\n",
|
||
"arguments = \"\"\n",
|
||
"for chunk in response_stream:\n",
|
||
" if chunk.choices[0].delta.content:\n",
|
||
" texts += chunk.choices[0].delta.content\n",
|
||
" if chunk.choices[0].delta.tool_calls:\n",
|
||
" tool_calls.append(chunk.choices[0].delta.tool_calls[0])\n",
|
||
"print_highlight(\"==== Text ====\")\n",
|
||
"print(texts)\n",
|
||
"\n",
|
||
"print_highlight(\"==== Tool Call ====\")\n",
|
||
"for tool_call in tool_calls:\n",
|
||
" print(tool_call)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"#### Handle Tools\n",
|
||
"When the engine determines it should call a particular tool, it will return arguments or partial arguments through the response. You can parse these arguments and later invoke the tool accordingly."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# Parse and combine function call arguments\n",
|
||
"arguments = []\n",
|
||
"for tool_call in tool_calls:\n",
|
||
" if tool_call.function.name:\n",
|
||
" print_highlight(f\"Streamed function call name: {tool_call.function.name}\")\n",
|
||
"\n",
|
||
" if tool_call.function.arguments:\n",
|
||
" arguments.append(tool_call.function.arguments)\n",
|
||
"\n",
|
||
"# Combine all fragments into a single JSON string\n",
|
||
"full_arguments = \"\".join(arguments)\n",
|
||
"print_highlight(f\"streamed function call arguments: {full_arguments}\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"### Define a Tool Function"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# This is a demonstration, define real function according to your usage.\n",
|
||
"def get_current_weather(city: str, state: str, unit: \"str\"):\n",
|
||
" return (\n",
|
||
" f\"The weather in {city}, {state} is 85 degrees {unit}. It is \"\n",
|
||
" \"partly cloudly, with highs in the 90's.\"\n",
|
||
" )\n",
|
||
"\n",
|
||
"\n",
|
||
"available_tools = {\"get_current_weather\": get_current_weather}"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"\n",
|
||
"## Execute the Tool"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"call_data = json.loads(full_arguments)\n",
|
||
"\n",
|
||
"messages.append(\n",
|
||
" {\n",
|
||
" \"role\": \"user\",\n",
|
||
" \"content\": \"\",\n",
|
||
" \"tool_calls\": {\"name\": \"get_current_weather\", \"arguments\": full_arguments},\n",
|
||
" }\n",
|
||
")\n",
|
||
"\n",
|
||
"# Call the corresponding tool function\n",
|
||
"tool_name = messages[-1][\"tool_calls\"][\"name\"]\n",
|
||
"tool_to_call = available_tools[tool_name]\n",
|
||
"result = tool_to_call(**call_data)\n",
|
||
"print_highlight(f\"Function call result: {result}\")\n",
|
||
"messages.append({\"role\": \"tool\", \"content\": result, \"name\": tool_name})\n",
|
||
"\n",
|
||
"print_highlight(f\"Updated message history: {messages}\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"### Send Results Back to Model"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"final_response = client.chat.completions.create(\n",
|
||
" model=model_name,\n",
|
||
" messages=messages,\n",
|
||
" temperature=0.1,\n",
|
||
" top_p=0.95,\n",
|
||
" stream=False,\n",
|
||
" tools=tools,\n",
|
||
")\n",
|
||
"print_highlight(\"Non-stream response:\")\n",
|
||
"print(final_response)\n",
|
||
"\n",
|
||
"print_highlight(\"==== Text ====\")\n",
|
||
"print(final_response.choices[0].message.content)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Native API and SGLang Runtime (SRT)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"from transformers import AutoTokenizer\n",
|
||
"import requests\n",
|
||
"\n",
|
||
"# generate an answer\n",
|
||
"tokenizer = AutoTokenizer.from_pretrained(\"Qwen/Qwen2.5-7B-Instruct\")\n",
|
||
"\n",
|
||
"messages = get_messages()\n",
|
||
"\n",
|
||
"input = tokenizer.apply_chat_template(\n",
|
||
" messages,\n",
|
||
" tokenize=False,\n",
|
||
" add_generation_prompt=True,\n",
|
||
" tools=tools,\n",
|
||
")\n",
|
||
"\n",
|
||
"gen_url = f\"http://localhost:{port}/generate\"\n",
|
||
"gen_data = {\n",
|
||
" \"text\": input,\n",
|
||
" \"sampling_params\": {\n",
|
||
" \"skip_special_tokens\": False,\n",
|
||
" \"max_new_tokens\": 1024,\n",
|
||
" \"temperature\": 0.1,\n",
|
||
" \"top_p\": 0.95,\n",
|
||
" },\n",
|
||
"}\n",
|
||
"gen_response = requests.post(gen_url, json=gen_data).json()[\"text\"]\n",
|
||
"print_highlight(\"==== Reponse ====\")\n",
|
||
"print(gen_response)\n",
|
||
"\n",
|
||
"# parse the response\n",
|
||
"parse_url = f\"http://localhost:{port}/parse_function_call\"\n",
|
||
"\n",
|
||
"function_call_input = {\n",
|
||
" \"text\": gen_response,\n",
|
||
" \"tool_call_parser\": \"qwen25\",\n",
|
||
" \"tools\": tools,\n",
|
||
"}\n",
|
||
"\n",
|
||
"function_call_response = requests.post(parse_url, json=function_call_input)\n",
|
||
"function_call_response_json = function_call_response.json()\n",
|
||
"\n",
|
||
"print_highlight(\"==== Text ====\")\n",
|
||
"print(function_call_response_json[\"normal_text\"])\n",
|
||
"print_highlight(\"==== Calls ====\")\n",
|
||
"print(\"function name: \", function_call_response_json[\"calls\"][0][\"name\"])\n",
|
||
"print(\"function arguments: \", function_call_response_json[\"calls\"][0][\"parameters\"])"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"terminate_process(server_process)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Offline Engine API"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"import sglang as sgl\n",
|
||
"from sglang.srt.function_call_parser import FunctionCallParser\n",
|
||
"from sglang.srt.managers.io_struct import Tool, Function\n",
|
||
"\n",
|
||
"llm = sgl.Engine(model_path=\"Qwen/Qwen2.5-7B-Instruct\")\n",
|
||
"tokenizer = llm.tokenizer_manager.tokenizer\n",
|
||
"input_ids = tokenizer.apply_chat_template(\n",
|
||
" messages, tokenize=True, add_generation_prompt=True, tools=tools\n",
|
||
")\n",
|
||
"\n",
|
||
"sampling_params = {\n",
|
||
" \"max_new_tokens\": 1024,\n",
|
||
" \"temperature\": 0.1,\n",
|
||
" \"top_p\": 0.95,\n",
|
||
" \"skip_special_tokens\": False,\n",
|
||
"}\n",
|
||
"\n",
|
||
"# 1) Offline generation\n",
|
||
"result = llm.generate(input_ids=input_ids, sampling_params=sampling_params)\n",
|
||
"generated_text = result[\"text\"] # Assume there is only one prompt\n",
|
||
"\n",
|
||
"print(\"=== Offline Engine Output Text ===\")\n",
|
||
"print(generated_text)\n",
|
||
"\n",
|
||
"\n",
|
||
"# 2) Parse using FunctionCallParser\n",
|
||
"def convert_dict_to_tool(tool_dict: dict) -> Tool:\n",
|
||
" function_dict = tool_dict.get(\"function\", {})\n",
|
||
" return Tool(\n",
|
||
" type=tool_dict.get(\"type\", \"function\"),\n",
|
||
" function=Function(\n",
|
||
" name=function_dict.get(\"name\"),\n",
|
||
" description=function_dict.get(\"description\"),\n",
|
||
" parameters=function_dict.get(\"parameters\"),\n",
|
||
" ),\n",
|
||
" )\n",
|
||
"\n",
|
||
"\n",
|
||
"tools = [convert_dict_to_tool(raw_tool) for raw_tool in tools]\n",
|
||
"\n",
|
||
"parser = FunctionCallParser(tools=tools, tool_call_parser=\"qwen25\")\n",
|
||
"normal_text, calls = parser.parse_non_stream(generated_text)\n",
|
||
"\n",
|
||
"print(\"=== Parsing Result ===\")\n",
|
||
"print(\"Normal text portion:\", normal_text)\n",
|
||
"print(\"Function call portion:\")\n",
|
||
"for call in calls:\n",
|
||
" # call: ToolCallItem\n",
|
||
" print(f\" - tool name: {call.name}\")\n",
|
||
" print(f\" parameters: {call.parameters}\")\n",
|
||
"\n",
|
||
"# 3) If needed, perform additional logic on the parsed functions, such as automatically calling the corresponding function to obtain a return value, etc."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"llm.shutdown()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## How to support a new model?\n",
|
||
"1. Update the TOOLS_TAG_LIST in sglang/srt/function_call_parser.py with the model’s tool tags. Currently supported tags include:\n",
|
||
"```\n",
|
||
"\tTOOLS_TAG_LIST = [\n",
|
||
"\t “<|plugin|>“,\n",
|
||
"\t “<function=“,\n",
|
||
"\t “<tool_call>“,\n",
|
||
"\t “<|python_tag|>“,\n",
|
||
"\t “[TOOL_CALLS]”\n",
|
||
"\t]\n",
|
||
"```\n",
|
||
"2. Create a new detector class in sglang/srt/function_call_parser.py that inherits from BaseFormatDetector. The detector should handle the model’s specific function call format. For example:\n",
|
||
"```\n",
|
||
" class NewModelDetector(BaseFormatDetector):\n",
|
||
"```\n",
|
||
"3. Add the new detector to the MultiFormatParser class that manages all the format detectors."
|
||
]
|
||
}
|
||
],
|
||
"metadata": {
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 4
|
||
}
|