{ "cells": [ { "cell_type": "markdown", "id": "0", "metadata": {}, "source": [ "# Query Vision Language Model" ] }, { "cell_type": "markdown", "id": "1", "metadata": {}, "source": [ "## Querying Qwen-VL" ] }, { "cell_type": "code", "execution_count": null, "id": "2", "metadata": {}, "outputs": [], "source": [ "import nest_asyncio\n", "\n", "nest_asyncio.apply() # Run this first.\n", "\n", "model_path = \"Qwen/Qwen2.5-VL-3B-Instruct\"\n", "chat_template = \"qwen2-vl\"" ] }, { "cell_type": "code", "execution_count": null, "id": "3", "metadata": {}, "outputs": [], "source": [ "# Lets create a prompt.\n", "\n", "from io import BytesIO\n", "import requests\n", "from PIL import Image\n", "\n", "from sglang.srt.parser.conversation import chat_templates\n", "\n", "image = Image.open(\n", " BytesIO(\n", " requests.get(\n", " \"https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true\"\n", " ).content\n", " )\n", ")\n", "\n", "conv = chat_templates[chat_template].copy()\n", "conv.append_message(conv.roles[0], f\"What's shown here: {conv.image_token}?\")\n", "conv.append_message(conv.roles[1], \"\")\n", "conv.image_data = [image]\n", "\n", "print(conv.get_prompt())\n", "image" ] }, { "cell_type": "markdown", "id": "4", "metadata": {}, "source": [ "### Query via the offline Engine API" ] }, { "cell_type": "code", "execution_count": null, "id": "5", "metadata": {}, "outputs": [], "source": [ "from sglang import Engine\n", "\n", "llm = Engine(\n", " model_path=model_path, chat_template=chat_template, mem_fraction_static=0.8\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "6", "metadata": {}, "outputs": [], "source": [ "out = llm.generate(prompt=conv.get_prompt(), image_data=[image])\n", "print(out[\"text\"])" ] }, { "cell_type": "markdown", "id": "7", "metadata": {}, "source": [ "### Query via the offline Engine API, but send precomputed embeddings" ] }, { "cell_type": "code", "execution_count": null, "id": "8", "metadata": {}, "outputs": [], "source": [ "# Compute the image embeddings using Huggingface.\n", "\n", "from transformers import AutoProcessor\n", "from transformers import Qwen2_5_VLForConditionalGeneration\n", "\n", "processor = AutoProcessor.from_pretrained(model_path, use_fast=True)\n", "vision = (\n", " Qwen2_5_VLForConditionalGeneration.from_pretrained(model_path).eval().visual.cuda()\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "9", "metadata": {}, "outputs": [], "source": [ "processed_prompt = processor(\n", " images=[image], text=conv.get_prompt(), return_tensors=\"pt\"\n", ")\n", "input_ids = processed_prompt[\"input_ids\"][0].detach().cpu().tolist()\n", "precomputed_embeddings = vision(\n", " processed_prompt[\"pixel_values\"].cuda(), processed_prompt[\"image_grid_thw\"].cuda()\n", ")\n", "\n", "mm_item = dict(\n", " modality=\"IMAGE\",\n", " image_grid_thw=processed_prompt[\"image_grid_thw\"],\n", " precomputed_embeddings=precomputed_embeddings,\n", ")\n", "out = llm.generate(input_ids=input_ids, image_data=[mm_item])\n", "print(out[\"text\"])" ] }, { "cell_type": "markdown", "id": "10", "metadata": {}, "source": [ "## Querying Llama 4 (Vision)" ] }, { "cell_type": "code", "execution_count": null, "id": "11", "metadata": {}, "outputs": [], "source": [ "import nest_asyncio\n", "\n", "nest_asyncio.apply() # Run this first.\n", "\n", "model_path = \"meta-llama/Llama-4-Scout-17B-16E-Instruct\"\n", "chat_template = \"llama-4\"" ] }, { "cell_type": "code", "execution_count": null, "id": "12", "metadata": {}, "outputs": [], "source": [ "# Lets create a prompt.\n", "\n", "from io import BytesIO\n", "import requests\n", "from PIL import Image\n", "\n", "from sglang.srt.parser.conversation import chat_templates\n", "\n", "image = Image.open(\n", " BytesIO(\n", " requests.get(\n", " \"https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true\"\n", " ).content\n", " )\n", ")\n", "\n", "conv = chat_templates[chat_template].copy()\n", "conv.append_message(conv.roles[0], f\"What's shown here: {conv.image_token}?\")\n", "conv.append_message(conv.roles[1], \"\")\n", "conv.image_data = [image]\n", "\n", "print(conv.get_prompt())\n", "print(f\"Image size: {image.size}\")\n", "\n", "image" ] }, { "cell_type": "markdown", "id": "13", "metadata": {}, "source": [ "### Query via the offline Engine API" ] }, { "cell_type": "code", "execution_count": null, "id": "14", "metadata": {}, "outputs": [], "source": [ "from sglang.test.test_utils import is_in_ci\n", "\n", "if not is_in_ci():\n", " from sglang import Engine\n", "\n", " llm = Engine(\n", " model_path=model_path,\n", " trust_remote_code=True,\n", " enable_multimodal=True,\n", " mem_fraction_static=0.8,\n", " tp_size=4,\n", " attention_backend=\"fa3\",\n", " context_length=65536,\n", " )" ] }, { "cell_type": "code", "execution_count": null, "id": "15", "metadata": {}, "outputs": [], "source": [ "if not is_in_ci():\n", " out = llm.generate(prompt=conv.get_prompt(), image_data=[image])\n", " print(out[\"text\"])" ] }, { "cell_type": "markdown", "id": "16", "metadata": {}, "source": [ "### Query via the offline Engine API, but send precomputed embeddings" ] }, { "cell_type": "code", "execution_count": null, "id": "17", "metadata": {}, "outputs": [], "source": [ "if not is_in_ci():\n", " # Compute the image embeddings using Huggingface.\n", "\n", " from transformers import AutoProcessor\n", " from transformers import Llama4ForConditionalGeneration\n", "\n", " processor = AutoProcessor.from_pretrained(model_path, use_fast=True)\n", " model = Llama4ForConditionalGeneration.from_pretrained(\n", " model_path, torch_dtype=\"auto\"\n", " ).eval()\n", " vision = model.vision_model.cuda()\n", " multi_modal_projector = model.multi_modal_projector.cuda()" ] }, { "cell_type": "code", "execution_count": null, "id": "18", "metadata": {}, "outputs": [], "source": [ "if not is_in_ci():\n", " processed_prompt = processor(\n", " images=[image], text=conv.get_prompt(), return_tensors=\"pt\"\n", " )\n", " print(f'{processed_prompt[\"pixel_values\"].shape=}')\n", " input_ids = processed_prompt[\"input_ids\"][0].detach().cpu().tolist()\n", "\n", " image_outputs = vision(\n", " processed_prompt[\"pixel_values\"].to(\"cuda\"), output_hidden_states=False\n", " )\n", " image_features = image_outputs.last_hidden_state\n", " vision_flat = image_features.view(-1, image_features.size(-1))\n", " precomputed_embeddings = multi_modal_projector(vision_flat)\n", "\n", " mm_item = dict(modality=\"IMAGE\", precomputed_embeddings=precomputed_embeddings)\n", " out = llm.generate(input_ids=input_ids, image_data=[mm_item])\n", " print(out[\"text\"])" ] } ], "metadata": { "jupytext": { "cell_metadata_filter": "-all", "custom_cell_magics": "kql", "encoding": "# -*- coding: utf-8 -*-" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3" } }, "nbformat": 4, "nbformat_minor": 5 }