{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Tool and Function Calling\n", "\n", "This guide demonstrates how to use SGLang’s [Funcion calling](https://platform.openai.com/docs/guides/function-calling) functionality." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## OpenAI Compatible API" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Launching the Server" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from openai import OpenAI\n", "import json\n", "from sglang.utils import wait_for_server, print_highlight, terminate_process\n", "from sglang.test.test_utils import is_in_ci\n", "\n", "if is_in_ci():\n", " from patch import launch_server_cmd\n", "else:\n", " from sglang.utils import launch_server_cmd\n", "\n", "\n", "server_process, port = launch_server_cmd(\n", " \"python3 -m sglang.launch_server --model-path Qwen/Qwen2.5-7B-Instruct --tool-call-parser qwen25 --host 0.0.0.0\" # qwen25\n", ")\n", "wait_for_server(f\"http://localhost:{port}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note that `--tool-call-parser` defines the parser used to interpret responses. Currently supported parsers include:\n", "\n", "- llama3: Llama 3.1 / 3.2 (e.g. meta-llama/Llama-3.1-8B-Instruct, meta-llama/Llama-3.2-1B-Instruct).\n", "- mistral: Mistral (e.g. mistralai/Mistral-7B-Instruct-v0.3, mistralai/Mistral-Nemo-Instruct-2407, mistralai/\n", "Mistral-Nemo-Instruct-2407, mistralai/Mistral-7B-v0.3).\n", "- qwen25: Qwen 2.5 (e.g. Qwen/Qwen2.5-1.5B-Instruct, Qwen/Qwen2.5-7B-Instruct) and QwQ (i.e. Qwen/QwQ-32B). Especially, for QwQ, we can enable the reasoning parser together with tool call parser, details about reasoning parser can be found in [reasoning parser](https://docs.sglang.ai/backend/separate_reasoning.html)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Define Tools for Function Call\n", "Below is a Python snippet that shows how to define a tool as a dictionary. The dictionary includes a tool name, a description, and property defined Parameters." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Define tools\n", "tools = [\n", " {\n", " \"type\": \"function\",\n", " \"function\": {\n", " \"name\": \"get_current_weather\",\n", " \"description\": \"Get the current weather in a given location\",\n", " \"parameters\": {\n", " \"type\": \"object\",\n", " \"properties\": {\n", " \"city\": {\n", " \"type\": \"string\",\n", " \"description\": \"The city to find the weather for, e.g. 'San Francisco'\",\n", " },\n", " \"state\": {\n", " \"type\": \"string\",\n", " \"description\": \"the two-letter abbreviation for the state that the city is\"\n", " \" in, e.g. 'CA' which would mean 'California'\",\n", " },\n", " \"unit\": {\n", " \"type\": \"string\",\n", " \"description\": \"The unit to fetch the temperature in\",\n", " \"enum\": [\"celsius\", \"fahrenheit\"],\n", " },\n", " },\n", " \"required\": [\"city\", \"state\", \"unit\"],\n", " },\n", " },\n", " }\n", "]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Define Messages" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def get_messages():\n", " return [\n", " {\n", " \"role\": \"user\",\n", " \"content\": \"What's the weather like in Boston today? Output a reasoning before act, then use the tools to help you.\",\n", " }\n", " ]\n", "\n", "\n", "messages = get_messages()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Initialize the Client" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Initialize OpenAI-like client\n", "client = OpenAI(api_key=\"None\", base_url=f\"http://0.0.0.0:{port}/v1\")\n", "model_name = client.models.list().data[0].id" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Non-Streaming Request" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Non-streaming mode test\n", "response_non_stream = client.chat.completions.create(\n", " model=model_name,\n", " messages=messages,\n", " temperature=0.1,\n", " top_p=0.95,\n", " max_tokens=1024,\n", " stream=False, # Non-streaming\n", " tools=tools,\n", ")\n", "print_highlight(\"Non-stream response:\")\n", "print(response_non_stream)\n", "print_highlight(\"==== content ====\")\n", "print(response_non_stream.choices[0].message.content)\n", "print_highlight(\"==== tool_calls ====\")\n", "print(response_non_stream.choices[0].message.tool_calls)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Handle Tools\n", "When the engine determines it should call a particular tool, it will return arguments or partial arguments through the response. You can parse these arguments and later invoke the tool accordingly." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "name_non_stream = response_non_stream.choices[0].message.tool_calls[0].function.name\n", "arguments_non_stream = (\n", " response_non_stream.choices[0].message.tool_calls[0].function.arguments\n", ")\n", "\n", "print_highlight(f\"Final streamed function call name: {name_non_stream}\")\n", "print_highlight(f\"Final streamed function call arguments: {arguments_non_stream}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Streaming Request" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Streaming mode test\n", "print_highlight(\"Streaming response:\")\n", "response_stream = client.chat.completions.create(\n", " model=model_name,\n", " messages=messages,\n", " temperature=0.1,\n", " top_p=0.95,\n", " max_tokens=1024,\n", " stream=True, # Enable streaming\n", " tools=tools,\n", ")\n", "\n", "texts = \"\"\n", "tool_calls = []\n", "name = \"\"\n", "arguments = \"\"\n", "for chunk in response_stream:\n", " if chunk.choices[0].delta.content:\n", " texts += chunk.choices[0].delta.content\n", " if chunk.choices[0].delta.tool_calls:\n", " tool_calls.append(chunk.choices[0].delta.tool_calls[0])\n", "print_highlight(\"==== Text ====\")\n", "print(texts)\n", "\n", "print_highlight(\"==== Tool Call ====\")\n", "for tool_call in tool_calls:\n", " print(tool_call)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Handle Tools\n", "When the engine determines it should call a particular tool, it will return arguments or partial arguments through the response. You can parse these arguments and later invoke the tool accordingly." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Parse and combine function call arguments\n", "arguments = []\n", "for tool_call in tool_calls:\n", " if tool_call.function.name:\n", " print_highlight(f\"Streamed function call name: {tool_call.function.name}\")\n", "\n", " if tool_call.function.arguments:\n", " arguments.append(tool_call.function.arguments)\n", "\n", "# Combine all fragments into a single JSON string\n", "full_arguments = \"\".join(arguments)\n", "print_highlight(f\"streamed function call arguments: {full_arguments}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Define a Tool Function" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# This is a demonstration, define real function according to your usage.\n", "def get_current_weather(city: str, state: str, unit: \"str\"):\n", " return (\n", " f\"The weather in {city}, {state} is 85 degrees {unit}. It is \"\n", " \"partly cloudly, with highs in the 90's.\"\n", " )\n", "\n", "\n", "available_tools = {\"get_current_weather\": get_current_weather}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "## Execute the Tool" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "call_data = json.loads(full_arguments)\n", "\n", "messages.append(\n", " {\n", " \"role\": \"user\",\n", " \"content\": \"\",\n", " \"tool_calls\": {\"name\": \"get_current_weather\", \"arguments\": full_arguments},\n", " }\n", ")\n", "\n", "# Call the corresponding tool function\n", "tool_name = messages[-1][\"tool_calls\"][\"name\"]\n", "tool_to_call = available_tools[tool_name]\n", "result = tool_to_call(**call_data)\n", "print_highlight(f\"Function call result: {result}\")\n", "messages.append({\"role\": \"tool\", \"content\": result, \"name\": tool_name})\n", "\n", "print_highlight(f\"Updated message history: {messages}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Send Results Back to Model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "final_response = client.chat.completions.create(\n", " model=model_name,\n", " messages=messages,\n", " temperature=0.1,\n", " top_p=0.95,\n", " stream=False,\n", " tools=tools,\n", ")\n", "print_highlight(\"Non-stream response:\")\n", "print(final_response)\n", "\n", "print_highlight(\"==== Text ====\")\n", "print(final_response.choices[0].message.content)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Native API and SGLang Runtime (SRT)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from transformers import AutoTokenizer\n", "import requests\n", "\n", "# generate an answer\n", "tokenizer = AutoTokenizer.from_pretrained(\"Qwen/Qwen2.5-7B-Instruct\")\n", "\n", "messages = get_messages()\n", "\n", "input = tokenizer.apply_chat_template(\n", " messages,\n", " tokenize=False,\n", " add_generation_prompt=True,\n", " tools=tools,\n", ")\n", "\n", "gen_url = f\"http://localhost:{port}/generate\"\n", "gen_data = {\n", " \"text\": input,\n", " \"sampling_params\": {\n", " \"skip_special_tokens\": False,\n", " \"max_new_tokens\": 1024,\n", " \"temperature\": 0.1,\n", " \"top_p\": 0.95,\n", " },\n", "}\n", "gen_response = requests.post(gen_url, json=gen_data).json()[\"text\"]\n", "print_highlight(\"==== Reponse ====\")\n", "print(gen_response)\n", "\n", "# parse the response\n", "parse_url = f\"http://localhost:{port}/parse_function_call\"\n", "\n", "function_call_input = {\n", " \"text\": gen_response,\n", " \"tool_call_parser\": \"qwen25\",\n", " \"tools\": tools,\n", "}\n", "\n", "function_call_response = requests.post(parse_url, json=function_call_input)\n", "function_call_response_json = function_call_response.json()\n", "\n", "print_highlight(\"==== Text ====\")\n", "print(function_call_response_json[\"normal_text\"])\n", "print_highlight(\"==== Calls ====\")\n", "print(\"function name: \", function_call_response_json[\"calls\"][0][\"name\"])\n", "print(\"function arguments: \", function_call_response_json[\"calls\"][0][\"parameters\"])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "terminate_process(server_process)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Offline Engine API" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import sglang as sgl\n", "from sglang.srt.function_call_parser import FunctionCallParser\n", "from sglang.srt.managers.io_struct import Tool, Function\n", "\n", "llm = sgl.Engine(model_path=\"Qwen/Qwen2.5-7B-Instruct\")\n", "tokenizer = llm.tokenizer_manager.tokenizer\n", "input_ids = tokenizer.apply_chat_template(\n", " messages, tokenize=True, add_generation_prompt=True, tools=tools\n", ")\n", "\n", "sampling_params = {\n", " \"max_new_tokens\": 1024,\n", " \"temperature\": 0.1,\n", " \"top_p\": 0.95,\n", " \"skip_special_tokens\": False,\n", "}\n", "\n", "# 1) Offline generation\n", "result = llm.generate(input_ids=input_ids, sampling_params=sampling_params)\n", "generated_text = result[\"text\"] # Assume there is only one prompt\n", "\n", "print(\"=== Offline Engine Output Text ===\")\n", "print(generated_text)\n", "\n", "\n", "# 2) Parse using FunctionCallParser\n", "def convert_dict_to_tool(tool_dict: dict) -> Tool:\n", " function_dict = tool_dict.get(\"function\", {})\n", " return Tool(\n", " type=tool_dict.get(\"type\", \"function\"),\n", " function=Function(\n", " name=function_dict.get(\"name\"),\n", " description=function_dict.get(\"description\"),\n", " parameters=function_dict.get(\"parameters\"),\n", " ),\n", " )\n", "\n", "\n", "tools = [convert_dict_to_tool(raw_tool) for raw_tool in tools]\n", "\n", "parser = FunctionCallParser(tools=tools, tool_call_parser=\"qwen25\")\n", "normal_text, calls = parser.parse_non_stream(generated_text)\n", "\n", "print(\"=== Parsing Result ===\")\n", "print(\"Normal text portion:\", normal_text)\n", "print(\"Function call portion:\")\n", "for call in calls:\n", " # call: ToolCallItem\n", " print(f\" - tool name: {call.name}\")\n", " print(f\" parameters: {call.parameters}\")\n", "\n", "# 3) If needed, perform additional logic on the parsed functions, such as automatically calling the corresponding function to obtain a return value, etc." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "llm.shutdown()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## How to support a new model?\n", "1. Update the TOOLS_TAG_LIST in sglang/srt/function_call_parser.py with the model’s tool tags. Currently supported tags include:\n", "```\n", "\tTOOLS_TAG_LIST = [\n", "\t “<|plugin|>“,\n", "\t ““,\n", "\t “<|python_tag|>“,\n", "\t “[TOOL_CALLS]”\n", "\t]\n", "```\n", "2. Create a new detector class in sglang/srt/function_call_parser.py that inherits from BaseFormatDetector. The detector should handle the model’s specific function call format. For example:\n", "```\n", " class NewModelDetector(BaseFormatDetector):\n", "```\n", "3. Add the new detector to the MultiFormatParser class that manages all the format detectors." ] } ], "metadata": { "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3" } }, "nbformat": 4, "nbformat_minor": 4 }