sglang.0.4.8.post1/sglang/docs/backend/separate_reasoning.ipynb

427 lines
13 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Reasoning Parser\n",
"\n",
"SGLang supports parsing reasoning content our from \"normal\" content for reasoning models such as [DeepSeek R1](https://huggingface.co/deepseek-ai/DeepSeek-R1).\n",
"\n",
"## Supported Models & Parsers\n",
"\n",
"| Model | Reasoning tags | Parser |\n",
"|---------|-----------------------------|------------------|\n",
"| [DeepSeekR1 series](https://huggingface.co/collections/deepseek-ai/deepseek-r1-678e1e131c0169c0bc89728d) | `<think>` … `</think>` | `deepseek-r1` |\n",
"| [Qwen3 and QwQ series](https://huggingface.co/collections/Qwen/qwen3-67dd247413f0e2e4f653967f) | `<think>` … `</think>` | `qwen3` |"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Usage\n",
"\n",
"### Launching the Server"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Specify the `--reasoning-parser` option."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import requests\n",
"from openai import OpenAI\n",
"from sglang.test.test_utils import is_in_ci\n",
"\n",
"if is_in_ci():\n",
" from patch import launch_server_cmd\n",
"else:\n",
" from sglang.utils import launch_server_cmd\n",
"\n",
"from sglang.utils import wait_for_server, print_highlight, terminate_process\n",
"\n",
"\n",
"server_process, port = launch_server_cmd(\n",
" \"python3 -m sglang.launch_server --model-path deepseek-ai/DeepSeek-R1-Distill-Qwen-7B --host 0.0.0.0 --reasoning-parser deepseek-r1\"\n",
")\n",
"\n",
"wait_for_server(f\"http://localhost:{port}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Note that `--reasoning-parser` defines the parser used to interpret responses."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### OpenAI Compatible API\n",
"\n",
"Using the OpenAI compatible API, the contract follows the [DeepSeek API design](https://api-docs.deepseek.com/guides/reasoning_model) established with the release of DeepSeek-R1:\n",
"\n",
"- `reasoning_content`: The content of the CoT.\n",
"- `content`: The content of the final answer."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Initialize OpenAI-like client\n",
"client = OpenAI(api_key=\"None\", base_url=f\"http://0.0.0.0:{port}/v1\")\n",
"model_name = client.models.list().data[0].id\n",
"\n",
"messages = [\n",
" {\n",
" \"role\": \"user\",\n",
" \"content\": \"What is 1+3?\",\n",
" }\n",
"]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Non-Streaming Request"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"response_non_stream = client.chat.completions.create(\n",
" model=model_name,\n",
" messages=messages,\n",
" temperature=0.6,\n",
" top_p=0.95,\n",
" stream=False, # Non-streaming\n",
" extra_body={\"separate_reasoning\": True},\n",
")\n",
"print_highlight(\"==== Reasoning ====\")\n",
"print_highlight(response_non_stream.choices[0].message.reasoning_content)\n",
"\n",
"print_highlight(\"==== Text ====\")\n",
"print_highlight(response_non_stream.choices[0].message.content)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Streaming Request"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"response_stream = client.chat.completions.create(\n",
" model=model_name,\n",
" messages=messages,\n",
" temperature=0.6,\n",
" top_p=0.95,\n",
" stream=True, # Non-streaming\n",
" extra_body={\"separate_reasoning\": True},\n",
")\n",
"\n",
"reasoning_content = \"\"\n",
"content = \"\"\n",
"for chunk in response_stream:\n",
" if chunk.choices[0].delta.content:\n",
" content += chunk.choices[0].delta.content\n",
" if chunk.choices[0].delta.reasoning_content:\n",
" reasoning_content += chunk.choices[0].delta.reasoning_content\n",
"\n",
"print_highlight(\"==== Reasoning ====\")\n",
"print_highlight(reasoning_content)\n",
"\n",
"print_highlight(\"==== Text ====\")\n",
"print_highlight(content)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Optionally, you can buffer the reasoning content to the last reasoning chunk (or the first chunk after the reasoning content)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"response_stream = client.chat.completions.create(\n",
" model=model_name,\n",
" messages=messages,\n",
" temperature=0.6,\n",
" top_p=0.95,\n",
" stream=True, # Non-streaming\n",
" extra_body={\"separate_reasoning\": True, \"stream_reasoning\": False},\n",
")\n",
"\n",
"reasoning_content = \"\"\n",
"content = \"\"\n",
"for chunk in response_stream:\n",
" if chunk.choices[0].delta.content:\n",
" content += chunk.choices[0].delta.content\n",
" if chunk.choices[0].delta.reasoning_content:\n",
" reasoning_content = chunk.choices[0].delta.reasoning_content\n",
"\n",
"print_highlight(\"==== Reasoning ====\")\n",
"print_highlight(reasoning_content)\n",
"\n",
"print_highlight(\"==== Text ====\")\n",
"print_highlight(content)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The reasoning separation is enable by default when specify . \n",
"**To disable it, set the `separate_reasoning` option to `False` in request.**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"response_non_stream = client.chat.completions.create(\n",
" model=model_name,\n",
" messages=messages,\n",
" temperature=0.6,\n",
" top_p=0.95,\n",
" stream=False, # Non-streaming\n",
" extra_body={\"separate_reasoning\": False},\n",
")\n",
"\n",
"print_highlight(\"==== Original Output ====\")\n",
"print_highlight(response_non_stream.choices[0].message.content)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### SGLang Native API "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from transformers import AutoTokenizer\n",
"\n",
"tokenizer = AutoTokenizer.from_pretrained(\"deepseek-ai/DeepSeek-R1-Distill-Qwen-7B\")\n",
"input = tokenizer.apply_chat_template(\n",
" messages,\n",
" tokenize=False,\n",
" add_generation_prompt=True,\n",
")\n",
"\n",
"gen_url = f\"http://localhost:{port}/generate\"\n",
"gen_data = {\n",
" \"text\": input,\n",
" \"sampling_params\": {\n",
" \"skip_special_tokens\": False,\n",
" \"max_new_tokens\": 1024,\n",
" \"temperature\": 0.6,\n",
" \"top_p\": 0.95,\n",
" },\n",
"}\n",
"gen_response = requests.post(gen_url, json=gen_data).json()[\"text\"]\n",
"\n",
"print_highlight(\"==== Original Output ====\")\n",
"print_highlight(gen_response)\n",
"\n",
"parse_url = f\"http://localhost:{port}/separate_reasoning\"\n",
"separate_reasoning_data = {\n",
" \"text\": gen_response,\n",
" \"reasoning_parser\": \"deepseek-r1\",\n",
"}\n",
"separate_reasoning_response_json = requests.post(\n",
" parse_url, json=separate_reasoning_data\n",
").json()\n",
"print_highlight(\"==== Reasoning ====\")\n",
"print_highlight(separate_reasoning_response_json[\"reasoning_text\"])\n",
"print_highlight(\"==== Text ====\")\n",
"print_highlight(separate_reasoning_response_json[\"text\"])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"terminate_process(server_process)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Offline Engine API"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import sglang as sgl\n",
"from sglang.srt.reasoning_parser import ReasoningParser\n",
"from sglang.utils import print_highlight\n",
"\n",
"llm = sgl.Engine(model_path=\"deepseek-ai/DeepSeek-R1-Distill-Qwen-7B\")\n",
"tokenizer = AutoTokenizer.from_pretrained(\"deepseek-ai/DeepSeek-R1-Distill-Qwen-7B\")\n",
"input = tokenizer.apply_chat_template(\n",
" messages,\n",
" tokenize=False,\n",
" add_generation_prompt=True,\n",
")\n",
"sampling_params = {\n",
" \"max_new_tokens\": 1024,\n",
" \"skip_special_tokens\": False,\n",
" \"temperature\": 0.6,\n",
" \"top_p\": 0.95,\n",
"}\n",
"result = llm.generate(prompt=input, sampling_params=sampling_params)\n",
"\n",
"generated_text = result[\"text\"] # Assume there is only one prompt\n",
"\n",
"print_highlight(\"==== Original Output ====\")\n",
"print_highlight(generated_text)\n",
"\n",
"parser = ReasoningParser(\"deepseek-r1\")\n",
"reasoning_text, text = parser.parse_non_stream(generated_text)\n",
"print_highlight(\"==== Reasoning ====\")\n",
"print_highlight(reasoning_text)\n",
"print_highlight(\"==== Text ====\")\n",
"print_highlight(text)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"llm.shutdown()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Supporting New Reasoning Model Schemas\n",
"\n",
"For future reasoning models, you can implement the reasoning parser as a subclass of `BaseReasoningFormatDetector` in `python/sglang/srt/reasoning_parser.py` and specify the reasoning parser for new reasoning model schemas accordingly."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"```python\n",
"class DeepSeekR1Detector(BaseReasoningFormatDetector):\n",
" \"\"\"\n",
" Detector for DeepSeek-R1 model.\n",
" Assumes reasoning format:\n",
" (<think>)*(.*)</think>\n",
" Returns all the text before the </think> tag as `reasoning_text`\n",
" and the rest of the text as `normal_text`.\n",
"\n",
" Args:\n",
" stream_reasoning (bool): If False, accumulates reasoning content until the end tag.\n",
" If True, streams reasoning content as it arrives.\n",
" \"\"\"\n",
"\n",
" def __init__(self, stream_reasoning: bool = False):\n",
" # DeepSeek-R1 is assumed to be reasoning until `</think>` token\n",
" super().__init__(\"<think>\", \"</think>\", True, stream_reasoning=stream_reasoning)\n",
" # https://github.com/sgl-project/sglang/pull/3202#discussion_r1950153599\n",
"\n",
"\n",
"class ReasoningParser:\n",
" \"\"\"\n",
" Parser that handles both streaming and non-streaming scenarios for extracting\n",
" reasoning content from model outputs.\n",
"\n",
" Args:\n",
" model_type (str): Type of model to parse reasoning from\n",
" stream_reasoning (bool): If False, accumulates reasoning content until complete.\n",
" If True, streams reasoning content as it arrives.\n",
" \"\"\"\n",
"\n",
" DetectorMap: Dict[str, BaseReasoningFormatDetector] = {\n",
" \"deepseek-r1\": DeepSeekR1Detector\n",
" }\n",
"\n",
" def __init__(self, model_type: str = None, stream_reasoning: bool = True):\n",
" if not model_type:\n",
" raise ValueError(\"Model type must be specified\")\n",
"\n",
" detector_class = self.DetectorMap.get(model_type.lower())\n",
" if not detector_class:\n",
" raise ValueError(f\"Unsupported model type: {model_type}\")\n",
"\n",
" self.detector = detector_class(stream_reasoning=stream_reasoning)\n",
"\n",
" def parse_non_stream(self, full_text: str) -> StreamingParseResult:\n",
" \"\"\"Non-streaming call: one-time parsing\"\"\"\n",
" ret = self.detector.detect_and_parse(full_text)\n",
" return ret.reasoning_text, ret.normal_text\n",
"\n",
" def parse_stream_chunk(self, chunk_text: str) -> StreamingParseResult:\n",
" \"\"\"Streaming call: incremental parsing\"\"\"\n",
" ret = self.detector.parse_streaming_increment(chunk_text)\n",
" return ret.reasoning_text, ret.normal_text\n",
"```"
]
}
],
"metadata": {
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3"
}
},
"nbformat": 4,
"nbformat_minor": 4
}