sglang_v0.5.2/sglang/docs/advanced_features/vlm_query.ipynb

326 lines
8.1 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"id": "0",
"metadata": {},
"source": [
"# Query Vision Language Model"
]
},
{
"cell_type": "markdown",
"id": "1",
"metadata": {},
"source": [
"## Querying Qwen-VL"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2",
"metadata": {},
"outputs": [],
"source": [
"import nest_asyncio\n",
"\n",
"nest_asyncio.apply() # Run this first.\n",
"\n",
"model_path = \"Qwen/Qwen2.5-VL-3B-Instruct\"\n",
"chat_template = \"qwen2-vl\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3",
"metadata": {},
"outputs": [],
"source": [
"# Lets create a prompt.\n",
"\n",
"from io import BytesIO\n",
"import requests\n",
"from PIL import Image\n",
"\n",
"from sglang.srt.parser.conversation import chat_templates\n",
"\n",
"image = Image.open(\n",
" BytesIO(\n",
" requests.get(\n",
" \"https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true\"\n",
" ).content\n",
" )\n",
")\n",
"\n",
"conv = chat_templates[chat_template].copy()\n",
"conv.append_message(conv.roles[0], f\"What's shown here: {conv.image_token}?\")\n",
"conv.append_message(conv.roles[1], \"\")\n",
"conv.image_data = [image]\n",
"\n",
"print(conv.get_prompt())\n",
"image"
]
},
{
"cell_type": "markdown",
"id": "4",
"metadata": {},
"source": [
"### Query via the offline Engine API"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5",
"metadata": {},
"outputs": [],
"source": [
"from sglang import Engine\n",
"\n",
"llm = Engine(\n",
" model_path=model_path, chat_template=chat_template, mem_fraction_static=0.8\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6",
"metadata": {},
"outputs": [],
"source": [
"out = llm.generate(prompt=conv.get_prompt(), image_data=[image])\n",
"print(out[\"text\"])"
]
},
{
"cell_type": "markdown",
"id": "7",
"metadata": {},
"source": [
"### Query via the offline Engine API, but send precomputed embeddings"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8",
"metadata": {},
"outputs": [],
"source": [
"# Compute the image embeddings using Huggingface.\n",
"\n",
"from transformers import AutoProcessor\n",
"from transformers import Qwen2_5_VLForConditionalGeneration\n",
"\n",
"processor = AutoProcessor.from_pretrained(model_path, use_fast=True)\n",
"vision = (\n",
" Qwen2_5_VLForConditionalGeneration.from_pretrained(model_path).eval().visual.cuda()\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9",
"metadata": {},
"outputs": [],
"source": [
"processed_prompt = processor(\n",
" images=[image], text=conv.get_prompt(), return_tensors=\"pt\"\n",
")\n",
"input_ids = processed_prompt[\"input_ids\"][0].detach().cpu().tolist()\n",
"precomputed_embeddings = vision(\n",
" processed_prompt[\"pixel_values\"].cuda(), processed_prompt[\"image_grid_thw\"].cuda()\n",
")\n",
"\n",
"mm_item = dict(\n",
" modality=\"IMAGE\",\n",
" image_grid_thw=processed_prompt[\"image_grid_thw\"],\n",
" precomputed_embeddings=precomputed_embeddings,\n",
")\n",
"out = llm.generate(input_ids=input_ids, image_data=[mm_item])\n",
"print(out[\"text\"])"
]
},
{
"cell_type": "markdown",
"id": "10",
"metadata": {},
"source": [
"## Querying Llama 4 (Vision)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "11",
"metadata": {},
"outputs": [],
"source": [
"import nest_asyncio\n",
"\n",
"nest_asyncio.apply() # Run this first.\n",
"\n",
"model_path = \"meta-llama/Llama-4-Scout-17B-16E-Instruct\"\n",
"chat_template = \"llama-4\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "12",
"metadata": {},
"outputs": [],
"source": [
"# Lets create a prompt.\n",
"\n",
"from io import BytesIO\n",
"import requests\n",
"from PIL import Image\n",
"\n",
"from sglang.srt.parser.conversation import chat_templates\n",
"\n",
"image = Image.open(\n",
" BytesIO(\n",
" requests.get(\n",
" \"https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true\"\n",
" ).content\n",
" )\n",
")\n",
"\n",
"conv = chat_templates[chat_template].copy()\n",
"conv.append_message(conv.roles[0], f\"What's shown here: {conv.image_token}?\")\n",
"conv.append_message(conv.roles[1], \"\")\n",
"conv.image_data = [image]\n",
"\n",
"print(conv.get_prompt())\n",
"print(f\"Image size: {image.size}\")\n",
"\n",
"image"
]
},
{
"cell_type": "markdown",
"id": "13",
"metadata": {},
"source": [
"### Query via the offline Engine API"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "14",
"metadata": {},
"outputs": [],
"source": [
"from sglang.test.test_utils import is_in_ci\n",
"\n",
"if not is_in_ci():\n",
" from sglang import Engine\n",
"\n",
" llm = Engine(\n",
" model_path=model_path,\n",
" trust_remote_code=True,\n",
" enable_multimodal=True,\n",
" mem_fraction_static=0.8,\n",
" tp_size=4,\n",
" attention_backend=\"fa3\",\n",
" context_length=65536,\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "15",
"metadata": {},
"outputs": [],
"source": [
"if not is_in_ci():\n",
" out = llm.generate(prompt=conv.get_prompt(), image_data=[image])\n",
" print(out[\"text\"])"
]
},
{
"cell_type": "markdown",
"id": "16",
"metadata": {},
"source": [
"### Query via the offline Engine API, but send precomputed embeddings"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "17",
"metadata": {},
"outputs": [],
"source": [
"if not is_in_ci():\n",
" # Compute the image embeddings using Huggingface.\n",
"\n",
" from transformers import AutoProcessor\n",
" from transformers import Llama4ForConditionalGeneration\n",
"\n",
" processor = AutoProcessor.from_pretrained(model_path, use_fast=True)\n",
" model = Llama4ForConditionalGeneration.from_pretrained(\n",
" model_path, torch_dtype=\"auto\"\n",
" ).eval()\n",
" vision = model.vision_model.cuda()\n",
" multi_modal_projector = model.multi_modal_projector.cuda()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "18",
"metadata": {},
"outputs": [],
"source": [
"if not is_in_ci():\n",
" processed_prompt = processor(\n",
" images=[image], text=conv.get_prompt(), return_tensors=\"pt\"\n",
" )\n",
" print(f'{processed_prompt[\"pixel_values\"].shape=}')\n",
" input_ids = processed_prompt[\"input_ids\"][0].detach().cpu().tolist()\n",
"\n",
" image_outputs = vision(\n",
" processed_prompt[\"pixel_values\"].to(\"cuda\"), output_hidden_states=False\n",
" )\n",
" image_features = image_outputs.last_hidden_state\n",
" vision_flat = image_features.view(-1, image_features.size(-1))\n",
" precomputed_embeddings = multi_modal_projector(vision_flat)\n",
"\n",
" mm_item = dict(modality=\"IMAGE\", precomputed_embeddings=precomputed_embeddings)\n",
" out = llm.generate(input_ids=input_ids, image_data=[mm_item])\n",
" print(out[\"text\"])"
]
}
],
"metadata": {
"jupytext": {
"cell_metadata_filter": "-all",
"custom_cell_magics": "kql",
"encoding": "# -*- coding: utf-8 -*-"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3"
}
},
"nbformat": 4,
"nbformat_minor": 5
}