sglang0.4.5.post1/docs/backend/native_api.ipynb

516 lines
15 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# SGLang Native APIs\n",
"\n",
"Apart from the OpenAI compatible APIs, the SGLang Runtime also provides its native server APIs. We introduce these following APIs:\n",
"\n",
"- `/generate` (text generation model)\n",
"- `/get_model_info`\n",
"- `/get_server_info`\n",
"- `/health`\n",
"- `/health_generate`\n",
"- `/flush_cache`\n",
"- `/update_weights`\n",
"- `/encode`(embedding model)\n",
"- `/classify`(reward model)\n",
"- `/start_expert_distribution_record`\n",
"- `/stop_expert_distribution_record`\n",
"- `/dump_expert_distribution_record`\n",
"\n",
"We mainly use `requests` to test these APIs in the following examples. You can also use `curl`."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Launch A Server"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import requests\n",
"from sglang.test.test_utils import is_in_ci\n",
"\n",
"if is_in_ci():\n",
" from patch import launch_server_cmd\n",
"else:\n",
" from sglang.utils import launch_server_cmd\n",
"\n",
"from sglang.utils import wait_for_server, print_highlight, terminate_process\n",
"\n",
"\n",
"server_process, port = launch_server_cmd(\n",
" \"python -m sglang.launch_server --model-path meta-llama/Llama-3.2-1B-Instruct --host 0.0.0.0\"\n",
")\n",
"\n",
"wait_for_server(f\"http://localhost:{port}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Generate (text generation model)\n",
"Generate completions. This is similar to the `/v1/completions` in OpenAI API. Detailed parameters can be found in the [sampling parameters](./sampling_params.md)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"url = f\"http://localhost:{port}/generate\"\n",
"data = {\"text\": \"What is the capital of France?\"}\n",
"\n",
"response = requests.post(url, json=data)\n",
"print_highlight(response.json())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": []
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Get Model Info\n",
"\n",
"Get the information of the model.\n",
"\n",
"- `model_path`: The path/name of the model.\n",
"- `is_generation`: Whether the model is used as generation model or embedding model.\n",
"- `tokenizer_path`: The path/name of the tokenizer."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"url = f\"http://localhost:{port}/get_model_info\"\n",
"\n",
"response = requests.get(url)\n",
"response_json = response.json()\n",
"print_highlight(response_json)\n",
"assert response_json[\"model_path\"] == \"meta-llama/Llama-3.2-1B-Instruct\"\n",
"assert response_json[\"is_generation\"] is True\n",
"assert response_json[\"tokenizer_path\"] == \"meta-llama/Llama-3.2-1B-Instruct\"\n",
"assert response_json.keys() == {\"model_path\", \"is_generation\", \"tokenizer_path\"}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Get Server Info\n",
"Gets the server information including CLI arguments, token limits, and memory pool sizes.\n",
"- Note: `get_server_info` merges the following deprecated endpoints:\n",
" - `get_server_args`\n",
" - `get_memory_pool_size` \n",
" - `get_max_total_num_tokens`"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# get_server_info\n",
"\n",
"url = f\"http://localhost:{port}/get_server_info\"\n",
"\n",
"response = requests.get(url)\n",
"print_highlight(response.text)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Health Check\n",
"- `/health`: Check the health of the server.\n",
"- `/health_generate`: Check the health of the server by generating one token."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"url = f\"http://localhost:{port}/health_generate\"\n",
"\n",
"response = requests.get(url)\n",
"print_highlight(response.text)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"url = f\"http://localhost:{port}/health\"\n",
"\n",
"response = requests.get(url)\n",
"print_highlight(response.text)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Flush Cache\n",
"\n",
"Flush the radix cache. It will be automatically triggered when the model weights are updated by the `/update_weights` API."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# flush cache\n",
"\n",
"url = f\"http://localhost:{port}/flush_cache\"\n",
"\n",
"response = requests.post(url)\n",
"print_highlight(response.text)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Update Weights From Disk\n",
"\n",
"Update model weights from disk without restarting the server. Only applicable for models with the same architecture and parameter size.\n",
"\n",
"SGLang support `update_weights_from_disk` API for continuous evaluation during training (save checkpoint to disk and update weights from disk).\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# successful update with same architecture and size\n",
"\n",
"url = f\"http://localhost:{port}/update_weights_from_disk\"\n",
"data = {\"model_path\": \"meta-llama/Llama-3.2-1B\"}\n",
"\n",
"response = requests.post(url, json=data)\n",
"print_highlight(response.text)\n",
"assert response.json()[\"success\"] is True\n",
"assert response.json()[\"message\"] == \"Succeeded to update model weights.\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# failed update with different parameter size or wrong name\n",
"\n",
"url = f\"http://localhost:{port}/update_weights_from_disk\"\n",
"data = {\"model_path\": \"meta-llama/Llama-3.2-1B-wrong\"}\n",
"\n",
"response = requests.post(url, json=data)\n",
"response_json = response.json()\n",
"print_highlight(response_json)\n",
"assert response_json[\"success\"] is False\n",
"assert response_json[\"message\"] == (\n",
" \"Failed to get weights iterator: \"\n",
" \"meta-llama/Llama-3.2-1B-wrong\"\n",
" \" (repository not found).\"\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Encode (embedding model)\n",
"\n",
"Encode text into embeddings. Note that this API is only available for [embedding models](openai_api_embeddings.html#openai-apis-embedding) and will raise an error for generation models.\n",
"Therefore, we launch a new server to server an embedding model."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"terminate_process(server_process)\n",
"\n",
"embedding_process, port = launch_server_cmd(\n",
" \"\"\"\n",
"python -m sglang.launch_server --model-path Alibaba-NLP/gte-Qwen2-7B-instruct \\\n",
" --host 0.0.0.0 --is-embedding\n",
"\"\"\"\n",
")\n",
"\n",
"wait_for_server(f\"http://localhost:{port}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# successful encode for embedding model\n",
"\n",
"url = f\"http://localhost:{port}/encode\"\n",
"data = {\"model\": \"Alibaba-NLP/gte-Qwen2-7B-instruct\", \"text\": \"Once upon a time\"}\n",
"\n",
"response = requests.post(url, json=data)\n",
"response_json = response.json()\n",
"print_highlight(f\"Text embedding (first 10): {response_json['embedding'][:10]}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"terminate_process(embedding_process)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Classify (reward model)\n",
"\n",
"SGLang Runtime also supports reward models. Here we use a reward model to classify the quality of pairwise generations."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"terminate_process(embedding_process)\n",
"\n",
"# Note that SGLang now treats embedding models and reward models as the same type of models.\n",
"# This will be updated in the future.\n",
"\n",
"reward_process, port = launch_server_cmd(\n",
" \"\"\"\n",
"python -m sglang.launch_server --model-path Skywork/Skywork-Reward-Llama-3.1-8B-v0.2 --host 0.0.0.0 --is-embedding\n",
"\"\"\"\n",
")\n",
"\n",
"wait_for_server(f\"http://localhost:{port}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from transformers import AutoTokenizer\n",
"\n",
"PROMPT = (\n",
" \"What is the range of the numeric output of a sigmoid node in a neural network?\"\n",
")\n",
"\n",
"RESPONSE1 = \"The output of a sigmoid node is bounded between -1 and 1.\"\n",
"RESPONSE2 = \"The output of a sigmoid node is bounded between 0 and 1.\"\n",
"\n",
"CONVS = [\n",
" [{\"role\": \"user\", \"content\": PROMPT}, {\"role\": \"assistant\", \"content\": RESPONSE1}],\n",
" [{\"role\": \"user\", \"content\": PROMPT}, {\"role\": \"assistant\", \"content\": RESPONSE2}],\n",
"]\n",
"\n",
"tokenizer = AutoTokenizer.from_pretrained(\"Skywork/Skywork-Reward-Llama-3.1-8B-v0.2\")\n",
"prompts = tokenizer.apply_chat_template(CONVS, tokenize=False)\n",
"\n",
"url = f\"http://localhost:{port}/classify\"\n",
"data = {\"model\": \"Skywork/Skywork-Reward-Llama-3.1-8B-v0.2\", \"text\": prompts}\n",
"\n",
"responses = requests.post(url, json=data).json()\n",
"for response in responses:\n",
" print_highlight(f\"reward: {response['embedding'][0]}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"terminate_process(reward_process)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Capture expert selection distribution in MoE models\n",
"\n",
"SGLang Runtime supports recording the number of times an expert is selected in a MoE model run for each expert in the model. This is useful when analyzing the throughput of the model and plan for optimization."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"expert_record_server_process, port = launch_server_cmd(\n",
" \"python -m sglang.launch_server --model-path Qwen/Qwen1.5-MoE-A2.7B --host 0.0.0.0\"\n",
")\n",
"\n",
"wait_for_server(f\"http://localhost:{port}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"response = requests.post(f\"http://localhost:{port}/start_expert_distribution_record\")\n",
"print_highlight(response)\n",
"\n",
"url = f\"http://localhost:{port}/generate\"\n",
"data = {\"text\": \"What is the capital of France?\"}\n",
"\n",
"response = requests.post(url, json=data)\n",
"print_highlight(response.json())\n",
"\n",
"response = requests.post(f\"http://localhost:{port}/stop_expert_distribution_record\")\n",
"print_highlight(response)\n",
"\n",
"response = requests.post(f\"http://localhost:{port}/dump_expert_distribution_record\")\n",
"print_highlight(response)\n",
"\n",
"import glob\n",
"\n",
"output_file = glob.glob(\"expert_distribution_*.csv\")[0]\n",
"with open(output_file, \"r\") as f:\n",
" print_highlight(\"Content of dumped record:\")\n",
" for line in f:\n",
" print_highlight(line.strip())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"terminate_process(expert_record_server_process)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Skip Tokenizer and Detokenizer\n",
"\n",
"SGLang Runtime also supports skip tokenizer and detokenizer. This is useful in cases like integrating with RLHF workflow."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"tokenizer_free_server_process, port = launch_server_cmd(\n",
" \"\"\"\n",
"python3 -m sglang.launch_server --model-path meta-llama/Llama-3.2-1B-Instruct --skip-tokenizer-init\n",
"\"\"\"\n",
")\n",
"\n",
"wait_for_server(f\"http://localhost:{port}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from transformers import AutoTokenizer\n",
"\n",
"tokenizer = AutoTokenizer.from_pretrained(\"meta-llama/Llama-3.2-1B-Instruct\")\n",
"\n",
"input_text = \"What is the capital of France?\"\n",
"\n",
"input_tokens = tokenizer.encode(input_text)\n",
"print_highlight(f\"Input Text: {input_text}\")\n",
"print_highlight(f\"Tokenized Input: {input_tokens}\")\n",
"\n",
"response = requests.post(\n",
" f\"http://localhost:{port}/generate\",\n",
" json={\n",
" \"input_ids\": input_tokens,\n",
" \"sampling_params\": {\n",
" \"temperature\": 0,\n",
" \"max_new_tokens\": 256,\n",
" \"stop_token_ids\": [tokenizer.eos_token_id],\n",
" },\n",
" \"stream\": False,\n",
" },\n",
")\n",
"output = response.json()\n",
"output_tokens = output[\"output_ids\"]\n",
"\n",
"output_text = tokenizer.decode(output_tokens, skip_special_tokens=False)\n",
"print_highlight(f\"Tokenized Output: {output_tokens}\")\n",
"print_highlight(f\"Decoded Output: {output_text}\")\n",
"print_highlight(f\"Output Text: {output['meta_info']['finish_reason']}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"terminate_process(tokenizer_free_server_process)"
]
}
],
"metadata": {
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}