{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Reasoning Parser\n", "\n", "SGLang supports parsing reasoning content out from \"normal\" content for reasoning models such as [DeepSeek R1](https://huggingface.co/deepseek-ai/DeepSeek-R1).\n", "\n", "## Supported Models & Parsers\n", "\n", "| Model | Reasoning tags | Parser | Notes |\n", "|---------|-----------------------------|------------------|-------|\n", "| [DeepSeek‑R1 series](https://huggingface.co/collections/deepseek-ai/deepseek-r1-678e1e131c0169c0bc89728d) | `` … `` | `deepseek-r1` | Supports all variants (R1, R1-0528, R1-Distill) |\n", "| [DeepSeek‑V3.1](https://huggingface.co/deepseek-ai/DeepSeek-V3.1) | `` … `` | `deepseek-v3` | Supports `thinking` parameter |\n", "| [Standard Qwen3 models](https://huggingface.co/collections/Qwen/qwen3-67dd247413f0e2e4f653967f) | `` … `` | `qwen3` | Supports `enable_thinking` parameter |\n", "| [Qwen3-Thinking models](https://huggingface.co/Qwen/Qwen3-235B-A22B-Thinking-2507) | `` … `` | `qwen3` or `qwen3-thinking` | Always generates thinking content |\n", "| [Kimi models](https://huggingface.co/moonshotai/models) | `◁think▷` … `◁/think▷` | `kimi` | Uses special thinking delimiters |\n", "| [GPT OSS](https://huggingface.co/openai/gpt-oss-120b) | `<\\|channel\\|>analysis<\\|message\\|>` … `<\\|end\\|>` | `gpt-oss` | N/A |\n", "### Model-Specific Behaviors\n", "\n", "**DeepSeek-R1 Family:**\n", "- DeepSeek-R1: No `` start tag, jumps directly to thinking content\n", "- DeepSeek-R1-0528: Generates both `` start and `` end tags\n", "- Both are handled by the same `deepseek-r1` parser\n", "\n", "**DeepSeek-V3 Family:**\n", "- DeepSeek-V3.1: Hybrid model supporting both thinking and non-thinking modes, use the `deepseek-v3` parser and `thinking` parameter (NOTE: not `enable_thinking`)\n", "\n", "**Qwen3 Family:**\n", "- Standard Qwen3 (e.g., Qwen3-2507): Use `qwen3` parser, supports `enable_thinking` in chat templates\n", "- Qwen3-Thinking (e.g., Qwen3-235B-A22B-Thinking-2507): Use `qwen3` or `qwen3-thinking` parser, always thinks\n", "\n", "**Kimi:**\n", "- Kimi: Uses special `◁think▷` and `◁/think▷` tags\n", "\n", "**GPT OSS:**\n", "- GPT OSS: Uses special `<|channel|>analysis<|message|>` and `<|end|>` tags" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Usage\n", "\n", "### Launching the Server" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Specify the `--reasoning-parser` option." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import requests\n", "from openai import OpenAI\n", "from sglang.test.doc_patch import launch_server_cmd\n", "from sglang.utils import wait_for_server, print_highlight, terminate_process\n", "\n", "server_process, port = launch_server_cmd(\n", " \"python3 -m sglang.launch_server --model-path deepseek-ai/DeepSeek-R1-Distill-Qwen-7B --host 0.0.0.0 --reasoning-parser deepseek-r1 --log-level warning\"\n", ")\n", "\n", "wait_for_server(f\"http://localhost:{port}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note that `--reasoning-parser` defines the parser used to interpret responses." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### OpenAI Compatible API\n", "\n", "Using the OpenAI compatible API, the contract follows the [DeepSeek API design](https://api-docs.deepseek.com/guides/reasoning_model) established with the release of DeepSeek-R1:\n", "\n", "- `reasoning_content`: The content of the CoT.\n", "- `content`: The content of the final answer." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Initialize OpenAI-like client\n", "client = OpenAI(api_key=\"None\", base_url=f\"http://0.0.0.0:{port}/v1\")\n", "model_name = client.models.list().data[0].id\n", "\n", "messages = [\n", " {\n", " \"role\": \"user\",\n", " \"content\": \"What is 1+3?\",\n", " }\n", "]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Non-Streaming Request" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "response_non_stream = client.chat.completions.create(\n", " model=model_name,\n", " messages=messages,\n", " temperature=0.6,\n", " top_p=0.95,\n", " stream=False, # Non-streaming\n", " extra_body={\"separate_reasoning\": True},\n", ")\n", "print_highlight(\"==== Reasoning ====\")\n", "print_highlight(response_non_stream.choices[0].message.reasoning_content)\n", "\n", "print_highlight(\"==== Text ====\")\n", "print_highlight(response_non_stream.choices[0].message.content)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Streaming Request" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "response_stream = client.chat.completions.create(\n", " model=model_name,\n", " messages=messages,\n", " temperature=0.6,\n", " top_p=0.95,\n", " stream=True, # Non-streaming\n", " extra_body={\"separate_reasoning\": True},\n", ")\n", "\n", "reasoning_content = \"\"\n", "content = \"\"\n", "for chunk in response_stream:\n", " if chunk.choices[0].delta.content:\n", " content += chunk.choices[0].delta.content\n", " if chunk.choices[0].delta.reasoning_content:\n", " reasoning_content += chunk.choices[0].delta.reasoning_content\n", "\n", "print_highlight(\"==== Reasoning ====\")\n", "print_highlight(reasoning_content)\n", "\n", "print_highlight(\"==== Text ====\")\n", "print_highlight(content)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Optionally, you can buffer the reasoning content to the last reasoning chunk (or the first chunk after the reasoning content)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "response_stream = client.chat.completions.create(\n", " model=model_name,\n", " messages=messages,\n", " temperature=0.6,\n", " top_p=0.95,\n", " stream=True, # Non-streaming\n", " extra_body={\"separate_reasoning\": True, \"stream_reasoning\": False},\n", ")\n", "\n", "reasoning_content = \"\"\n", "content = \"\"\n", "for chunk in response_stream:\n", " if chunk.choices[0].delta.content:\n", " content += chunk.choices[0].delta.content\n", " if chunk.choices[0].delta.reasoning_content:\n", " reasoning_content += chunk.choices[0].delta.reasoning_content\n", "\n", "print_highlight(\"==== Reasoning ====\")\n", "print_highlight(reasoning_content)\n", "\n", "print_highlight(\"==== Text ====\")\n", "print_highlight(content)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The reasoning separation is enable by default when specify . \n", "**To disable it, set the `separate_reasoning` option to `False` in request.**" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "response_non_stream = client.chat.completions.create(\n", " model=model_name,\n", " messages=messages,\n", " temperature=0.6,\n", " top_p=0.95,\n", " stream=False, # Non-streaming\n", " extra_body={\"separate_reasoning\": False},\n", ")\n", "\n", "print_highlight(\"==== Original Output ====\")\n", "print_highlight(response_non_stream.choices[0].message.content)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### SGLang Native API " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from transformers import AutoTokenizer\n", "\n", "tokenizer = AutoTokenizer.from_pretrained(\"deepseek-ai/DeepSeek-R1-Distill-Qwen-7B\")\n", "input = tokenizer.apply_chat_template(\n", " messages,\n", " tokenize=False,\n", " add_generation_prompt=True,\n", ")\n", "\n", "gen_url = f\"http://localhost:{port}/generate\"\n", "gen_data = {\n", " \"text\": input,\n", " \"sampling_params\": {\n", " \"skip_special_tokens\": False,\n", " \"max_new_tokens\": 1024,\n", " \"temperature\": 0.6,\n", " \"top_p\": 0.95,\n", " },\n", "}\n", "gen_response = requests.post(gen_url, json=gen_data).json()[\"text\"]\n", "\n", "print_highlight(\"==== Original Output ====\")\n", "print_highlight(gen_response)\n", "\n", "parse_url = f\"http://localhost:{port}/separate_reasoning\"\n", "separate_reasoning_data = {\n", " \"text\": gen_response,\n", " \"reasoning_parser\": \"deepseek-r1\",\n", "}\n", "separate_reasoning_response_json = requests.post(\n", " parse_url, json=separate_reasoning_data\n", ").json()\n", "print_highlight(\"==== Reasoning ====\")\n", "print_highlight(separate_reasoning_response_json[\"reasoning_text\"])\n", "print_highlight(\"==== Text ====\")\n", "print_highlight(separate_reasoning_response_json[\"text\"])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "terminate_process(server_process)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Offline Engine API" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import sglang as sgl\n", "from sglang.srt.parser.reasoning_parser import ReasoningParser\n", "from sglang.utils import print_highlight\n", "\n", "llm = sgl.Engine(model_path=\"deepseek-ai/DeepSeek-R1-Distill-Qwen-7B\")\n", "tokenizer = AutoTokenizer.from_pretrained(\"deepseek-ai/DeepSeek-R1-Distill-Qwen-7B\")\n", "input = tokenizer.apply_chat_template(\n", " messages,\n", " tokenize=False,\n", " add_generation_prompt=True,\n", ")\n", "sampling_params = {\n", " \"max_new_tokens\": 1024,\n", " \"skip_special_tokens\": False,\n", " \"temperature\": 0.6,\n", " \"top_p\": 0.95,\n", "}\n", "result = llm.generate(prompt=input, sampling_params=sampling_params)\n", "\n", "generated_text = result[\"text\"] # Assume there is only one prompt\n", "\n", "print_highlight(\"==== Original Output ====\")\n", "print_highlight(generated_text)\n", "\n", "parser = ReasoningParser(\"deepseek-r1\")\n", "reasoning_text, text = parser.parse_non_stream(generated_text)\n", "print_highlight(\"==== Reasoning ====\")\n", "print_highlight(reasoning_text)\n", "print_highlight(\"==== Text ====\")\n", "print_highlight(text)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "llm.shutdown()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Supporting New Reasoning Model Schemas\n", "\n", "For future reasoning models, you can implement the reasoning parser as a subclass of `BaseReasoningFormatDetector` in `python/sglang/srt/reasoning_parser.py` and specify the reasoning parser for new reasoning model schemas accordingly." ] } ], "metadata": { "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3" } }, "nbformat": 4, "nbformat_minor": 4 }