326 lines
8.1 KiB
Plaintext
326 lines
8.1 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "0",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Query Vision Language Model"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "1",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Querying Qwen-VL"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "2",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import nest_asyncio\n",
|
|
"\n",
|
|
"nest_asyncio.apply() # Run this first.\n",
|
|
"\n",
|
|
"model_path = \"Qwen/Qwen2.5-VL-3B-Instruct\"\n",
|
|
"chat_template = \"qwen2-vl\""
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "3",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Lets create a prompt.\n",
|
|
"\n",
|
|
"from io import BytesIO\n",
|
|
"import requests\n",
|
|
"from PIL import Image\n",
|
|
"\n",
|
|
"from sglang.srt.parser.conversation import chat_templates\n",
|
|
"\n",
|
|
"image = Image.open(\n",
|
|
" BytesIO(\n",
|
|
" requests.get(\n",
|
|
" \"https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true\"\n",
|
|
" ).content\n",
|
|
" )\n",
|
|
")\n",
|
|
"\n",
|
|
"conv = chat_templates[chat_template].copy()\n",
|
|
"conv.append_message(conv.roles[0], f\"What's shown here: {conv.image_token}?\")\n",
|
|
"conv.append_message(conv.roles[1], \"\")\n",
|
|
"conv.image_data = [image]\n",
|
|
"\n",
|
|
"print(conv.get_prompt())\n",
|
|
"image"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "4",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Query via the offline Engine API"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "5",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from sglang import Engine\n",
|
|
"\n",
|
|
"llm = Engine(\n",
|
|
" model_path=model_path, chat_template=chat_template, mem_fraction_static=0.8\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "6",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"out = llm.generate(prompt=conv.get_prompt(), image_data=[image])\n",
|
|
"print(out[\"text\"])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "7",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Query via the offline Engine API, but send precomputed embeddings"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "8",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Compute the image embeddings using Huggingface.\n",
|
|
"\n",
|
|
"from transformers import AutoProcessor\n",
|
|
"from transformers import Qwen2_5_VLForConditionalGeneration\n",
|
|
"\n",
|
|
"processor = AutoProcessor.from_pretrained(model_path, use_fast=True)\n",
|
|
"vision = (\n",
|
|
" Qwen2_5_VLForConditionalGeneration.from_pretrained(model_path).eval().visual.cuda()\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "9",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"processed_prompt = processor(\n",
|
|
" images=[image], text=conv.get_prompt(), return_tensors=\"pt\"\n",
|
|
")\n",
|
|
"input_ids = processed_prompt[\"input_ids\"][0].detach().cpu().tolist()\n",
|
|
"precomputed_embeddings = vision(\n",
|
|
" processed_prompt[\"pixel_values\"].cuda(), processed_prompt[\"image_grid_thw\"].cuda()\n",
|
|
")\n",
|
|
"\n",
|
|
"mm_item = dict(\n",
|
|
" modality=\"IMAGE\",\n",
|
|
" image_grid_thw=processed_prompt[\"image_grid_thw\"],\n",
|
|
" precomputed_embeddings=precomputed_embeddings,\n",
|
|
")\n",
|
|
"out = llm.generate(input_ids=input_ids, image_data=[mm_item])\n",
|
|
"print(out[\"text\"])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "10",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Querying Llama 4 (Vision)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "11",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import nest_asyncio\n",
|
|
"\n",
|
|
"nest_asyncio.apply() # Run this first.\n",
|
|
"\n",
|
|
"model_path = \"meta-llama/Llama-4-Scout-17B-16E-Instruct\"\n",
|
|
"chat_template = \"llama-4\""
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "12",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Lets create a prompt.\n",
|
|
"\n",
|
|
"from io import BytesIO\n",
|
|
"import requests\n",
|
|
"from PIL import Image\n",
|
|
"\n",
|
|
"from sglang.srt.parser.conversation import chat_templates\n",
|
|
"\n",
|
|
"image = Image.open(\n",
|
|
" BytesIO(\n",
|
|
" requests.get(\n",
|
|
" \"https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true\"\n",
|
|
" ).content\n",
|
|
" )\n",
|
|
")\n",
|
|
"\n",
|
|
"conv = chat_templates[chat_template].copy()\n",
|
|
"conv.append_message(conv.roles[0], f\"What's shown here: {conv.image_token}?\")\n",
|
|
"conv.append_message(conv.roles[1], \"\")\n",
|
|
"conv.image_data = [image]\n",
|
|
"\n",
|
|
"print(conv.get_prompt())\n",
|
|
"print(f\"Image size: {image.size}\")\n",
|
|
"\n",
|
|
"image"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "13",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Query via the offline Engine API"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "14",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from sglang.test.test_utils import is_in_ci\n",
|
|
"\n",
|
|
"if not is_in_ci():\n",
|
|
" from sglang import Engine\n",
|
|
"\n",
|
|
" llm = Engine(\n",
|
|
" model_path=model_path,\n",
|
|
" trust_remote_code=True,\n",
|
|
" enable_multimodal=True,\n",
|
|
" mem_fraction_static=0.8,\n",
|
|
" tp_size=4,\n",
|
|
" attention_backend=\"fa3\",\n",
|
|
" context_length=65536,\n",
|
|
" )"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "15",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"if not is_in_ci():\n",
|
|
" out = llm.generate(prompt=conv.get_prompt(), image_data=[image])\n",
|
|
" print(out[\"text\"])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "16",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Query via the offline Engine API, but send precomputed embeddings"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "17",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"if not is_in_ci():\n",
|
|
" # Compute the image embeddings using Huggingface.\n",
|
|
"\n",
|
|
" from transformers import AutoProcessor\n",
|
|
" from transformers import Llama4ForConditionalGeneration\n",
|
|
"\n",
|
|
" processor = AutoProcessor.from_pretrained(model_path, use_fast=True)\n",
|
|
" model = Llama4ForConditionalGeneration.from_pretrained(\n",
|
|
" model_path, torch_dtype=\"auto\"\n",
|
|
" ).eval()\n",
|
|
" vision = model.vision_model.cuda()\n",
|
|
" multi_modal_projector = model.multi_modal_projector.cuda()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "18",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"if not is_in_ci():\n",
|
|
" processed_prompt = processor(\n",
|
|
" images=[image], text=conv.get_prompt(), return_tensors=\"pt\"\n",
|
|
" )\n",
|
|
" print(f'{processed_prompt[\"pixel_values\"].shape=}')\n",
|
|
" input_ids = processed_prompt[\"input_ids\"][0].detach().cpu().tolist()\n",
|
|
"\n",
|
|
" image_outputs = vision(\n",
|
|
" processed_prompt[\"pixel_values\"].to(\"cuda\"), output_hidden_states=False\n",
|
|
" )\n",
|
|
" image_features = image_outputs.last_hidden_state\n",
|
|
" vision_flat = image_features.view(-1, image_features.size(-1))\n",
|
|
" precomputed_embeddings = multi_modal_projector(vision_flat)\n",
|
|
"\n",
|
|
" mm_item = dict(modality=\"IMAGE\", precomputed_embeddings=precomputed_embeddings)\n",
|
|
" out = llm.generate(input_ids=input_ids, image_data=[mm_item])\n",
|
|
" print(out[\"text\"])"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"jupytext": {
|
|
"cell_metadata_filter": "-all",
|
|
"custom_cell_magics": "kql",
|
|
"encoding": "# -*- coding: utf-8 -*-"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|