first commit
This commit is contained in:
commit
466c38a2d4
|
|
@ -0,0 +1 @@
|
||||||
|
sgl-kernel/3rdparty/tensorrt_llm/*
|
||||||
|
|
@ -0,0 +1,35 @@
|
||||||
|
From lmsysorg/sglang:dev
|
||||||
|
|
||||||
|
# Create non-root user with specified UID and GID
|
||||||
|
# NOTE: Replace with your own UID and GID. This is a workaround from https://github.com/microsoft/vscode-remote-release/issues/49#issuecomment-489060908.
|
||||||
|
ARG HOST_UID=1003
|
||||||
|
ARG HOST_GID=1003
|
||||||
|
RUN groupadd -g $HOST_GID devuser && \
|
||||||
|
useradd -m -u $HOST_UID -g $HOST_GID -s /bin/zsh devuser
|
||||||
|
|
||||||
|
# Give devuser sudo access
|
||||||
|
RUN apt-get update && apt-get install -y sudo && \
|
||||||
|
echo "devuser ALL=(ALL) NOPASSWD:ALL" > /etc/sudoers.d/devuser && \
|
||||||
|
rm -rf /var/lib/apt/lists/* && \
|
||||||
|
apt-get clean
|
||||||
|
|
||||||
|
# Set up oh-my-zsh for devuser
|
||||||
|
RUN cp -r /root/.oh-my-zsh /home/devuser/.oh-my-zsh && \
|
||||||
|
cp /root/.zshrc /home/devuser/.zshrc && \
|
||||||
|
cp /root/.vimrc /home/devuser/.vimrc && \
|
||||||
|
cp /root/.tmux.conf /home/devuser/.tmux.conf && \
|
||||||
|
sed -i 's|/root/.oh-my-zsh|/home/devuser/.oh-my-zsh|g' /home/devuser/.zshrc && \
|
||||||
|
chown -R devuser:devuser /home/devuser/
|
||||||
|
|
||||||
|
# Set workspace directory and ownership
|
||||||
|
WORKDIR /sgl-workspace/sglang
|
||||||
|
RUN chown -R devuser:devuser /sgl-workspace
|
||||||
|
|
||||||
|
# Switch to devuser
|
||||||
|
USER devuser
|
||||||
|
|
||||||
|
# Install uv
|
||||||
|
RUN curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||||
|
|
||||||
|
# Install rust
|
||||||
|
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
|
||||||
|
|
@ -0,0 +1,24 @@
|
||||||
|
{
|
||||||
|
"name": "sglang",
|
||||||
|
"build": {
|
||||||
|
"dockerfile": "Dockerfile"
|
||||||
|
},
|
||||||
|
"remoteUser": "devuser",
|
||||||
|
"customizations": {
|
||||||
|
"vscode": {
|
||||||
|
"extensions": [
|
||||||
|
// Python development
|
||||||
|
"ms-python.python",
|
||||||
|
"charliermarsh.ruff",
|
||||||
|
// Rust development
|
||||||
|
"rust-lang.rust-analyzer",
|
||||||
|
"tamasfe.even-better-toml"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"forwardPorts": [],
|
||||||
|
"runArgs": [
|
||||||
|
"--gpus",
|
||||||
|
"all"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,25 @@
|
||||||
|
# https://editorconfig.org/
|
||||||
|
|
||||||
|
root = true
|
||||||
|
|
||||||
|
[*]
|
||||||
|
charset = utf-8
|
||||||
|
end_of_line = lf
|
||||||
|
indent_style = space
|
||||||
|
indent_size = 4
|
||||||
|
trim_trailing_whitespace = true
|
||||||
|
insert_final_newline = true
|
||||||
|
|
||||||
|
[*.{json,yaml,yml}]
|
||||||
|
indent_size = 2
|
||||||
|
|
||||||
|
[*.md]
|
||||||
|
indent_size = 2
|
||||||
|
x-soft-wrap-text = true
|
||||||
|
|
||||||
|
[*.rst]
|
||||||
|
indent_size = 4
|
||||||
|
x-soft-wrap-text = true
|
||||||
|
|
||||||
|
[Makefile]
|
||||||
|
indent_style = tab
|
||||||
|
|
@ -0,0 +1,229 @@
|
||||||
|
# Byte-compiled / optimized / DLL files
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
|
||||||
|
# C extensions
|
||||||
|
*.so
|
||||||
|
|
||||||
|
# Distribution / packaging
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
share/python-wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
MANIFEST
|
||||||
|
|
||||||
|
# PyInstaller
|
||||||
|
# Usually these files are written by a python script from a template
|
||||||
|
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||||
|
*.manifest
|
||||||
|
*.spec
|
||||||
|
|
||||||
|
# Installer logs
|
||||||
|
pip-log.txt
|
||||||
|
pip-delete-this-directory.txt
|
||||||
|
|
||||||
|
# Unit test / coverage reports
|
||||||
|
htmlcov/
|
||||||
|
.tox/
|
||||||
|
.nox/
|
||||||
|
.coverage
|
||||||
|
.coverage.*
|
||||||
|
.cache
|
||||||
|
nosetests.xml
|
||||||
|
coverage.xml
|
||||||
|
*.cover
|
||||||
|
*.py,cover
|
||||||
|
.hypothesis/
|
||||||
|
.pytest_cache/
|
||||||
|
cover/
|
||||||
|
|
||||||
|
# Translations
|
||||||
|
*.mo
|
||||||
|
*.pot
|
||||||
|
|
||||||
|
# Django stuff:
|
||||||
|
*.log
|
||||||
|
local_settings.py
|
||||||
|
db.sqlite3
|
||||||
|
db.sqlite3-journal
|
||||||
|
|
||||||
|
# Flask stuff:
|
||||||
|
instance/
|
||||||
|
.webassets-cache
|
||||||
|
|
||||||
|
# Scrapy stuff:
|
||||||
|
.scrapy
|
||||||
|
|
||||||
|
# Sphinx documentation
|
||||||
|
docs/_build/
|
||||||
|
|
||||||
|
# PyBuilder
|
||||||
|
.pybuilder/
|
||||||
|
target/
|
||||||
|
|
||||||
|
# Jupyter Notebook
|
||||||
|
.ipynb_checkpoints
|
||||||
|
|
||||||
|
# IPython
|
||||||
|
profile_default/
|
||||||
|
ipython_config.py
|
||||||
|
|
||||||
|
# pyenv
|
||||||
|
# For a library or package, you might want to ignore these files since the code is
|
||||||
|
# intended to run in multiple environments; otherwise, check them in:
|
||||||
|
# .python-version
|
||||||
|
|
||||||
|
# pipenv
|
||||||
|
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||||
|
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||||
|
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||||
|
# install all needed dependencies.
|
||||||
|
#Pipfile.lock
|
||||||
|
|
||||||
|
# poetry
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||||
|
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||||
|
# commonly ignored for libraries.
|
||||||
|
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||||
|
#poetry.lock
|
||||||
|
|
||||||
|
# pdm
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||||
|
#pdm.lock
|
||||||
|
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||||
|
# in version control.
|
||||||
|
# https://pdm.fming.dev/#use-with-ide
|
||||||
|
.pdm.toml
|
||||||
|
|
||||||
|
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||||
|
__pypackages__/
|
||||||
|
|
||||||
|
# Celery stuff
|
||||||
|
celerybeat-schedule
|
||||||
|
celerybeat.pid
|
||||||
|
|
||||||
|
# SageMath parsed files
|
||||||
|
*.sage.py
|
||||||
|
|
||||||
|
# Environments
|
||||||
|
.env
|
||||||
|
.venv
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env.bak/
|
||||||
|
venv.bak/
|
||||||
|
|
||||||
|
# Spyder project settings
|
||||||
|
.spyderproject
|
||||||
|
.spyproject
|
||||||
|
|
||||||
|
# Rope project settings
|
||||||
|
.ropeproject
|
||||||
|
|
||||||
|
# mkdocs documentation
|
||||||
|
/site
|
||||||
|
|
||||||
|
# mypy
|
||||||
|
.mypy_cache/
|
||||||
|
.dmypy.json
|
||||||
|
dmypy.json
|
||||||
|
|
||||||
|
# Pyre type checker
|
||||||
|
.pyre/
|
||||||
|
|
||||||
|
# pytype static type analyzer
|
||||||
|
.pytype/
|
||||||
|
|
||||||
|
# Cython debug symbols
|
||||||
|
cython_debug/
|
||||||
|
|
||||||
|
# PyCharm
|
||||||
|
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||||
|
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||||
|
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||||
|
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||||
|
.idea/
|
||||||
|
|
||||||
|
# MacOS
|
||||||
|
.DS_Store
|
||||||
|
|
||||||
|
# Vim
|
||||||
|
*.swp
|
||||||
|
|
||||||
|
# Documentation
|
||||||
|
docs/_build
|
||||||
|
|
||||||
|
# SGL
|
||||||
|
benchmark/mmlu/data
|
||||||
|
benchmark/mmlu/data.tar
|
||||||
|
benchmark/llava_bench/images
|
||||||
|
benchmark/llava_bench/mme_pack
|
||||||
|
*.jsonl
|
||||||
|
tmp*.txt
|
||||||
|
|
||||||
|
# Plots
|
||||||
|
*.png
|
||||||
|
*.pdf
|
||||||
|
|
||||||
|
# personnal
|
||||||
|
work_dirs/
|
||||||
|
*.csv
|
||||||
|
|
||||||
|
!logo.png
|
||||||
|
|
||||||
|
# Prerequisites
|
||||||
|
*.d
|
||||||
|
|
||||||
|
# Compiled Object files
|
||||||
|
*.slo
|
||||||
|
*.lo
|
||||||
|
*.o
|
||||||
|
*.obj
|
||||||
|
|
||||||
|
# Precompiled Headers
|
||||||
|
*.gch
|
||||||
|
*.pch
|
||||||
|
|
||||||
|
# Compiled Dynamic libraries
|
||||||
|
*.so
|
||||||
|
*.dylib
|
||||||
|
*.dll
|
||||||
|
|
||||||
|
# Fortran module files
|
||||||
|
*.mod
|
||||||
|
*.smod
|
||||||
|
|
||||||
|
# Compiled Static libraries
|
||||||
|
*.lai
|
||||||
|
*.la
|
||||||
|
*.a
|
||||||
|
*.lib
|
||||||
|
|
||||||
|
# Executables
|
||||||
|
*.exe
|
||||||
|
*.out
|
||||||
|
*.app
|
||||||
|
|
||||||
|
compile_commands.json
|
||||||
|
|
||||||
|
*.iml
|
||||||
|
|
||||||
|
# VSCode
|
||||||
|
.vscode
|
||||||
|
|
||||||
|
1
|
||||||
|
|
@ -0,0 +1,12 @@
|
||||||
|
[submodule "sgl-kernel/3rdparty/cutlass"]
|
||||||
|
path = sgl-kernel/3rdparty/cutlass
|
||||||
|
url = https://github.com/NVIDIA/cutlass.git
|
||||||
|
[submodule "sgl-kernel/3rdparty/cccl"]
|
||||||
|
path = sgl-kernel/3rdparty/cccl
|
||||||
|
url = https://github.com/NVIDIA/cccl.git
|
||||||
|
[submodule "sgl-kernel/3rdparty/flashinfer"]
|
||||||
|
path = sgl-kernel/3rdparty/flashinfer
|
||||||
|
url = https://github.com/flashinfer-ai/flashinfer.git
|
||||||
|
[submodule "sgl-kernel/3rdparty/deepgemm"]
|
||||||
|
path = sgl-kernel/3rdparty/deepgemm
|
||||||
|
url = https://github.com/deepseek-ai/DeepGEMM
|
||||||
|
|
@ -0,0 +1,3 @@
|
||||||
|
[settings]
|
||||||
|
profile=black
|
||||||
|
known_first_party=sglang
|
||||||
|
|
@ -0,0 +1,48 @@
|
||||||
|
default_stages: [pre-commit, pre-push, manual]
|
||||||
|
|
||||||
|
repos:
|
||||||
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||||
|
rev: v5.0.0
|
||||||
|
hooks:
|
||||||
|
- id: check-symlinks
|
||||||
|
- id: destroyed-symlinks
|
||||||
|
- id: trailing-whitespace
|
||||||
|
- id: end-of-file-fixer
|
||||||
|
- id: check-yaml
|
||||||
|
args: [--allow-multiple-documents]
|
||||||
|
- id: check-toml
|
||||||
|
- id: check-ast
|
||||||
|
- id: check-added-large-files
|
||||||
|
- id: check-merge-conflict
|
||||||
|
- id: check-shebang-scripts-are-executable
|
||||||
|
- id: detect-private-key
|
||||||
|
- id: debug-statements
|
||||||
|
- id: no-commit-to-branch
|
||||||
|
- repo: https://github.com/PyCQA/isort
|
||||||
|
rev: 5.13.2
|
||||||
|
hooks:
|
||||||
|
- id: isort
|
||||||
|
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||||
|
rev: v0.11.2
|
||||||
|
hooks:
|
||||||
|
- id: ruff
|
||||||
|
args: [--select=F401, --fixable=F401]
|
||||||
|
files: ^(benchmark/|docs/|examples/)
|
||||||
|
exclude: \.ipynb$
|
||||||
|
- repo: https://github.com/psf/black
|
||||||
|
rev: 24.10.0
|
||||||
|
hooks:
|
||||||
|
- id: black-jupyter
|
||||||
|
- repo: https://github.com/pre-commit/mirrors-clang-format
|
||||||
|
rev: v18.1.8
|
||||||
|
hooks:
|
||||||
|
- id: clang-format
|
||||||
|
types_or: [c++, cuda]
|
||||||
|
args: [--style=file, --verbose]
|
||||||
|
- repo: https://github.com/kynan/nbstripout
|
||||||
|
rev: 0.8.1
|
||||||
|
hooks:
|
||||||
|
- id: nbstripout
|
||||||
|
args:
|
||||||
|
- '--keep-output'
|
||||||
|
- '--extra-keys=metadata.kernelspec metadata.language_info.version'
|
||||||
|
|
@ -0,0 +1,425 @@
|
||||||
|
## Profiling SGLang Infer System with AMD GPUs
|
||||||
|
This AppNote describes the SGLang profiling technical, code augment and running steps for systems with AMD Instinct GPUs, nevertheless the same procedure may work with Nvidia GPUs too.
|
||||||
|
Examples and steps are provided in detail, to facilitate easy reproduce and use to localize performance problem towards optimizations.
|
||||||
|
Two primary methods are covered:
|
||||||
|
- [RPD](https://github.com/ROCm/rocmProfileData.git)
|
||||||
|
- [PyTorch Profiler](https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html)
|
||||||
|
|
||||||
|
### Profiling SGLang Infer System with RPD Profiler
|
||||||
|
RPD profiler is a low-overhead cross-platform profiler. Therefore, the same RPD code augment not only works for profiling on ROCm/AMD GPUs, but also works for profiling on CUDA/Nvidia GPUs as well. To do RPD profiling on SGLang repository, please use scripts and patch files included in this directory and follow the steps below:
|
||||||
|
1. Install RPD with rpd.patch applied during installation using install_rpd.sh, both files are in this directory.
|
||||||
|
|
||||||
|
install_rpd.sh
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# download and install RPD
|
||||||
|
apt update && apt install -y sqlite3 libsqlite3-dev libfmt-dev
|
||||||
|
|
||||||
|
# install rpd module
|
||||||
|
git clone https://github.com/ROCmSoftwarePlatform/rocmProfileData
|
||||||
|
cd rocmProfileData
|
||||||
|
git checkout 976899e9c6dbc6dd2bccf770818e4e44125590ac
|
||||||
|
git apply rpd.patch
|
||||||
|
make && make install
|
||||||
|
cd rocpd_python && python setup.py install && cd ..
|
||||||
|
cd rpd_tracer && make clean;make install && python setup.py install && cd ..
|
||||||
|
```
|
||||||
|
|
||||||
|
rpd.patch
|
||||||
|
|
||||||
|
```bash
|
||||||
|
diff --git a/rpd_tracer/Makefile b/rpd_tracer/Makefile
|
||||||
|
index e9d9feb..b2e9e1a 100644
|
||||||
|
--- a/rpd_tracer/Makefile
|
||||||
|
+++ b/rpd_tracer/Makefile
|
||||||
|
@@ -16,7 +16,7 @@ ifneq (,$(HIP_PATH))
|
||||||
|
$(info Building with roctracer)
|
||||||
|
RPD_LIBS += -L/opt/rocm/lib -lroctracer64 -lroctx64 -lamdhip64 -lrocm_smi64
|
||||||
|
RPD_INCLUDES += -I/opt/rocm/include -I/opt/rocm/include/roctracer -I/opt/rocm/include/hsa
|
||||||
|
- RPD_SRCS += RoctracerDataSource.cpp RocmSmiDataSource.cpp
|
||||||
|
+ RPD_SRCS += RoctracerDataSource.cpp
|
||||||
|
RPD_INCLUDES += -D__HIP_PLATFORM_AMD__
|
||||||
|
endif
|
||||||
|
```
|
||||||
|
2. Add loadTracer.sh file included in this directory to /sglang/python/sglang.
|
||||||
|
|
||||||
|
loadTracer.sh
|
||||||
|
|
||||||
|
```bash
|
||||||
|
#!/bin/bash
|
||||||
|
################################################################################
|
||||||
|
# Copyright (c) 2021 - 2023 Advanced Micro Devices, Inc. All rights reserved.
|
||||||
|
#
|
||||||
|
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
# of this software and associated documentation files (the "Software"), to deal
|
||||||
|
# in the Software without restriction, including without limitation the rights
|
||||||
|
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
# copies of the Software, and to permit persons to whom the Software is
|
||||||
|
# furnished to do so, subject to the following conditions:
|
||||||
|
#
|
||||||
|
# The above copyright notice and this permission notice shall be included in
|
||||||
|
# all copies or substantial portions of the Software.
|
||||||
|
#
|
||||||
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||||
|
# THE SOFTWARE.
|
||||||
|
################################################################################
|
||||||
|
OUTPUT_FILE="trace.rpd"
|
||||||
|
|
||||||
|
if [ "$1" = "-o" ] ; then
|
||||||
|
OUTPUT_FILE=$2
|
||||||
|
shift
|
||||||
|
shift
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ -e ${OUTPUT_FILE} ] ; then
|
||||||
|
rm ${OUTPUT_FILE}
|
||||||
|
fi
|
||||||
|
|
||||||
|
python3 -m rocpd.schema --create ${OUTPUT_FILE}
|
||||||
|
if [ $? != 0 ] ; then
|
||||||
|
echo "Error: Could not create rpd file. Please run 'python setup.py install' from the rocpd_python dir"
|
||||||
|
exit
|
||||||
|
fi
|
||||||
|
|
||||||
|
export RPDT_FILENAME=${OUTPUT_FILE}
|
||||||
|
export RPDT_AUTOSTART=0
|
||||||
|
LD_PRELOAD=librocm-smi_64:librpd_tracer.so "$@"
|
||||||
|
```
|
||||||
|
3. Apply patch (provided in this directory) with "git apply rpd_profile_server_enable.patch" if the main profiling purpose is to get info on gpu kernels as well as limited cpu activity info.
|
||||||
|
|
||||||
|
#### Common Notes 1
|
||||||
|
Please note that although we are doing TP=8 in the example, we purposely only log RPD profiling on 2 ranks in the patch file (i.e.tp_rank=0/1) for profiling/visualization convenience, as even Perfetto streaming mode can only load maximal 8GB json file for visualization. With 2 ranks logged in RPD profiling, we could still check whether there are issues among ranks (e.g. load imbalance issue, nccl issue), and at the same time, we could log relatively longer time duration before the json file generated from RPD file hits 8GB size.
|
||||||
|
|
||||||
|
rpd_profile_server_enable.patch
|
||||||
|
|
||||||
|
```bash
|
||||||
|
diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py
|
||||||
|
index 62d1ff9..9021c01 100644
|
||||||
|
--- a/python/sglang/srt/managers/scheduler.py
|
||||||
|
+++ b/python/sglang/srt/managers/scheduler.py
|
||||||
|
@@ -71,6 +71,8 @@ from sglang.srt.utils import (
|
||||||
|
suppress_other_loggers,
|
||||||
|
)
|
||||||
|
from sglang.utils import get_exception_traceback
|
||||||
|
+from rpdTracerControl import rpdTracerControl
|
||||||
|
+rpdTracerControl.skipCreate()
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@@ -245,6 +247,7 @@ class Scheduler:
|
||||||
|
],
|
||||||
|
with_stack=True,
|
||||||
|
)
|
||||||
|
+ self.rpd = rpdTracerControl()
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def event_loop(self):
|
||||||
|
@@ -1027,15 +1030,24 @@ class Scheduler:
|
||||||
|
def start_profile(self) -> None:
|
||||||
|
if self.profiler is None:
|
||||||
|
raise RuntimeError("Profiler is not enabled.")
|
||||||
|
- self.profiler.start()
|
||||||
|
+ #self.profiler.start() #block pytorch profiler for rpd profiler enabling
|
||||||
|
+ if self.tp_rank == 0 or self.tp_rank == 1:
|
||||||
|
+ self.rpd.start()
|
||||||
|
+ self.rpd.rangePush("", "rpd profile range", "")
|
||||||
|
+ logger.info("rpd is enabled")
|
||||||
|
|
||||||
|
def stop_profile(self) -> None:
|
||||||
|
if self.profiler is None:
|
||||||
|
raise RuntimeError("Profiler is not enabled.")
|
||||||
|
- self.profiler.stop()
|
||||||
|
- self.profiler.export_chrome_trace(
|
||||||
|
- self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz"
|
||||||
|
- )
|
||||||
|
+ #self.profiler.stop()
|
||||||
|
+ #self.profiler.export_chrome_trace(
|
||||||
|
+ # self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz"
|
||||||
|
+ #)
|
||||||
|
+ if self.tp_rank ==0 or self.tp_rank ==1:
|
||||||
|
+ self.rpd.rangePop()
|
||||||
|
+ self.rpd.stop()
|
||||||
|
+ self.rpd.flush()
|
||||||
|
+ logger.info("rpd is done")
|
||||||
|
logger.info("Profiler is done")
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Advanced Debugging with RPD Profiler
|
||||||
|
Sometimes, we want to use rpd profiler to capture more CPU and python activities in order to debug some challenging issues (e.g. root cause of load imbalance across gpu processes, root cause of bubbles, etc). Only in such cases, we need to apply patch "git apply rpd_profile_server_enable_wCPU_activities.patch", where 3 files are modified.
|
||||||
|
|
||||||
|
rpd_profile_server_enable_wCPU_activities.patch
|
||||||
|
|
||||||
|
```bash
|
||||||
|
diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py
|
||||||
|
index 62d1ff9..2edb427 100644
|
||||||
|
--- a/python/sglang/srt/managers/scheduler.py
|
||||||
|
+++ b/python/sglang/srt/managers/scheduler.py
|
||||||
|
@@ -71,6 +71,8 @@ from sglang.srt.utils import (
|
||||||
|
suppress_other_loggers,
|
||||||
|
)
|
||||||
|
from sglang.utils import get_exception_traceback
|
||||||
|
+from rpdTracerControl import rpdTracerControl
|
||||||
|
+rpdTracerControl.skipCreate()
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@@ -245,6 +247,7 @@ class Scheduler:
|
||||||
|
],
|
||||||
|
with_stack=True,
|
||||||
|
)
|
||||||
|
+ self.rpd = rpdTracerControl()
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def event_loop(self):
|
||||||
|
@@ -1027,15 +1030,26 @@ class Scheduler:
|
||||||
|
def start_profile(self) -> None:
|
||||||
|
if self.profiler is None:
|
||||||
|
raise RuntimeError("Profiler is not enabled.")
|
||||||
|
- self.profiler.start()
|
||||||
|
+ #self.profiler.start()
|
||||||
|
+ logger.info("torch profiler is disabled")
|
||||||
|
+ if self.tp_rank == 0 or self.tp_rank == 1:
|
||||||
|
+ self.rpd.setPythonTrace(True)
|
||||||
|
+ self.rpd.start()
|
||||||
|
+ self.rpd.rangePush("", "scheduler", "")
|
||||||
|
+ logger.info("rpd is enabled inside scheduler profiling")
|
||||||
|
|
||||||
|
def stop_profile(self) -> None:
|
||||||
|
if self.profiler is None:
|
||||||
|
raise RuntimeError("Profiler is not enabled.")
|
||||||
|
- self.profiler.stop()
|
||||||
|
- self.profiler.export_chrome_trace(
|
||||||
|
- self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz"
|
||||||
|
- )
|
||||||
|
+ #self.profiler.stop()
|
||||||
|
+ #self.profiler.export_chrome_trace(
|
||||||
|
+ # self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz"
|
||||||
|
+ #)
|
||||||
|
+ if self.tp_rank ==0 or self.tp_rank ==1:
|
||||||
|
+ self.rpd.rangePop()
|
||||||
|
+ self.rpd.stop()
|
||||||
|
+ self.rpd.flush()
|
||||||
|
+ logger.info("rpd is done inside scheduler")
|
||||||
|
logger.info("Profiler is done")
|
||||||
|
|
||||||
|
|
||||||
|
diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py
|
||||||
|
index 2621ccd..181df85 100644
|
||||||
|
--- a/python/sglang/srt/managers/tokenizer_manager.py
|
||||||
|
+++ b/python/sglang/srt/managers/tokenizer_manager.py
|
||||||
|
@@ -58,6 +58,10 @@ from sglang.srt.sampling.sampling_params import SamplingParams
|
||||||
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||||
|
from sglang.srt.utils import is_generation_model, is_multimodal_model
|
||||||
|
|
||||||
|
+from rpdTracerControl import rpdTracerControl
|
||||||
|
+rpdTracerControl.skipCreate()
|
||||||
|
+
|
||||||
|
+
|
||||||
|
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
@@ -514,10 +518,20 @@ class TokenizerManager:
|
||||||
|
self.send_to_scheduler.send_pyobj(req)
|
||||||
|
|
||||||
|
def start_profile(self):
|
||||||
|
+ rpd = rpdTracerControl()
|
||||||
|
+ rpd.setPythonTrace(True)
|
||||||
|
+ rpd.start()
|
||||||
|
+ rpd.rangePush("", "tokenizer_manager", "")
|
||||||
|
+ logger.info("tokenizer_manager rpd profiling started!")
|
||||||
|
req = ProfileReq.START_PROFILE
|
||||||
|
self.send_to_scheduler.send_pyobj(req)
|
||||||
|
|
||||||
|
def stop_profile(self):
|
||||||
|
+ rpd = rpdTracerControl()
|
||||||
|
+ rpd.rangePop()
|
||||||
|
+ rpd.stop()
|
||||||
|
+ rpd.flush()
|
||||||
|
+ logger.info("rpd profiling is done inside tokenizer_manager!")
|
||||||
|
req = ProfileReq.STOP_PROFILE
|
||||||
|
self.send_to_scheduler.send_pyobj(req)
|
||||||
|
|
||||||
|
diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py
|
||||||
|
index 7111c93..2bd722c 100644
|
||||||
|
--- a/python/sglang/srt/server.py
|
||||||
|
+++ b/python/sglang/srt/server.py
|
||||||
|
@@ -30,6 +30,8 @@ import threading
|
||||||
|
import time
|
||||||
|
from http import HTTPStatus
|
||||||
|
from typing import Dict, List, Optional, Union
|
||||||
|
+from rpdTracerControl import rpdTracerControl
|
||||||
|
+rpdTracerControl.skipCreate()
|
||||||
|
|
||||||
|
# Fix a bug of Python threading
|
||||||
|
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
|
||||||
|
@@ -152,6 +154,11 @@ async def flush_cache():
|
||||||
|
@app.post("/start_profile")
|
||||||
|
async def start_profile():
|
||||||
|
"""Start profiling."""
|
||||||
|
+ rpd = rpdTracerControl()
|
||||||
|
+ rpd.setPythonTrace(True)
|
||||||
|
+ rpd.start()
|
||||||
|
+ rpd.rangePush("", "server rpd profile range", "")
|
||||||
|
+ logger.info("rpd profiling started in server.py!")
|
||||||
|
tokenizer_manager.start_profile()
|
||||||
|
return Response(
|
||||||
|
content="Start profiling.\n",
|
||||||
|
@@ -164,6 +171,11 @@ async def start_profile():
|
||||||
|
async def stop_profile():
|
||||||
|
"""Stop profiling."""
|
||||||
|
tokenizer_manager.stop_profile()
|
||||||
|
+ rpd = rpdTracerControl()
|
||||||
|
+ rpd.rangePop()
|
||||||
|
+ rpd.stop()
|
||||||
|
+ rpd.flush()
|
||||||
|
+ logger.info("rpd profiling is done in server.py!")
|
||||||
|
return Response(
|
||||||
|
content="Stop profiling. This will take some time.\n",
|
||||||
|
status_code=200,
|
||||||
|
```
|
||||||
|
|
||||||
|
4. As an example for grok1 profiling, we create a dummy_grok1 directory with config.json (see content below) inside this directory and copy this directory to the right path for "--model-path" if you want to use the example server.sh file provided.
|
||||||
|
```bash
|
||||||
|
cat ../dummy_grok1/config.json
|
||||||
|
{
|
||||||
|
"architectures": [
|
||||||
|
"Grok1ModelForCausalLM"
|
||||||
|
],
|
||||||
|
"embedding_multiplier_scale": 78.38367176906169,
|
||||||
|
"output_multiplier_scale": 0.5773502691896257,
|
||||||
|
"vocab_size": 131072,
|
||||||
|
"hidden_size": 6144,
|
||||||
|
"intermediate_size": 32768,
|
||||||
|
"max_position_embeddings": 8192,
|
||||||
|
"num_experts_per_tok": 2,
|
||||||
|
"num_local_experts": 8,
|
||||||
|
"num_attention_heads": 48,
|
||||||
|
"num_hidden_layers": 64,
|
||||||
|
"num_key_value_heads": 8,
|
||||||
|
"head_dim": 128,
|
||||||
|
"rms_norm_eps": 1e-05,
|
||||||
|
"rope_theta": 10000.0,
|
||||||
|
"model_type": "mixtral",
|
||||||
|
"torch_dtype": "bfloat16"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
5. Launch server with rpd enabled script ./server.sh in one terminal inside the docker container.
|
||||||
|
|
||||||
|
#### Common Notes 2
|
||||||
|
- Remember to change model-path to the correct path
|
||||||
|
- loadTracer.sh is needed to conduct profiling
|
||||||
|
- SGLANG_TORCH_PROFILER_DIR is used for default torch profiler
|
||||||
|
- Do not use loadTracer.sh if you are using the torch profiler, simply use python3 -m sglang.launch_server.
|
||||||
|
|
||||||
|
|
||||||
|
server.sh
|
||||||
|
|
||||||
|
```bash
|
||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# export SGLANG_TORCH_PROFILER_DIR=/data/sglang/
|
||||||
|
export SGLANG_TORCH_PROFILER_DIR=/sgl-workspace/sglang/profile/
|
||||||
|
|
||||||
|
# Get the current timestamp
|
||||||
|
TIMESTAMP=$(date +"%Y%m%d_%H%M%S")
|
||||||
|
|
||||||
|
# Define the log file with a timestamp
|
||||||
|
LOGFILE="sglang_server_log_$TIMESTAMP.json"
|
||||||
|
|
||||||
|
# Run the Python command and save the output to the log file
|
||||||
|
loadTracer.sh python3 -m sglang.launch_server \
|
||||||
|
--model-path /sgl-workspace/sglang/dummy_grok1 \
|
||||||
|
--tokenizer-path Xenova/grok-1-tokenizer \
|
||||||
|
--load-format dummy \
|
||||||
|
--quantization fp8 \
|
||||||
|
--tp 8 \
|
||||||
|
--port 30000 \
|
||||||
|
--disable-radix-cache 2>&1 | tee "$LOGFILE"
|
||||||
|
```
|
||||||
|
6. Open another terminal for the same docker container, and run the rpd enabled ./client.sh after you see "The server is fired up and is ready to roll!" message from server side terminal.
|
||||||
|
|
||||||
|
#### Common Notes 3
|
||||||
|
- Use curl http://localhost:30000/start_profile & curl http://localhost:30000/stop_profile to control the start and end of profiling. Check sglang/python/sglang/srt/managers/scheduler.py for more details.
|
||||||
|
- Please don't use RPD profiler together with PyTorch profiler to avoid interference.
|
||||||
|
- The rocmProfileData/tools/rpd2tracing.py file is used to generate json file from RPD file.
|
||||||
|
|
||||||
|
client.sh
|
||||||
|
|
||||||
|
```bash
|
||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# Start profiling via API
|
||||||
|
curl http://localhost:30000/start_profile -H "Content-Type: application/json"
|
||||||
|
|
||||||
|
# Benchmark serving using sglang with random dataset and tokenizer
|
||||||
|
# Define the log file with a timestamp
|
||||||
|
TIMESTAMP=$(date +%Y%m%d_%H%M%S)
|
||||||
|
LOGFILE="sglang_client_log_$TIMESTAMP.json"
|
||||||
|
|
||||||
|
# Run the benchmark with specified parameters and save logs
|
||||||
|
python3 -m sglang.bench_serving \
|
||||||
|
--backend sglang \
|
||||||
|
--tokenizer Xenova/grok-1-tokenizer \
|
||||||
|
--dataset-name random \
|
||||||
|
--random-input 1024\
|
||||||
|
--random-output 1024 \
|
||||||
|
--num-prompts 120 \
|
||||||
|
--request-rate 8 \
|
||||||
|
--output-file online.jsonl 2>&1 | tee "$LOGFILE"
|
||||||
|
|
||||||
|
# Stop profiling via API
|
||||||
|
curl http://localhost:30000/stop_profile -H "Content-Type: application/json"
|
||||||
|
|
||||||
|
# Convert tracing file to csv & json
|
||||||
|
sqlite3 trace.rpd ".mode csv" ".header on" ".output trace.csv" "select * from top;" ".output stdout"
|
||||||
|
python3 ./rocmProfileData/tools/rpd2tracing.py trace.rpd trace.json
|
||||||
|
```
|
||||||
|
7. Follow [Perfetto docs](https://perfetto.dev/docs/visualization/large-traces) to visualize large json files. Try to adjust parameters so that the trace.json file size is less than 9GB.
|
||||||
|
|
||||||
|
### Profiling SGLang Infer System with PyTorch Profiler
|
||||||
|
|
||||||
|
Please use the steps as follows:
|
||||||
|
|
||||||
|
1. Apply the patch torch_profiler.patch. Note that you can modify "if self.tp_rank == 0" in the patch to allow more ranks be recorded in profiling.
|
||||||
|
|
||||||
|
torch_profiler.patch
|
||||||
|
```bash
|
||||||
|
diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py
|
||||||
|
index 62d1ff9..6ecd78c 100644
|
||||||
|
--- a/python/sglang/srt/managers/scheduler.py
|
||||||
|
+++ b/python/sglang/srt/managers/scheduler.py
|
||||||
|
@@ -240,7 +240,6 @@ class Scheduler:
|
||||||
|
)
|
||||||
|
self.profiler = torch.profiler.profile(
|
||||||
|
activities=[
|
||||||
|
- torch.profiler.ProfilerActivity.CPU,
|
||||||
|
torch.profiler.ProfilerActivity.CUDA,
|
||||||
|
],
|
||||||
|
with_stack=True,
|
||||||
|
@@ -1033,9 +1032,11 @@ class Scheduler:
|
||||||
|
if self.profiler is None:
|
||||||
|
raise RuntimeError("Profiler is not enabled.")
|
||||||
|
self.profiler.stop()
|
||||||
|
- self.profiler.export_chrome_trace(
|
||||||
|
- self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz"
|
||||||
|
- )
|
||||||
|
+ if self.tp_rank == 0:
|
||||||
|
+ with open(f"stats_repro_{int(time.time())}.txt", "w") as f:
|
||||||
|
+ print(self.profiler.key_averages(group_by_input_shape=True).table(sort_by="cuda_time_total", row_limit=-1), file=f)
|
||||||
|
+ print("Profiling stats done.")
|
||||||
|
+
|
||||||
|
logger.info("Profiler is done")
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Create the model path directory and copy it to the right path for "--model-path" if you want to use the server.sh file provided.
|
||||||
|
|
||||||
|
3. Modify the included server.sh by removing "loadTracer.sh" before python command and launch script ./server.sh in one terminal inside the docker container.
|
||||||
|
|
||||||
|
4. Similar to step 6 in RPD profiling section, but remove the last 2 lines in client.sh, which converted rpd file into csv and json files. Run modified client.sh for PyTorch profiling.
|
||||||
|
-------
|
||||||
|
- [Torch Profiler](https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html)
|
||||||
|
|
@ -0,0 +1,27 @@
|
||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# Start profiling via API
|
||||||
|
curl http://localhost:30000/start_profile -H "Content-Type: application/json"
|
||||||
|
|
||||||
|
# Benchmark serving using sglang with random dataset and tokenizer
|
||||||
|
# Define the log file with a timestamp
|
||||||
|
TIMESTAMP=$(date +%Y%m%d_%H%M%S)
|
||||||
|
LOGFILE="sglang_client_log_$TIMESTAMP.json"
|
||||||
|
|
||||||
|
# Run the benchmark with specified parameters and save logs
|
||||||
|
python3 -m sglang.bench_serving \
|
||||||
|
--backend sglang \
|
||||||
|
--tokenizer Xenova/grok-1-tokenizer \
|
||||||
|
--dataset-name random \
|
||||||
|
--random-input 1024\
|
||||||
|
--random-output 1024 \
|
||||||
|
--num-prompts 240 \
|
||||||
|
--request-rate 8 \
|
||||||
|
--output-file online.jsonl 2>&1 | tee "$LOGFILE"
|
||||||
|
|
||||||
|
# Stop profiling via API
|
||||||
|
curl http://localhost:30000/stop_profile -H "Content-Type: application/json"
|
||||||
|
|
||||||
|
# Convert tracing file to csv & json
|
||||||
|
sqlite3 trace.rpd ".mode csv" ".header on" ".output trace.csv" "select * from top;" ".output stdout"
|
||||||
|
python3 /sgl-workspace/rocmProfileData/tools/rpd2tracing.py trace.rpd trace.json
|
||||||
|
|
@ -0,0 +1,10 @@
|
||||||
|
# download and install RPD
|
||||||
|
apt update && apt install -y sqlite3 libsqlite3-dev libfmt-dev
|
||||||
|
|
||||||
|
# install rpd module
|
||||||
|
git clone https://github.com/ROCmSoftwarePlatform/rocmProfileData
|
||||||
|
cd rocmProfileData
|
||||||
|
git apply rpd.patch
|
||||||
|
make && make install
|
||||||
|
cd rocpd_python && python setup.py install && cd ..
|
||||||
|
cd rpd_tracer && make clean;make install && python setup.py install && cd ..
|
||||||
|
|
@ -0,0 +1,43 @@
|
||||||
|
#!/bin/bash
|
||||||
|
################################################################################
|
||||||
|
# Copyright (c) 2021 - 2023 Advanced Micro Devices, Inc. All rights reserved.
|
||||||
|
#
|
||||||
|
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
# of this software and associated documentation files (the "Software"), to deal
|
||||||
|
# in the Software without restriction, including without limitation the rights
|
||||||
|
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
# copies of the Software, and to permit persons to whom the Software is
|
||||||
|
# furnished to do so, subject to the following conditions:
|
||||||
|
#
|
||||||
|
# The above copyright notice and this permission notice shall be included in
|
||||||
|
# all copies or substantial portions of the Software.
|
||||||
|
#
|
||||||
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||||
|
# THE SOFTWARE.
|
||||||
|
################################################################################
|
||||||
|
OUTPUT_FILE="trace.rpd"
|
||||||
|
|
||||||
|
if [ "$1" = "-o" ] ; then
|
||||||
|
OUTPUT_FILE=$2
|
||||||
|
shift
|
||||||
|
shift
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ -e ${OUTPUT_FILE} ] ; then
|
||||||
|
rm ${OUTPUT_FILE}
|
||||||
|
fi
|
||||||
|
|
||||||
|
python3 -m rocpd.schema --create ${OUTPUT_FILE}
|
||||||
|
if [ $? != 0 ] ; then
|
||||||
|
echo "Error: Could not create rpd file. Please run 'python setup.py install' from the rocpd_python dir"
|
||||||
|
exit
|
||||||
|
fi
|
||||||
|
|
||||||
|
export RPDT_FILENAME=${OUTPUT_FILE}
|
||||||
|
export RPDT_AUTOSTART=0
|
||||||
|
LD_PRELOAD=librocm-smi_64:librpd_tracer.so "$@"
|
||||||
|
|
@ -0,0 +1,12 @@
|
||||||
|
diff --git a/rpd_tracer/Makefile b/rpd_tracer/Makefile
|
||||||
|
index e9d9feb..b2e9e1a 100644
|
||||||
|
--- a/rpd_tracer/Makefile
|
||||||
|
+++ b/rpd_tracer/Makefile
|
||||||
|
@@ -16,7 +16,7 @@ ifneq (,$(HIP_PATH))
|
||||||
|
$(info Building with roctracer)
|
||||||
|
RPD_LIBS += -L/opt/rocm/lib -lroctracer64 -lroctx64 -lamdhip64 -lrocm_smi64
|
||||||
|
RPD_INCLUDES += -I/opt/rocm/include -I/opt/rocm/include/roctracer -I/opt/rocm/include/hsa
|
||||||
|
- RPD_SRCS += RoctracerDataSource.cpp RocmSmiDataSource.cpp
|
||||||
|
+ RPD_SRCS += RoctracerDataSource.cpp
|
||||||
|
RPD_INCLUDES += -D__HIP_PLATFORM_AMD__
|
||||||
|
endif
|
||||||
|
|
@ -0,0 +1,49 @@
|
||||||
|
diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py
|
||||||
|
index 62d1ff9..9021c01 100644
|
||||||
|
--- a/python/sglang/srt/managers/scheduler.py
|
||||||
|
+++ b/python/sglang/srt/managers/scheduler.py
|
||||||
|
@@ -71,6 +71,8 @@ from sglang.srt.utils import (
|
||||||
|
suppress_other_loggers,
|
||||||
|
)
|
||||||
|
from sglang.utils import get_exception_traceback
|
||||||
|
+from rpdTracerControl import rpdTracerControl
|
||||||
|
+rpdTracerControl.skipCreate()
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@@ -245,6 +247,7 @@ class Scheduler:
|
||||||
|
],
|
||||||
|
with_stack=True,
|
||||||
|
)
|
||||||
|
+ self.rpd = rpdTracerControl()
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def event_loop(self):
|
||||||
|
@@ -1027,15 +1030,24 @@ class Scheduler:
|
||||||
|
def start_profile(self) -> None:
|
||||||
|
if self.profiler is None:
|
||||||
|
raise RuntimeError("Profiler is not enabled.")
|
||||||
|
- self.profiler.start()
|
||||||
|
+ #self.profiler.start() #block pytorch profiler for rpd profiler enabling
|
||||||
|
+ if self.tp_rank == 0 or self.tp_rank == 1:
|
||||||
|
+ self.rpd.start()
|
||||||
|
+ self.rpd.rangePush("", "rpd profile range", "")
|
||||||
|
+ logger.info("rpd is enabled")
|
||||||
|
|
||||||
|
def stop_profile(self) -> None:
|
||||||
|
if self.profiler is None:
|
||||||
|
raise RuntimeError("Profiler is not enabled.")
|
||||||
|
- self.profiler.stop()
|
||||||
|
- self.profiler.export_chrome_trace(
|
||||||
|
- self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz"
|
||||||
|
- )
|
||||||
|
+ #self.profiler.stop()
|
||||||
|
+ #self.profiler.export_chrome_trace(
|
||||||
|
+ # self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz"
|
||||||
|
+ #)
|
||||||
|
+ if self.tp_rank ==0 or self.tp_rank ==1:
|
||||||
|
+ self.rpd.rangePop()
|
||||||
|
+ self.rpd.stop()
|
||||||
|
+ self.rpd.flush()
|
||||||
|
+ logger.info("rpd is done")
|
||||||
|
logger.info("Profiler is done")
|
||||||
126
sglang/3rdparty/amd/profiling/rpd_profile_server_enable_wCPU_activities.patch
vendored
Normal file
126
sglang/3rdparty/amd/profiling/rpd_profile_server_enable_wCPU_activities.patch
vendored
Normal file
|
|
@ -0,0 +1,126 @@
|
||||||
|
diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py
|
||||||
|
index 62d1ff9..2edb427 100644
|
||||||
|
--- a/python/sglang/srt/managers/scheduler.py
|
||||||
|
+++ b/python/sglang/srt/managers/scheduler.py
|
||||||
|
@@ -71,6 +71,8 @@ from sglang.srt.utils import (
|
||||||
|
suppress_other_loggers,
|
||||||
|
)
|
||||||
|
from sglang.utils import get_exception_traceback
|
||||||
|
+from rpdTracerControl import rpdTracerControl
|
||||||
|
+rpdTracerControl.skipCreate()
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@@ -245,6 +247,7 @@ class Scheduler:
|
||||||
|
],
|
||||||
|
with_stack=True,
|
||||||
|
)
|
||||||
|
+ self.rpd = rpdTracerControl()
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def event_loop(self):
|
||||||
|
@@ -1027,15 +1030,26 @@ class Scheduler:
|
||||||
|
def start_profile(self) -> None:
|
||||||
|
if self.profiler is None:
|
||||||
|
raise RuntimeError("Profiler is not enabled.")
|
||||||
|
- self.profiler.start()
|
||||||
|
+ #self.profiler.start()
|
||||||
|
+ logger.info("torch profiler is disabled")
|
||||||
|
+ if self.tp_rank == 0 or self.tp_rank == 1:
|
||||||
|
+ self.rpd.setPythonTrace(True)
|
||||||
|
+ self.rpd.start()
|
||||||
|
+ self.rpd.rangePush("", "scheduler", "")
|
||||||
|
+ logger.info("rpd is enabled inside scheduler profiling")
|
||||||
|
|
||||||
|
def stop_profile(self) -> None:
|
||||||
|
if self.profiler is None:
|
||||||
|
raise RuntimeError("Profiler is not enabled.")
|
||||||
|
- self.profiler.stop()
|
||||||
|
- self.profiler.export_chrome_trace(
|
||||||
|
- self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz"
|
||||||
|
- )
|
||||||
|
+ #self.profiler.stop()
|
||||||
|
+ #self.profiler.export_chrome_trace(
|
||||||
|
+ # self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz"
|
||||||
|
+ #)
|
||||||
|
+ if self.tp_rank ==0 or self.tp_rank ==1:
|
||||||
|
+ self.rpd.rangePop()
|
||||||
|
+ self.rpd.stop()
|
||||||
|
+ self.rpd.flush()
|
||||||
|
+ logger.info("rpd is done inside scheduler")
|
||||||
|
logger.info("Profiler is done")
|
||||||
|
|
||||||
|
|
||||||
|
diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py
|
||||||
|
index 2621ccd..181df85 100644
|
||||||
|
--- a/python/sglang/srt/managers/tokenizer_manager.py
|
||||||
|
+++ b/python/sglang/srt/managers/tokenizer_manager.py
|
||||||
|
@@ -58,6 +58,10 @@ from sglang.srt.sampling.sampling_params import SamplingParams
|
||||||
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||||
|
from sglang.srt.utils import is_generation_model, is_multimodal_model
|
||||||
|
|
||||||
|
+from rpdTracerControl import rpdTracerControl
|
||||||
|
+rpdTracerControl.skipCreate()
|
||||||
|
+
|
||||||
|
+
|
||||||
|
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
@@ -514,10 +518,20 @@ class TokenizerManager:
|
||||||
|
self.send_to_scheduler.send_pyobj(req)
|
||||||
|
|
||||||
|
def start_profile(self):
|
||||||
|
+ rpd = rpdTracerControl()
|
||||||
|
+ rpd.setPythonTrace(True)
|
||||||
|
+ rpd.start()
|
||||||
|
+ rpd.rangePush("", "tokenizer_manager", "")
|
||||||
|
+ logger.info("tokenizer_manager rpd profiling started!")
|
||||||
|
req = ProfileReq.START_PROFILE
|
||||||
|
self.send_to_scheduler.send_pyobj(req)
|
||||||
|
|
||||||
|
def stop_profile(self):
|
||||||
|
+ rpd = rpdTracerControl()
|
||||||
|
+ rpd.rangePop()
|
||||||
|
+ rpd.stop()
|
||||||
|
+ rpd.flush()
|
||||||
|
+ logger.info("rpd profiling is done inside tokenizer_manager!")
|
||||||
|
req = ProfileReq.STOP_PROFILE
|
||||||
|
self.send_to_scheduler.send_pyobj(req)
|
||||||
|
|
||||||
|
diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py
|
||||||
|
index 7111c93..2bd722c 100644
|
||||||
|
--- a/python/sglang/srt/server.py
|
||||||
|
+++ b/python/sglang/srt/server.py
|
||||||
|
@@ -30,6 +30,8 @@ import threading
|
||||||
|
import time
|
||||||
|
from http import HTTPStatus
|
||||||
|
from typing import Dict, List, Optional, Union
|
||||||
|
+from rpdTracerControl import rpdTracerControl
|
||||||
|
+rpdTracerControl.skipCreate()
|
||||||
|
|
||||||
|
# Fix a bug of Python threading
|
||||||
|
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
|
||||||
|
@@ -152,6 +154,11 @@ async def flush_cache():
|
||||||
|
@app.post("/start_profile")
|
||||||
|
async def start_profile():
|
||||||
|
"""Start profiling."""
|
||||||
|
+ rpd = rpdTracerControl()
|
||||||
|
+ rpd.setPythonTrace(True)
|
||||||
|
+ rpd.start()
|
||||||
|
+ rpd.rangePush("", "server rpd profile range", "")
|
||||||
|
+ logger.info("rpd profiling started in server.py!")
|
||||||
|
tokenizer_manager.start_profile()
|
||||||
|
return Response(
|
||||||
|
content="Start profiling.\n",
|
||||||
|
@@ -164,6 +171,11 @@ async def start_profile():
|
||||||
|
async def stop_profile():
|
||||||
|
"""Stop profiling."""
|
||||||
|
tokenizer_manager.stop_profile()
|
||||||
|
+ rpd = rpdTracerControl()
|
||||||
|
+ rpd.rangePop()
|
||||||
|
+ rpd.stop()
|
||||||
|
+ rpd.flush()
|
||||||
|
+ logger.info("rpd profiling is done in server.py!")
|
||||||
|
return Response(
|
||||||
|
content="Stop profiling. This will take some time.\n",
|
||||||
|
status_code=200,
|
||||||
|
|
@ -0,0 +1,20 @@
|
||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# export SGLANG_TORCH_PROFILER_DIR=/data/sglang/
|
||||||
|
export SGLANG_TORCH_PROFILER_DIR=/sgl-workspace/sglang/profile/
|
||||||
|
|
||||||
|
# Get the current timestamp
|
||||||
|
TIMESTAMP=$(date +"%Y%m%d_%H%M%S")
|
||||||
|
|
||||||
|
# Define the log file with a timestamp
|
||||||
|
LOGFILE="sglang_server_log_$TIMESTAMP.json"
|
||||||
|
|
||||||
|
# Run the Python command and save the output to the log file
|
||||||
|
loadTracer.sh python3 -m sglang.launch_server \
|
||||||
|
--model-path /sgl-workspace/sglang/dummy_grok1 \
|
||||||
|
--tokenizer-path Xenova/grok-1-tokenizer \
|
||||||
|
--load-format dummy \
|
||||||
|
--quantization fp8 \
|
||||||
|
--tp 8 \
|
||||||
|
--port 30000 \
|
||||||
|
--disable-radix-cache 2>&1 | tee "$LOGFILE"
|
||||||
|
|
@ -0,0 +1,25 @@
|
||||||
|
diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py
|
||||||
|
index 62d1ff9..6ecd78c 100644
|
||||||
|
--- a/python/sglang/srt/managers/scheduler.py
|
||||||
|
+++ b/python/sglang/srt/managers/scheduler.py
|
||||||
|
@@ -240,7 +240,6 @@ class Scheduler:
|
||||||
|
)
|
||||||
|
self.profiler = torch.profiler.profile(
|
||||||
|
activities=[
|
||||||
|
- torch.profiler.ProfilerActivity.CPU,
|
||||||
|
torch.profiler.ProfilerActivity.CUDA,
|
||||||
|
],
|
||||||
|
with_stack=True,
|
||||||
|
@@ -1033,9 +1032,11 @@ class Scheduler:
|
||||||
|
if self.profiler is None:
|
||||||
|
raise RuntimeError("Profiler is not enabled.")
|
||||||
|
self.profiler.stop()
|
||||||
|
- self.profiler.export_chrome_trace(
|
||||||
|
- self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz"
|
||||||
|
- )
|
||||||
|
+ if self.tp_rank == 0:
|
||||||
|
+ with open(f"stats_repro_{int(time.time())}.txt", "w") as f:
|
||||||
|
+ print(self.profiler.key_averages(group_by_input_shape=True).table(sort_by="cuda_time_total", row_limit=-1), file=f)
|
||||||
|
+ print("Profiling stats done.")
|
||||||
|
+
|
||||||
|
logger.info("Profiler is done")
|
||||||
|
|
@ -0,0 +1,118 @@
|
||||||
|
## Tuning SGLang Infer System with AMD GPUs
|
||||||
|
This AppNote describes the SGLang performance tuning technical, code harness and running steps for systems with AMD Instinct GPUs.
|
||||||
|
Harness code, examples and steps are provided in detail, to facilitate easy reproduce & use to tune performance towards workloads.
|
||||||
|
Three primary runtime areas are covered:
|
||||||
|
|
||||||
|
## 1. Triton Kernels
|
||||||
|
To maximize Triton kernel efficiency, several strategies can be employed:
|
||||||
|
|
||||||
|
### Key Environment Variables:
|
||||||
|
- **num_stages**: Adjusts the number of pipeline stages to optimize kernel efficiency based on the specific type of operations (e.g., General Matrix Multiplication - GEMM).
|
||||||
|
- **waves_per_eu**: Controls the usage of Vector General Purpose Registers (VGPR) to enhance occupancy, thereby improving latency or throughput.
|
||||||
|
- **BLOCK_M, BLOCK_N, BLOCK_K**: Tunable tile sizes that assist in balancing memory transfer and computational efficiency.
|
||||||
|
- **matrix_instr_nonkdim**: Optimizes the usage of Matrix-Fused Multiply-Add (MFMA) instructions for specific kernel types, such as Flash Attention.
|
||||||
|
- **OPTIMIZE_EPILOGUE**: An environment variable that can be set to `1` to enhance performance by eliminating the `convert_layout` operation in the kernel's epilogue.
|
||||||
|
```python
|
||||||
|
@triton.autotune(configs=[
|
||||||
|
triton.Config({'waves_per_eu': 1}, num_warps=4, num_stages=1),
|
||||||
|
triton.Config({'waves_per_eu': 1}, num_warps=8, num_stages=1),
|
||||||
|
triton.Config({'waves_per_eu': 1}, num_warps=16, num_stages=1),
|
||||||
|
triton.Config({'waves_per_eu': 2}, num_warps=4, num_stages=1),
|
||||||
|
triton.Config({'waves_per_eu': 2}, num_warps=8, num_stages=1),
|
||||||
|
triton.Config({'waves_per_eu': 2}, num_warps=16, num_stages=1),
|
||||||
|
triton.Config({'waves_per_eu': 4}, num_warps=4, num_stages=1),
|
||||||
|
triton.Config({'waves_per_eu': 4}, num_warps=8, num_stages=1),
|
||||||
|
triton.Config({'waves_per_eu': 4}, num_warps=16, num_stages=1),
|
||||||
|
], key=['BLOCK_N', 'NUM_TOKEN_BLKS'], use_cuda_graph=True)
|
||||||
|
@triton.jit
|
||||||
|
def _triton_kernel_funtion():
|
||||||
|
...
|
||||||
|
```
|
||||||
|
## 2. Torch Tunable Operations
|
||||||
|
**TunableOp** is a feature in PyTorch that allows for the definition and optimization of custom kernels with tunable parameters. This feature is particularly useful for enhancing the performance of kernels by experimenting with different configurations.
|
||||||
|
|
||||||
|
### Key Environment Variables:
|
||||||
|
1. **PYTORCH_TUNABLEOP_ENABLED**:
|
||||||
|
- Default: `0`
|
||||||
|
- Set to `1` to enable TunableOp.
|
||||||
|
|
||||||
|
2. **PYTORCH_TUNABLEOP_TUNING**:
|
||||||
|
- Default: `1`
|
||||||
|
- Set to `0` to disable tuning. If a tuned entry is not found, it will run the tuning step and record the entry when PYTORCH_TUNABLEOP_ENABLED is enabled.
|
||||||
|
|
||||||
|
3. **PYTORCH_TUNABLEOP_VERBOSE**:
|
||||||
|
- Default: `0`
|
||||||
|
- Set to `1` to enable verbose output for TunableOp.
|
||||||
|
|
||||||
|
### Usage Example:
|
||||||
|
To enable TunableOp and tuning, and optionally enable verbose mode, you can run the following command in your terminal:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
#Tuning
|
||||||
|
PYTORCH_TUNABLEOP_ENABLED=1 PYTORCH_TUNABLEOP_TUNING=1 your_script.sh
|
||||||
|
|
||||||
|
#Inference with tuning op
|
||||||
|
PYTORCH_TUNABLEOP_ENABLED=1 PYTORCH_TUNABLEOP_TUNING=0 your_script.sh
|
||||||
|
|
||||||
|
#Print out the log
|
||||||
|
PYTORCH_TUNABLEOP_ENABLED=1 PYTORCH_TUNABLEOP_TUNING=0 PYTORCH_TUNABLEOP_VERBOSE=1 your_script.sh
|
||||||
|
|
||||||
|
```
|
||||||
|
## 3. Torch Compilation
|
||||||
|
|
||||||
|
|
||||||
|
The following are suggestions for optimizing matrix multiplication (GEMM) and convolution (conv) operations in PyTorch using Inductor, a part of the PyTorch compilation framework. The goal is to leverage Triton to achieve better performance.
|
||||||
|
|
||||||
|
To tune Triton kernels with GEMM and convolution ops (conv), use the `torch.compile` function with the max-autotune mode. This benchmarks a predefined list of Triton configurations and selects the fastest one for each shape.
|
||||||
|
|
||||||
|
### Key Configurations:
|
||||||
|
1. **Max Autotune**:
|
||||||
|
- Set `torch._inductor.config.max_autotune = True` or `TORCHINDUCTOR_MAX_AUTOTUNE=1`.
|
||||||
|
|
||||||
|
2. **Fine-Grained Control**:
|
||||||
|
- Enable GEMM tuning: `torch._inductor.config.max_autotune_gemm = True`.
|
||||||
|
- Enable tuning for pointwise/reduction ops: `torch._inductor.config.max_autotune.pointwise = True`.
|
||||||
|
|
||||||
|
3. **Backend Selection**:
|
||||||
|
- Use `torch._inductor.max_autotune_gemm_backends` to limit backends to TRITON for better performance.
|
||||||
|
|
||||||
|
4. **Freezing for Inference**:
|
||||||
|
- Use `torch._inductor.config.freezing=True` to enable constant folding optimizations.
|
||||||
|
|
||||||
|
5. **Debugging**:
|
||||||
|
- Set `TORCH_COMPILE_DEBUG=1` to extract Triton kernels generated by Inductor.
|
||||||
|
|
||||||
|
### Example Code Block:
|
||||||
|
```bash
|
||||||
|
#Gemm Tuning
|
||||||
|
TORCHINDUCTOR_MAX_AUTOTUNE=1 TORCHINDUCTOR_COORDINATE_DESCENT_TUNING=1 your_script.sh
|
||||||
|
|
||||||
|
#Specify your backend to TRITON for Gemm Tuning
|
||||||
|
TORCHINDUCTOR_MAX_AUTOTUNE=1 TORCHINDUCTOR_COORDINATE_DESCENT_TUNING=1 TORCHINDUCTOR_MAX_AUTOTUNE_GEMM_BACKENDS=TRITON your_script.sh
|
||||||
|
|
||||||
|
#Inference with large improvement on AMD GPU
|
||||||
|
TORCHINDUCTOR_FREEZING=1 your_script.sh
|
||||||
|
```
|
||||||
|
## 4. Fused MOE kernel
|
||||||
|
To maximize moe kernel efficiency, need to use below scripts to find out the best launch configuration
|
||||||
|
|
||||||
|
### Key parameters:
|
||||||
|
- **--model**: what moe model type to do tuning, it will automatically decide the size of d_model, model_intermediate_size, num_layers
|
||||||
|
- **--tp-size**: simulate the whole model run configuration to set the dimension size using tp correctly
|
||||||
|
- **--batch**: M dimension size of moe kernel, for prefill moe kernel the value is batch*input_len, for decode moe kernel the value is batch
|
||||||
|
- **--dtype**: computation type
|
||||||
|
|
||||||
|
```bash
|
||||||
|
#Tuning
|
||||||
|
#for example, we have one case like this "python3 -m sglang.bench_latency --model dummy_grok1/ --load-format dummy --tokenizer-path Xenova/grok-1-tokenizer --tp 8 --batch-size 32 --input 1024 --output 8 --attention-backend triton --sampling-backend pytorch --quantization fp8" to run, it defined batch-size 32 input lenth 1024 and output length 8, from "--batch" in moe view point, the prefill batch is 32*1024 = 32768, the decode batch is 32*1(only one output token generated in each run).
|
||||||
|
#so we can tune decode moe use below command
|
||||||
|
python benchmark_moe_rocm.py --model grok1 --tp-size 8 --dtype float8 --batch "32"
|
||||||
|
# and use this command to tune prefill moe
|
||||||
|
python benchmark_moe_rocm.py --model grok1 --tp-size 8 --dtype float8 --batch "32768"
|
||||||
|
```
|
||||||
|
|
||||||
|
## Reference
|
||||||
|
|
||||||
|
For more detailed information on tuning SGLang performance with AMD GPUs, please refer to the following link:
|
||||||
|
|
||||||
|
[ROCm Documentation: Triton Kernel Performance Optimization](https://rocm.docs.amd.com/en/latest/how-to/tuning-guides/mi300x/workload.html#triton-kernel-performance-optimization)
|
||||||
|
|
@ -0,0 +1,380 @@
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
from tqdm import tqdm
|
||||||
|
from transformers import AutoConfig
|
||||||
|
|
||||||
|
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
|
||||||
|
fused_moe,
|
||||||
|
get_config_file_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0
|
||||||
|
|
||||||
|
|
||||||
|
def main(model, tp_size, dtype: str, batches):
|
||||||
|
method = fused_moe
|
||||||
|
|
||||||
|
for bs in batches:
|
||||||
|
run_grid(int(bs), model=model, method=method, tp_size=tp_size, dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def prune_configs(M, N, K, configs):
|
||||||
|
pruned_configs = []
|
||||||
|
elemBytes_a = 1 # [DV Note] Hard-coded for float16 (2 bytes)
|
||||||
|
elemBytes_b = 1 # [DV Note] Hard-coded for float16 (2 bytes)
|
||||||
|
|
||||||
|
mfma = 16 if M < 32 or N < 32 else 32
|
||||||
|
|
||||||
|
# TODO (zhanglx): figure out the boundary between large and small gemms
|
||||||
|
large_gemm = False
|
||||||
|
if M >= 2048 and N >= 2048:
|
||||||
|
large_gemm = True
|
||||||
|
|
||||||
|
for config in configs:
|
||||||
|
BLOCK_SIZE_M = config.get("BLOCK_SIZE_M")
|
||||||
|
BLOCK_SIZE_N = config.get("BLOCK_SIZE_N")
|
||||||
|
BLOCK_SIZE_K = config.get("BLOCK_SIZE_K")
|
||||||
|
num_warps = config.get("num_warps")
|
||||||
|
matrix_instr_nonkdim = config.get("matrix_instr_nonkdim")
|
||||||
|
# kpack = config.get("kpack")
|
||||||
|
if matrix_instr_nonkdim > mfma:
|
||||||
|
continue
|
||||||
|
if mfma == 4 and BLOCK_SIZE_K < 64:
|
||||||
|
continue
|
||||||
|
# some layouts could not work properly in case
|
||||||
|
# number elements per thread is less 1
|
||||||
|
if BLOCK_SIZE_M * BLOCK_SIZE_N < 64:
|
||||||
|
continue
|
||||||
|
SPLIT_K = 1 # config.get("SPLIT_K")
|
||||||
|
GROUP_M = config.get("GROUP_SIZE_M")
|
||||||
|
if matrix_instr_nonkdim > BLOCK_SIZE_M or matrix_instr_nonkdim > BLOCK_SIZE_N:
|
||||||
|
continue
|
||||||
|
if matrix_instr_nonkdim >= M and matrix_instr_nonkdim != BLOCK_SIZE_M:
|
||||||
|
continue
|
||||||
|
if matrix_instr_nonkdim >= N and matrix_instr_nonkdim != BLOCK_SIZE_N:
|
||||||
|
continue
|
||||||
|
# Skip BLOCK_SIZE that is too large compare to M/N
|
||||||
|
# unless BLOCK_SIZE is already small enough
|
||||||
|
if M * 2 < BLOCK_SIZE_M and BLOCK_SIZE_M != 16:
|
||||||
|
continue
|
||||||
|
if N * 2 < BLOCK_SIZE_N and BLOCK_SIZE_N != 16:
|
||||||
|
continue
|
||||||
|
# skip large split_k when not necessary
|
||||||
|
if SPLIT_K != 1 and not need_split_k(M, N, K):
|
||||||
|
continue
|
||||||
|
# skip split_k that leads to EVEN_K = false
|
||||||
|
leap = SPLIT_K * BLOCK_SIZE_K
|
||||||
|
modv = K % leap
|
||||||
|
if modv != 0:
|
||||||
|
continue
|
||||||
|
# skip large GROUP_M
|
||||||
|
if GROUP_M * BLOCK_SIZE_M > M and GROUP_M != 1:
|
||||||
|
continue
|
||||||
|
# out of shared memory resource
|
||||||
|
# TODO (zhanglx): This does not consider the LDS usage in the epilogue
|
||||||
|
LDS = (
|
||||||
|
BLOCK_SIZE_K * BLOCK_SIZE_M * elemBytes_a
|
||||||
|
+ BLOCK_SIZE_K * BLOCK_SIZE_N * elemBytes_b
|
||||||
|
)
|
||||||
|
if LDS > 65536:
|
||||||
|
continue
|
||||||
|
# Skip small block sizes and num_warps for large gemm
|
||||||
|
# For fp16 and f8, we want to only use BLOCK_SIZE >= 64
|
||||||
|
if large_gemm:
|
||||||
|
if BLOCK_SIZE_M < 64 or BLOCK_SIZE_N < 64:
|
||||||
|
continue
|
||||||
|
if BLOCK_SIZE_K < 64:
|
||||||
|
continue
|
||||||
|
if num_warps < 4:
|
||||||
|
continue
|
||||||
|
|
||||||
|
pruned_configs.append(config)
|
||||||
|
|
||||||
|
return pruned_configs
|
||||||
|
|
||||||
|
|
||||||
|
def union_of_list_of_dicts(l1, l2):
|
||||||
|
result = []
|
||||||
|
temp_list = l1.copy()
|
||||||
|
temp_list.extend(l2)
|
||||||
|
for myDict in temp_list:
|
||||||
|
if myDict not in result:
|
||||||
|
result.append(myDict)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def run_grid(bs, model, method, tp_size, dtype: str):
|
||||||
|
|
||||||
|
config = AutoConfig.from_pretrained(model)
|
||||||
|
|
||||||
|
top_k = config.num_experts_per_tok
|
||||||
|
d_model = config.hidden_size
|
||||||
|
model_intermediate_size = config.intermediate_size
|
||||||
|
num_layers = config.num_hidden_layers
|
||||||
|
hidden_states_dtype = config.torch_dtype
|
||||||
|
|
||||||
|
if config.num_experts_per_tok:
|
||||||
|
if config.architectures[0] == "Grok1ModelForCausalLM":
|
||||||
|
num_total_experts = config.num_experts
|
||||||
|
else:
|
||||||
|
num_total_experts = config.num_local_experts
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported Mixtral model {model}")
|
||||||
|
|
||||||
|
# tp_size = 2
|
||||||
|
num_warmup_calls = 10
|
||||||
|
num_calls = 30
|
||||||
|
|
||||||
|
num_warmup_trials = 1
|
||||||
|
num_trials = 1
|
||||||
|
|
||||||
|
full_configs = []
|
||||||
|
|
||||||
|
block_m_range = [16, 32, 64, 128, 256]
|
||||||
|
block_n_range = [16, 32, 64, 128, 256]
|
||||||
|
block_k_range = [32, 64, 128, 256] # MUST >= 32
|
||||||
|
num_warps_range = [1, 2, 4, 8]
|
||||||
|
group_m_range = [1, 4, 8, 16, 32]
|
||||||
|
# For now we see better perf with num_stages=0 for all gemm configs we care
|
||||||
|
# But keep this explicit so that we do not forget we may need to set it to
|
||||||
|
# other values in the future
|
||||||
|
num_stage_range = [2]
|
||||||
|
waves_per_eu_range = [0, 1, 2, 4, 8]
|
||||||
|
# Remove 32 because of triton compiling error
|
||||||
|
matrix_instr_nonkdim_range = [16]
|
||||||
|
kpack_range = [1, 2]
|
||||||
|
|
||||||
|
for block_size_m in block_m_range:
|
||||||
|
for block_size_n in block_n_range:
|
||||||
|
for block_size_k in block_k_range:
|
||||||
|
for group_size_m in group_m_range:
|
||||||
|
for num_warps in num_warps_range:
|
||||||
|
for num_stages in num_stage_range:
|
||||||
|
for waves_per_eu in waves_per_eu_range:
|
||||||
|
for matrix_instr_nonkdim in matrix_instr_nonkdim_range:
|
||||||
|
for kpack in kpack_range:
|
||||||
|
full_configs.append(
|
||||||
|
{
|
||||||
|
"BLOCK_SIZE_M": block_size_m,
|
||||||
|
"BLOCK_SIZE_N": block_size_n,
|
||||||
|
"BLOCK_SIZE_K": block_size_k,
|
||||||
|
"GROUP_SIZE_M": group_size_m,
|
||||||
|
"num_warps": num_warps,
|
||||||
|
"num_stages": num_stages,
|
||||||
|
"waves_per_eu": waves_per_eu,
|
||||||
|
"matrix_instr_nonkdim": matrix_instr_nonkdim,
|
||||||
|
"kpack": kpack,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
M1 = bs * 2
|
||||||
|
N1 = model_intermediate_size * 2 // tp_size
|
||||||
|
K1 = d_model
|
||||||
|
prune_configs_1 = prune_configs(M1, N1, K1, full_configs)
|
||||||
|
|
||||||
|
M2 = bs * 2
|
||||||
|
N2 = d_model
|
||||||
|
K2 = model_intermediate_size // tp_size
|
||||||
|
prune_configs_2 = prune_configs(M2, N2, K2, full_configs)
|
||||||
|
|
||||||
|
configs = union_of_list_of_dicts(prune_configs_1, prune_configs_2)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"{bs=} || {len(full_configs)=} | {len(prune_configs_1)=} | \
|
||||||
|
{len(prune_configs_2)=} | {len(configs)=}"
|
||||||
|
)
|
||||||
|
|
||||||
|
best_config = None
|
||||||
|
best_time_us = 1e20
|
||||||
|
|
||||||
|
print(f"{tp_size=} {bs=}")
|
||||||
|
|
||||||
|
for config in tqdm(configs):
|
||||||
|
# warmup
|
||||||
|
try:
|
||||||
|
print(config)
|
||||||
|
for _ in range(num_warmup_trials):
|
||||||
|
run_timing(
|
||||||
|
num_calls=num_warmup_calls,
|
||||||
|
bs=bs,
|
||||||
|
d_model=d_model,
|
||||||
|
num_total_experts=num_total_experts,
|
||||||
|
top_k=top_k,
|
||||||
|
tp_size=tp_size,
|
||||||
|
model_intermediate_size=model_intermediate_size,
|
||||||
|
method=method,
|
||||||
|
config=config,
|
||||||
|
dtype=dtype,
|
||||||
|
hidden_states_dtype=hidden_states_dtype,
|
||||||
|
)
|
||||||
|
except triton.runtime.autotuner.OutOfResources:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# trial
|
||||||
|
for _ in range(num_trials):
|
||||||
|
kernel_dur_ms = run_timing(
|
||||||
|
num_calls=num_calls,
|
||||||
|
bs=bs,
|
||||||
|
d_model=d_model,
|
||||||
|
num_total_experts=num_total_experts,
|
||||||
|
top_k=top_k,
|
||||||
|
tp_size=tp_size,
|
||||||
|
model_intermediate_size=model_intermediate_size,
|
||||||
|
method=method,
|
||||||
|
config=config,
|
||||||
|
dtype=dtype,
|
||||||
|
hidden_states_dtype=hidden_states_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
kernel_dur_us = 1000 * kernel_dur_ms
|
||||||
|
model_dur_ms = kernel_dur_ms * num_layers
|
||||||
|
|
||||||
|
if kernel_dur_us < best_time_us:
|
||||||
|
best_config = config
|
||||||
|
best_time_us = kernel_dur_us
|
||||||
|
|
||||||
|
tqdm.write(
|
||||||
|
f"{kernel_dur_us=:.1f} {model_dur_ms=:.1f}"
|
||||||
|
f" {bs=} {tp_size=} {top_k=} {num_total_experts=} "
|
||||||
|
f"{d_model=} {model_intermediate_size=} {num_layers=}"
|
||||||
|
)
|
||||||
|
|
||||||
|
print("best_time_us", best_time_us)
|
||||||
|
print("best_config", best_config)
|
||||||
|
|
||||||
|
# holds Dict[str, Dict[str, int]]
|
||||||
|
filename = get_config_file_name(
|
||||||
|
num_total_experts,
|
||||||
|
model_intermediate_size // tp_size,
|
||||||
|
"float8" if dtype == "float8" else None,
|
||||||
|
)
|
||||||
|
print(f"writing config to file {filename}")
|
||||||
|
existing_content = {}
|
||||||
|
if os.path.exists(filename):
|
||||||
|
with open(filename, "r") as f:
|
||||||
|
existing_content = json.load(f)
|
||||||
|
existing_content[str(bs)] = best_config
|
||||||
|
with open(filename, "w") as f:
|
||||||
|
json.dump(existing_content, f, indent=4)
|
||||||
|
f.write("\n")
|
||||||
|
|
||||||
|
|
||||||
|
def run_timing(
|
||||||
|
num_calls: int,
|
||||||
|
bs: int,
|
||||||
|
d_model: int,
|
||||||
|
num_total_experts: int,
|
||||||
|
top_k: int,
|
||||||
|
tp_size: int,
|
||||||
|
model_intermediate_size: int,
|
||||||
|
method,
|
||||||
|
config,
|
||||||
|
dtype: str,
|
||||||
|
hidden_states_dtype,
|
||||||
|
) -> float:
|
||||||
|
shard_intermediate_size = model_intermediate_size // tp_size
|
||||||
|
|
||||||
|
hidden_states = torch.rand(
|
||||||
|
(bs, d_model),
|
||||||
|
device="cuda:0",
|
||||||
|
dtype=hidden_states_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
w1 = torch.rand(
|
||||||
|
(num_total_experts, 2 * shard_intermediate_size, d_model + padding_size),
|
||||||
|
device=hidden_states.device,
|
||||||
|
dtype=hidden_states.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
w2 = torch.rand(
|
||||||
|
(num_total_experts, d_model, shard_intermediate_size + padding_size),
|
||||||
|
device=hidden_states.device,
|
||||||
|
dtype=hidden_states.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
w1_scale = None
|
||||||
|
w2_scale = None
|
||||||
|
a1_scale = None
|
||||||
|
a2_scale = None
|
||||||
|
|
||||||
|
if dtype == "float8":
|
||||||
|
w1 = w1.to(torch.float8_e4m3fnuz)
|
||||||
|
w2 = w2.to(torch.float8_e4m3fnuz)
|
||||||
|
w1_scale = torch.ones(
|
||||||
|
num_total_experts, device=hidden_states.device, dtype=torch.float32
|
||||||
|
)
|
||||||
|
w2_scale = torch.ones(
|
||||||
|
num_total_experts, device=hidden_states.device, dtype=torch.float32
|
||||||
|
)
|
||||||
|
a1_scale = torch.ones(1, device=hidden_states.device, dtype=torch.float32)
|
||||||
|
a2_scale = torch.ones(1, device=hidden_states.device, dtype=torch.float32)
|
||||||
|
|
||||||
|
gating_output = F.softmax(
|
||||||
|
torch.rand(
|
||||||
|
(num_calls, bs, num_total_experts),
|
||||||
|
device=hidden_states.device,
|
||||||
|
dtype=torch.float32,
|
||||||
|
),
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
|
||||||
|
##################################
|
||||||
|
|
||||||
|
start_event = torch.cuda.Event(enable_timing=True)
|
||||||
|
end_event = torch.cuda.Event(enable_timing=True)
|
||||||
|
|
||||||
|
start_event.record()
|
||||||
|
for i in range(num_calls):
|
||||||
|
hidden_states = method(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
w1=w1,
|
||||||
|
w2=w2,
|
||||||
|
w1_scale=w1_scale,
|
||||||
|
w2_scale=w2_scale,
|
||||||
|
a1_scale=a1_scale,
|
||||||
|
a2_scale=a2_scale,
|
||||||
|
gating_output=gating_output[0],
|
||||||
|
topk=top_k,
|
||||||
|
renormalize=True,
|
||||||
|
inplace=True,
|
||||||
|
override_config=config,
|
||||||
|
use_fp8=dtype == "float8",
|
||||||
|
)
|
||||||
|
|
||||||
|
end_event.record()
|
||||||
|
end_event.synchronize()
|
||||||
|
|
||||||
|
dur_ms = start_event.elapsed_time(end_event) / num_calls
|
||||||
|
return dur_ms
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
prog="benchmark_mixtral_moe",
|
||||||
|
description="Benchmark and tune the fused_moe kernel",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dtype",
|
||||||
|
type=str,
|
||||||
|
default="auto",
|
||||||
|
choices=["float8", "float16", "bfloat16"],
|
||||||
|
help="Data type used for fused_moe kernel computations",
|
||||||
|
)
|
||||||
|
parser.add_argument("--model", type=str, default="hpcai-tech/grok-1")
|
||||||
|
|
||||||
|
parser.add_argument("--tp-size", type=int, default=2, help="Tensor paralleli size")
|
||||||
|
parser.add_argument("-b", "--batches", type=str)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
batches = args.batches.split(",")
|
||||||
|
|
||||||
|
sys.exit(main(args.model, args.tp_size, args.dtype, batches))
|
||||||
|
|
@ -0,0 +1,201 @@
|
||||||
|
Apache License
|
||||||
|
Version 2.0, January 2004
|
||||||
|
http://www.apache.org/licenses/
|
||||||
|
|
||||||
|
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||||
|
|
||||||
|
1. Definitions.
|
||||||
|
|
||||||
|
"License" shall mean the terms and conditions for use, reproduction,
|
||||||
|
and distribution as defined by Sections 1 through 9 of this document.
|
||||||
|
|
||||||
|
"Licensor" shall mean the copyright owner or entity authorized by
|
||||||
|
the copyright owner that is granting the License.
|
||||||
|
|
||||||
|
"Legal Entity" shall mean the union of the acting entity and all
|
||||||
|
other entities that control, are controlled by, or are under common
|
||||||
|
control with that entity. For the purposes of this definition,
|
||||||
|
"control" means (i) the power, direct or indirect, to cause the
|
||||||
|
direction or management of such entity, whether by contract or
|
||||||
|
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||||
|
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||||
|
|
||||||
|
"You" (or "Your") shall mean an individual or Legal Entity
|
||||||
|
exercising permissions granted by this License.
|
||||||
|
|
||||||
|
"Source" form shall mean the preferred form for making modifications,
|
||||||
|
including but not limited to software source code, documentation
|
||||||
|
source, and configuration files.
|
||||||
|
|
||||||
|
"Object" form shall mean any form resulting from mechanical
|
||||||
|
transformation or translation of a Source form, including but
|
||||||
|
not limited to compiled object code, generated documentation,
|
||||||
|
and conversions to other media types.
|
||||||
|
|
||||||
|
"Work" shall mean the work of authorship, whether in Source or
|
||||||
|
Object form, made available under the License, as indicated by a
|
||||||
|
copyright notice that is included in or attached to the work
|
||||||
|
(an example is provided in the Appendix below).
|
||||||
|
|
||||||
|
"Derivative Works" shall mean any work, whether in Source or Object
|
||||||
|
form, that is based on (or derived from) the Work and for which the
|
||||||
|
editorial revisions, annotations, elaborations, or other modifications
|
||||||
|
represent, as a whole, an original work of authorship. For the purposes
|
||||||
|
of this License, Derivative Works shall not include works that remain
|
||||||
|
separable from, or merely link (or bind by name) to the interfaces of,
|
||||||
|
the Work and Derivative Works thereof.
|
||||||
|
|
||||||
|
"Contribution" shall mean any work of authorship, including
|
||||||
|
the original version of the Work and any modifications or additions
|
||||||
|
to that Work or Derivative Works thereof, that is intentionally
|
||||||
|
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||||
|
or by an individual or Legal Entity authorized to submit on behalf of
|
||||||
|
the copyright owner. For the purposes of this definition, "submitted"
|
||||||
|
means any form of electronic, verbal, or written communication sent
|
||||||
|
to the Licensor or its representatives, including but not limited to
|
||||||
|
communication on electronic mailing lists, source code control systems,
|
||||||
|
and issue tracking systems that are managed by, or on behalf of, the
|
||||||
|
Licensor for the purpose of discussing and improving the Work, but
|
||||||
|
excluding communication that is conspicuously marked or otherwise
|
||||||
|
designated in writing by the copyright owner as "Not a Contribution."
|
||||||
|
|
||||||
|
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||||
|
on behalf of whom a Contribution has been received by Licensor and
|
||||||
|
subsequently incorporated within the Work.
|
||||||
|
|
||||||
|
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
copyright license to reproduce, prepare Derivative Works of,
|
||||||
|
publicly display, publicly perform, sublicense, and distribute the
|
||||||
|
Work and such Derivative Works in Source or Object form.
|
||||||
|
|
||||||
|
3. Grant of Patent License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
(except as stated in this section) patent license to make, have made,
|
||||||
|
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||||
|
where such license applies only to those patent claims licensable
|
||||||
|
by such Contributor that are necessarily infringed by their
|
||||||
|
Contribution(s) alone or by combination of their Contribution(s)
|
||||||
|
with the Work to which such Contribution(s) was submitted. If You
|
||||||
|
institute patent litigation against any entity (including a
|
||||||
|
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||||
|
or a Contribution incorporated within the Work constitutes direct
|
||||||
|
or contributory patent infringement, then any patent licenses
|
||||||
|
granted to You under this License for that Work shall terminate
|
||||||
|
as of the date such litigation is filed.
|
||||||
|
|
||||||
|
4. Redistribution. You may reproduce and distribute copies of the
|
||||||
|
Work or Derivative Works thereof in any medium, with or without
|
||||||
|
modifications, and in Source or Object form, provided that You
|
||||||
|
meet the following conditions:
|
||||||
|
|
||||||
|
(a) You must give any other recipients of the Work or
|
||||||
|
Derivative Works a copy of this License; and
|
||||||
|
|
||||||
|
(b) You must cause any modified files to carry prominent notices
|
||||||
|
stating that You changed the files; and
|
||||||
|
|
||||||
|
(c) You must retain, in the Source form of any Derivative Works
|
||||||
|
that You distribute, all copyright, patent, trademark, and
|
||||||
|
attribution notices from the Source form of the Work,
|
||||||
|
excluding those notices that do not pertain to any part of
|
||||||
|
the Derivative Works; and
|
||||||
|
|
||||||
|
(d) If the Work includes a "NOTICE" text file as part of its
|
||||||
|
distribution, then any Derivative Works that You distribute must
|
||||||
|
include a readable copy of the attribution notices contained
|
||||||
|
within such NOTICE file, excluding those notices that do not
|
||||||
|
pertain to any part of the Derivative Works, in at least one
|
||||||
|
of the following places: within a NOTICE text file distributed
|
||||||
|
as part of the Derivative Works; within the Source form or
|
||||||
|
documentation, if provided along with the Derivative Works; or,
|
||||||
|
within a display generated by the Derivative Works, if and
|
||||||
|
wherever such third-party notices normally appear. The contents
|
||||||
|
of the NOTICE file are for informational purposes only and
|
||||||
|
do not modify the License. You may add Your own attribution
|
||||||
|
notices within Derivative Works that You distribute, alongside
|
||||||
|
or as an addendum to the NOTICE text from the Work, provided
|
||||||
|
that such additional attribution notices cannot be construed
|
||||||
|
as modifying the License.
|
||||||
|
|
||||||
|
You may add Your own copyright statement to Your modifications and
|
||||||
|
may provide additional or different license terms and conditions
|
||||||
|
for use, reproduction, or distribution of Your modifications, or
|
||||||
|
for any such Derivative Works as a whole, provided Your use,
|
||||||
|
reproduction, and distribution of the Work otherwise complies with
|
||||||
|
the conditions stated in this License.
|
||||||
|
|
||||||
|
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||||
|
any Contribution intentionally submitted for inclusion in the Work
|
||||||
|
by You to the Licensor shall be under the terms and conditions of
|
||||||
|
this License, without any additional terms or conditions.
|
||||||
|
Notwithstanding the above, nothing herein shall supersede or modify
|
||||||
|
the terms of any separate license agreement you may have executed
|
||||||
|
with Licensor regarding such Contributions.
|
||||||
|
|
||||||
|
6. Trademarks. This License does not grant permission to use the trade
|
||||||
|
names, trademarks, service marks, or product names of the Licensor,
|
||||||
|
except as required for reasonable and customary use in describing the
|
||||||
|
origin of the Work and reproducing the content of the NOTICE file.
|
||||||
|
|
||||||
|
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||||
|
agreed to in writing, Licensor provides the Work (and each
|
||||||
|
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||||
|
implied, including, without limitation, any warranties or conditions
|
||||||
|
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||||
|
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||||
|
appropriateness of using or redistributing the Work and assume any
|
||||||
|
risks associated with Your exercise of permissions under this License.
|
||||||
|
|
||||||
|
8. Limitation of Liability. In no event and under no legal theory,
|
||||||
|
whether in tort (including negligence), contract, or otherwise,
|
||||||
|
unless required by applicable law (such as deliberate and grossly
|
||||||
|
negligent acts) or agreed to in writing, shall any Contributor be
|
||||||
|
liable to You for damages, including any direct, indirect, special,
|
||||||
|
incidental, or consequential damages of any character arising as a
|
||||||
|
result of this License or out of the use or inability to use the
|
||||||
|
Work (including but not limited to damages for loss of goodwill,
|
||||||
|
work stoppage, computer failure or malfunction, or any and all
|
||||||
|
other commercial damages or losses), even if such Contributor
|
||||||
|
has been advised of the possibility of such damages.
|
||||||
|
|
||||||
|
9. Accepting Warranty or Additional Liability. While redistributing
|
||||||
|
the Work or Derivative Works thereof, You may choose to offer,
|
||||||
|
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||||
|
or other liability obligations and/or rights consistent with this
|
||||||
|
License. However, in accepting such obligations, You may act only
|
||||||
|
on Your own behalf and on Your sole responsibility, not on behalf
|
||||||
|
of any other Contributor, and only if You agree to indemnify,
|
||||||
|
defend, and hold each Contributor harmless for any liability
|
||||||
|
incurred by, or claims asserted against, such Contributor by reason
|
||||||
|
of your accepting any such warranty or additional liability.
|
||||||
|
|
||||||
|
END OF TERMS AND CONDITIONS
|
||||||
|
|
||||||
|
APPENDIX: How to apply the Apache License to your work.
|
||||||
|
|
||||||
|
To apply the Apache License to your work, attach the following
|
||||||
|
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||||
|
replaced with your own identifying information. (Don't include
|
||||||
|
the brackets!) The text should be enclosed in the appropriate
|
||||||
|
comment syntax for the file format. We also recommend that a
|
||||||
|
file or class name and description of purpose be included on the
|
||||||
|
same "printed page" as the copyright notice for easier
|
||||||
|
identification within third-party archives.
|
||||||
|
|
||||||
|
Copyright 2023-2024 SGLang Team
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
|
@ -0,0 +1,43 @@
|
||||||
|
.PHONY: check-deps install-deps format update help
|
||||||
|
|
||||||
|
# Show help for each target
|
||||||
|
help:
|
||||||
|
@echo "Available targets:"
|
||||||
|
@grep -E '^[a-zA-Z0-9_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-20s\033[0m %s\n", $$1, $$2}'
|
||||||
|
|
||||||
|
check-deps: ## Check and install required Python formatting dependencies
|
||||||
|
@command -v isort >/dev/null 2>&1 || (echo "Installing isort..." && pip install isort)
|
||||||
|
@command -v black >/dev/null 2>&1 || (echo "Installing black..." && pip install black)
|
||||||
|
|
||||||
|
install-deps: ## Install Python formatting tools (isort and black)
|
||||||
|
pip install isort black
|
||||||
|
|
||||||
|
format: check-deps ## Format modified Python files using isort and black
|
||||||
|
@echo "Formatting modified Python files..."
|
||||||
|
git diff --name-only --diff-filter=M | grep '\.py$$' | xargs -I {} sh -c 'isort {} && black {}'
|
||||||
|
|
||||||
|
FILES_TO_UPDATE = docker/Dockerfile.rocm \
|
||||||
|
python/pyproject.toml \
|
||||||
|
python/sglang/version.py \
|
||||||
|
docs/developer/setup_github_runner.md \
|
||||||
|
docs/start/install.md
|
||||||
|
|
||||||
|
update: ## Update version numbers across project files. Usage: make update <new_version>
|
||||||
|
@if [ -z "$(filter-out $@,$(MAKECMDGOALS))" ]; then \
|
||||||
|
echo "Version required. Usage: make update <new_version>"; \
|
||||||
|
exit 1; \
|
||||||
|
fi
|
||||||
|
@OLD_VERSION=$$(grep "version" python/sglang/version.py | cut -d '"' -f2); \
|
||||||
|
NEW_VERSION=$(filter-out $@,$(MAKECMDGOALS)); \
|
||||||
|
echo "Updating version from $$OLD_VERSION to $$NEW_VERSION"; \
|
||||||
|
for file in $(FILES_TO_UPDATE); do \
|
||||||
|
if [ "$(shell uname)" = "Darwin" ]; then \
|
||||||
|
sed -i '' -e "s/$$OLD_VERSION/$$NEW_VERSION/g" $$file; \
|
||||||
|
else \
|
||||||
|
sed -i -e "s/$$OLD_VERSION/$$NEW_VERSION/g" $$file; \
|
||||||
|
fi \
|
||||||
|
done; \
|
||||||
|
echo "Version update complete"
|
||||||
|
|
||||||
|
%:
|
||||||
|
@:
|
||||||
|
|
@ -0,0 +1,75 @@
|
||||||
|
<div align="center" id="sglangtop">
|
||||||
|
<img src="https://raw.githubusercontent.com/sgl-project/sglang/main/assets/logo.png" alt="logo" width="400" margin="10px"></img>
|
||||||
|
|
||||||
|
[](https://pypi.org/project/sglang)
|
||||||
|

|
||||||
|
[](https://github.com/sgl-project/sglang/tree/main/LICENSE)
|
||||||
|
[](https://github.com/sgl-project/sglang/issues)
|
||||||
|
[](https://github.com/sgl-project/sglang/issues)
|
||||||
|
[-006BFF)](https://gurubase.io/g/sglang)
|
||||||
|
|
||||||
|
</div>
|
||||||
|
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
| [**Blog**](https://lmsys.org/blog/2024-07-25-sglang-llama3/)
|
||||||
|
| [**Documentation**](https://docs.sglang.ai/)
|
||||||
|
| [**Join Slack**](https://slack.sglang.ai/)
|
||||||
|
| [**Join Bi-Weekly Development Meeting**](https://meeting.sglang.ai/)
|
||||||
|
| [**Roadmap**](https://github.com/sgl-project/sglang/issues/4042)
|
||||||
|
| [**Slides**](https://github.com/sgl-project/sgl-learning-materials?tab=readme-ov-file#slides) |
|
||||||
|
|
||||||
|
## News
|
||||||
|
- [2025/03] Supercharge DeepSeek-R1 Inference on AMD Instinct MI300X ([AMD blog](https://rocm.blogs.amd.com/artificial-intelligence/DeepSeekR1-Part2/README.html))
|
||||||
|
- [2025/03] SGLang Joins PyTorch Ecosystem: Efficient LLM Serving Engine ([PyTorch blog](https://pytorch.org/blog/sglang-joins-pytorch/))
|
||||||
|
- [2025/02] Unlock DeepSeek-R1 Inference Performance on AMD Instinct™ MI300X GPU ([AMD blog](https://rocm.blogs.amd.com/artificial-intelligence/DeepSeekR1_Perf/README.html))
|
||||||
|
- [2025/01] 🔥 SGLang provides day one support for DeepSeek V3/R1 models on NVIDIA and AMD GPUs with DeepSeek-specific optimizations. ([instructions](https://github.com/sgl-project/sglang/tree/main/benchmark/deepseek_v3), [AMD blog](https://www.amd.com/en/developer/resources/technical-articles/amd-instinct-gpus-power-deepseek-v3-revolutionizing-ai-development-with-sglang.html), [10+ other companies](https://x.com/lmsysorg/status/1887262321636221412))
|
||||||
|
- [2024/12] 🔥 v0.4 Release: Zero-Overhead Batch Scheduler, Cache-Aware Load Balancer, Faster Structured Outputs ([blog](https://lmsys.org/blog/2024-12-04-sglang-v0-4/)).
|
||||||
|
- [2024/09] v0.3 Release: 7x Faster DeepSeek MLA, 1.5x Faster torch.compile, Multi-Image/Video LLaVA-OneVision ([blog](https://lmsys.org/blog/2024-09-04-sglang-v0-3/)).
|
||||||
|
- [2024/07] v0.2 Release: Faster Llama3 Serving with SGLang Runtime (vs. TensorRT-LLM, vLLM) ([blog](https://lmsys.org/blog/2024-07-25-sglang-llama3/)).
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary>More</summary>
|
||||||
|
|
||||||
|
- [2024/10] The First SGLang Online Meetup ([slides](https://github.com/sgl-project/sgl-learning-materials?tab=readme-ov-file#the-first-sglang-online-meetup)).
|
||||||
|
- [2024/02] SGLang enables **3x faster JSON decoding** with compressed finite state machine ([blog](https://lmsys.org/blog/2024-02-05-compressed-fsm/)).
|
||||||
|
- [2024/01] SGLang provides up to **5x faster inference** with RadixAttention ([blog](https://lmsys.org/blog/2024-01-17-sglang/)).
|
||||||
|
- [2024/01] SGLang powers the serving of the official **LLaVA v1.6** release demo ([usage](https://github.com/haotian-liu/LLaVA?tab=readme-ov-file#demo)).
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
## About
|
||||||
|
SGLang is a fast serving framework for large language models and vision language models.
|
||||||
|
It makes your interaction with models faster and more controllable by co-designing the backend runtime and frontend language.
|
||||||
|
The core features include:
|
||||||
|
|
||||||
|
- **Fast Backend Runtime**: Provides efficient serving with RadixAttention for prefix caching, zero-overhead CPU scheduler, continuous batching, token attention (paged attention), speculative decoding, tensor parallelism, chunked prefill, structured outputs, and quantization (FP8/INT4/AWQ/GPTQ).
|
||||||
|
- **Flexible Frontend Language**: Offers an intuitive interface for programming LLM applications, including chained generation calls, advanced prompting, control flow, multi-modal inputs, parallelism, and external interactions.
|
||||||
|
- **Extensive Model Support**: Supports a wide range of generative models (Llama, Gemma, Mistral, QWen, DeepSeek, LLaVA, etc.), embedding models (e5-mistral, gte, mcdse) and reward models (Skywork), with easy extensibility for integrating new models.
|
||||||
|
- **Active Community**: SGLang is open-source and backed by an active community with industry adoption.
|
||||||
|
|
||||||
|
## Getting Started
|
||||||
|
- [Install SGLang](https://docs.sglang.ai/start/install.html)
|
||||||
|
- [Quick Start](https://docs.sglang.ai/backend/send_request.html)
|
||||||
|
- [Backend Tutorial](https://docs.sglang.ai/backend/openai_api_completions.html)
|
||||||
|
- [Frontend Tutorial](https://docs.sglang.ai/frontend/frontend.html)
|
||||||
|
- [Contribution Guide](https://docs.sglang.ai/references/contribution_guide.html)
|
||||||
|
|
||||||
|
## Benchmark and Performance
|
||||||
|
Learn more in the release blogs: [v0.2 blog](https://lmsys.org/blog/2024-07-25-sglang-llama3/), [v0.3 blog](https://lmsys.org/blog/2024-09-04-sglang-v0-3/), [v0.4 blog](https://lmsys.org/blog/2024-12-04-sglang-v0-4/)
|
||||||
|
|
||||||
|
## Roadmap
|
||||||
|
[Development Roadmap (2025 H1)](https://github.com/sgl-project/sglang/issues/4042)
|
||||||
|
|
||||||
|
## Adoption and Sponsorship
|
||||||
|
The project has been deployed to large-scale production, generating trillions of tokens every day.
|
||||||
|
It is supported by the following institutions: AMD, Atlas Cloud, Baseten, Cursor, DataCrunch, Etched, Hyperbolic, Iflytek, Jam & Tea Studios, LinkedIn, LMSYS, Meituan, Nebius, Novita AI, NVIDIA, RunPod, Stanford, UC Berkeley, UCLA, xAI, and 01.AI.
|
||||||
|
|
||||||
|
<img src="https://raw.githubusercontent.com/sgl-project/sgl-learning-materials/main/slides/adoption.png" alt="logo" width="800" margin="10px"></img>
|
||||||
|
|
||||||
|
## Contact Us
|
||||||
|
|
||||||
|
For enterprises interested in adopting or deploying SGLang at scale, including technical consulting, sponsorship opportunities, or partnership inquiries, please contact us at contact@sglang.ai.
|
||||||
|
|
||||||
|
## Acknowledgment and Citation
|
||||||
|
We learned the design and reused code from the following projects: [Guidance](https://github.com/guidance-ai/guidance), [vLLM](https://github.com/vllm-project/vllm), [LightLLM](https://github.com/ModelTC/lightllm), [FlashInfer](https://github.com/flashinfer-ai/flashinfer), [Outlines](https://github.com/outlines-dev/outlines), and [LMQL](https://github.com/eth-sri/lmql). Please cite the paper, [SGLang: Efficient Execution of Structured Language Model Programs](https://arxiv.org/abs/2312.07104), if you find the project useful.
|
||||||
Binary file not shown.
|
After Width: | Height: | Size: 393 KiB |
|
|
@ -0,0 +1,130 @@
|
||||||
|
# Benchmark with lots of common prefixes. Used to benchmark prefix caching performance.
|
||||||
|
#
|
||||||
|
# Launch a server:
|
||||||
|
# python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --log-level-http warning
|
||||||
|
|
||||||
|
import random
|
||||||
|
import string
|
||||||
|
import time
|
||||||
|
|
||||||
|
from tqdm import tqdm
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
import sglang as sgl
|
||||||
|
from sglang import set_default_backend
|
||||||
|
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
|
||||||
|
|
||||||
|
|
||||||
|
def generate_random_string(token_length: int) -> str:
|
||||||
|
random_string = "".join(
|
||||||
|
random.choices(string.ascii_letters + string.digits, k=token_length * 100)
|
||||||
|
)
|
||||||
|
tokenized_output = tokenizer.encode(random_string, add_special_tokens=False)[
|
||||||
|
:token_length
|
||||||
|
]
|
||||||
|
|
||||||
|
if len(tokenized_output) < token_length:
|
||||||
|
tokenized_output = tokenized_output + [tokenizer.pad_token_id] * (
|
||||||
|
token_length - len(tokenized_output)
|
||||||
|
)
|
||||||
|
|
||||||
|
decoded_string = tokenizer.decode(tokenized_output, skip_special_tokens=False)
|
||||||
|
return decoded_string
|
||||||
|
|
||||||
|
|
||||||
|
def generate_unique_prefix(base_text, index):
|
||||||
|
return str(index) + base_text[len(str(index)) :]
|
||||||
|
|
||||||
|
|
||||||
|
@sgl.function
|
||||||
|
def text_qa(s, question, gen_len):
|
||||||
|
s += "Q: " + question + "\n"
|
||||||
|
s += "A:" + sgl.gen("answer", stop="\n", temperature=0, max_tokens=gen_len)
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_prompts(num_prefix, num_samples_per_prefix, prefix_length, suffix_length):
|
||||||
|
base_prefix = generate_random_string(prefix_length)
|
||||||
|
|
||||||
|
tot_input_len = 0
|
||||||
|
all_prompts = []
|
||||||
|
for i in tqdm(range(num_prefix), desc="prepare prompts"):
|
||||||
|
unique_prefix = generate_unique_prefix(base_prefix, i)
|
||||||
|
prompt_list = []
|
||||||
|
for j in range(num_samples_per_prefix):
|
||||||
|
suffix = generate_random_string(suffix_length)
|
||||||
|
prompt = unique_prefix + suffix
|
||||||
|
prompt_list.append(prompt)
|
||||||
|
tot_input_len += len(tokenizer.encode(prompt))
|
||||||
|
all_prompts.append(prompt_list)
|
||||||
|
return all_prompts, tot_input_len
|
||||||
|
|
||||||
|
|
||||||
|
def test_batch_by_batch(all_prompts, gen_len):
|
||||||
|
backend.flush_cache()
|
||||||
|
|
||||||
|
tot_time = 0
|
||||||
|
for i in range(len(all_prompts)):
|
||||||
|
tic = time.time()
|
||||||
|
text_qa.run_batch(
|
||||||
|
list(zip(all_prompts[i], [gen_len] * len(all_prompts[i]))),
|
||||||
|
)
|
||||||
|
tot_time += time.time() - tic
|
||||||
|
|
||||||
|
return tot_time
|
||||||
|
|
||||||
|
|
||||||
|
def test_batch_by_batch_with_hint(all_prompts, gen_len):
|
||||||
|
backend.flush_cache()
|
||||||
|
|
||||||
|
tot_time = 0
|
||||||
|
for i in range(len(all_prompts)):
|
||||||
|
tic = time.time()
|
||||||
|
# Send a hint to cache the prefix
|
||||||
|
text_qa.run_batch(list(zip(all_prompts[i][:1], [gen_len])))
|
||||||
|
# Send the batch
|
||||||
|
text_qa.run_batch(list(zip(all_prompts[i], [gen_len] * len(all_prompts[i]))))
|
||||||
|
|
||||||
|
tot_time += time.time() - tic
|
||||||
|
|
||||||
|
return tot_time
|
||||||
|
|
||||||
|
|
||||||
|
def test_send_all(all_prompts, gen_len):
|
||||||
|
backend.flush_cache()
|
||||||
|
|
||||||
|
all_prompts = [x for prompt_list in all_prompts for x in prompt_list]
|
||||||
|
|
||||||
|
tic = time.time()
|
||||||
|
text_qa.run_batch(
|
||||||
|
list(zip(all_prompts, [gen_len] * len(all_prompts))),
|
||||||
|
)
|
||||||
|
tot_time = time.time() - tic
|
||||||
|
|
||||||
|
return tot_time
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
|
||||||
|
backend = RuntimeEndpoint("http://127.0.0.1:30000")
|
||||||
|
set_default_backend(backend)
|
||||||
|
|
||||||
|
random.seed(0)
|
||||||
|
num_prefix = 10
|
||||||
|
num_samples_per_prefix = 32
|
||||||
|
prefix_length = 1024
|
||||||
|
suffix_length = 128
|
||||||
|
gen_len = 1
|
||||||
|
all_prompts, tot_input_len = prepare_prompts(
|
||||||
|
num_prefix, num_samples_per_prefix, prefix_length, suffix_length
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Total input token length: {tot_input_len}\n")
|
||||||
|
|
||||||
|
cost = test_batch_by_batch(all_prompts, gen_len)
|
||||||
|
print(f"Latency of test_batch_by_batch : {cost:.4f} s\n")
|
||||||
|
|
||||||
|
cost = test_batch_by_batch_with_hint(all_prompts, gen_len)
|
||||||
|
print(f"Latency of test_batch_by_batch_with_hint: {cost:.4f} s\n")
|
||||||
|
|
||||||
|
cost = test_send_all(all_prompts, gen_len)
|
||||||
|
print(f"Latency of test_send_all : {cost:.4f} s\n")
|
||||||
|
|
@ -0,0 +1,89 @@
|
||||||
|
## How to reproduce the benchmark results for SGLang v0.3.0 compared to vLLM v0.6.0
|
||||||
|
|
||||||
|
In short, with multi step enabled, in online scenarios that we benchmarked, the Median TTFT of vLLM is **3 times** that of SGLang, and the Median ITL is **10 times** that of SGLang. Lower Median TTFT and ITL are better. vLLM's multi-step optimization did not improve throughput while ensuring lower Median TTFT and ITL. Also, under maximum throughput benchmark, if vLLM does not set gpu util to 0.95 separately and uses the default configuration instead, its maximum throughput is **lower** than that of SGLang.
|
||||||
|
|
||||||
|
## Online benchmark results
|
||||||
|
|
||||||
|
### Llama 3.1 8B Instruct 1 x A100 80G
|
||||||
|
|
||||||
|
| RPS | Num prompts | Engine | Median E2E Latency | Median TTFT | Median TPOT | Median ITL |
|
||||||
|
|------|-------------|--------|--------------------|-------------|-------------|------------|
|
||||||
|
| 4 | 1200 | SGLang | 1564.17 | **31.98** | 13.17 | **11.93** |
|
||||||
|
| 4 | 1200 | vLLM | 1691.97 | **100.48** | 14.14 | **129.32** |
|
||||||
|
| 8 | 2400 | SGLang | 2175.02 | **35.68** | 17.85 | **14.41** |
|
||||||
|
| 8 | 2400 | vLLM | 2137.16 | **120.39** | 17.09 | **158.63** |
|
||||||
|
|
||||||
|
### Llama 3.1 70B Insruct 4 x H100 80G
|
||||||
|
|
||||||
|
| RPS | Num Prompts | Engine | Median E2E Latency | Median TTFT | Median TPOT | Median ITL |
|
||||||
|
|------|-------------|--------|--------------------|-------------|-------------|------------|
|
||||||
|
| 4 | 1200 | SGLang | 3005.24 | **53.94** | 25.03 | **21.67** |
|
||||||
|
| 4 | 1200 | vLLM | 2915.60 | **179.15** | 23.58 | **231.23** |
|
||||||
|
| 8 | 2400 | SGLang | 4064.98 | **58.11** | 33.07 | **24.45** |
|
||||||
|
| 8 | 2400 | vLLM | 3752.38 | **207.12** | 29.15 | **275.32** |
|
||||||
|
|
||||||
|
## Offline benchmark results
|
||||||
|
|
||||||
|
### Llama 3.1 8B Instruct 1 x A100 80G
|
||||||
|
|
||||||
|
| RPS | Num Prompts | Engine | Request throughput | Output token throughput |
|
||||||
|
|------|-------------|--------|--------------------|-------------------------|
|
||||||
|
| inf | 5000 | SGLang | 22.03 | **4281.51** |
|
||||||
|
| inf | 5000 | vLLM | 21.27 | **4132.37** |
|
||||||
|
|
||||||
|
### Llama 3.1 70B Insruct 4 x H100 80G
|
||||||
|
|
||||||
|
| RPS | Num Prompts | Engine | Request throughput | Output token throughput |
|
||||||
|
|------|-------------|--------|--------------------|-------------------------|
|
||||||
|
| inf | 5000 | SGLang | 19.84 | **3856.01** |
|
||||||
|
| inf | 5000 | vLLM | 19.04 | **3700.64** |
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# install sglang v0.3.0
|
||||||
|
pip install --upgrade pip
|
||||||
|
pip install "sglang[all]"==0.3.0
|
||||||
|
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/
|
||||||
|
|
||||||
|
# install vllm v0.6.0
|
||||||
|
pip install vllm==0.6.0
|
||||||
|
```
|
||||||
|
|
||||||
|
## Notes
|
||||||
|
|
||||||
|
We referred to the reproduction method in https://github.com/vllm-project/vllm/issues/8176, and added the `--num-scheduler-steps 10` parameter when starting the vLLM server. The `gpu_memory_utilization` of vLLM is by default 0.9 at both TP 1 and TP 4, while SGLang's `mem_frac` is 0.88 at TP 1 and 0.85 at TP 4, so we manually set it to 0.88 at TP 4.
|
||||||
|
|
||||||
|
## Online benchmarks
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Llama 3.1 8B Instruct on 1 x A100
|
||||||
|
python -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --enable-torch-compile --disable-radix-cache
|
||||||
|
python -m vllm.entrypoints.openai.api_server --model meta-llama/Llama-3.1-8B-Instruct --disable-log-requests --num-scheduler-steps 10 --max_model_len 4096
|
||||||
|
|
||||||
|
# Llama 3.1 70B Instruct on 4 x H100
|
||||||
|
python -m sglang.launch_server --model-path meta-llama/Llama-3.1-70B-Instruct --disable-radix-cache --tp 4
|
||||||
|
python -m vllm.entrypoints.openai.api_server --model meta-llama/Llama-3.1-70B-Instruct --disable-log-requests --num-scheduler-steps 10 --tensor 4 --max_model_len 4096
|
||||||
|
|
||||||
|
# bench serving
|
||||||
|
python3 -m sglang.bench_serving --backend sglang --dataset-name sharegpt --num-prompts 1200 --request-rate 4
|
||||||
|
python3 -m sglang.bench_serving --backend sglang --dataset-name sharegpt --num-prompts 2400 --request-rate 8
|
||||||
|
python3 -m sglang.bench_serving --backend vllm --dataset-name sharegpt --num-prompts 1200 --request-rate 4
|
||||||
|
python3 -m sglang.bench_serving --backend vllm --dataset-name sharegpt --num-prompts 2400 --request-rate 8
|
||||||
|
```
|
||||||
|
|
||||||
|
## Offline benchmarks
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Llama 3.1 8B Instruct on 1 x A100
|
||||||
|
python -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --enable-torch-compile --disable-radix-cache
|
||||||
|
python -m vllm.entrypoints.openai.api_server --model meta-llama/Llama-3.1-8B-Instruct --disable-log-requests --num-scheduler-steps 10 --max_model_len 4096
|
||||||
|
|
||||||
|
# Llama 3.1 70B Instruct on 4 x H100
|
||||||
|
python -m sglang.launch_server --model-path meta-llama/Llama-3.1-70B-Instruct --disable-radix-cache --tp 4 --mem-frac 0.88
|
||||||
|
python -m vllm.entrypoints.openai.api_server --model meta-llama/Llama-3.1-70B-Instruct --disable-log-requests --num-scheduler-steps 10 --tensor 4 --max_model_len 4096
|
||||||
|
|
||||||
|
# bench serving
|
||||||
|
python3 -m sglang.bench_serving --backend sglang --dataset-name sharegpt --num-prompts 5000
|
||||||
|
python3 -m sglang.bench_serving --backend vllm --dataset-name sharegpt --num-prompts 5000
|
||||||
|
```
|
||||||
|
|
@ -0,0 +1,24 @@
|
||||||
|
# Create dummy weights:
|
||||||
|
# 1. Create a folder `~/llama-3.1-405b-fp8-dummy` and create `config.json` and tokenizer under this folder.
|
||||||
|
# 2. Get `config.json`` from ./config.md
|
||||||
|
# 3. Download the tokenizer
|
||||||
|
# wget https://huggingface.co/neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8/resolve/main/tokenizer.json
|
||||||
|
# wget https://huggingface.co/neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8/resolve/main/tokenizer_config.json
|
||||||
|
|
||||||
|
# Launch sglang
|
||||||
|
# python -m sglang.launch_server --model-path ~/llama-3.1-405b-fp8-dummy/ --load-format dummy --tp 8 --quantization fp8 --disable-radix --mem-frac 0.87
|
||||||
|
|
||||||
|
# offline
|
||||||
|
python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompt 3000 --random-input 1024 --random-output 1024 > sglang_log11
|
||||||
|
python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompt 4000 --random-input 1024 --random-output 512 > sglang_log12
|
||||||
|
python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompt 800 --random-input 4096 --random-output 2048 > sglang_log13
|
||||||
|
python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompt 1500 --random-input 4096 --random-output 1024 > sglang_log14
|
||||||
|
python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompt 6000 --random-input 256 --random-output 512 > sglang_log15
|
||||||
|
python3 -m sglang.bench_serving --backend sglang --dataset-name sharegpt --num-prompt 2000 > sglang_log21
|
||||||
|
|
||||||
|
# online
|
||||||
|
python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompt 300 --request-rate 1 --random-input 1024 --random-output 1024 > sglang_log31
|
||||||
|
python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompt 600 --request-rate 2 --random-input 1024 --random-output 1024 > sglang_log32
|
||||||
|
python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompt 1200 --request-rate 4 --random-input 1024 --random-output 1024 > sglang_log33
|
||||||
|
python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompt 2400 --request-rate 8 --random-input 1024 --random-output 1024 > sglang_log34
|
||||||
|
python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompt 3200 --request-rate 16 --random-input 1024 --random-output 1024 > sglang_log35
|
||||||
|
|
@ -0,0 +1,17 @@
|
||||||
|
# Launch trtllm
|
||||||
|
# https://github.com/sgl-project/tensorrt-demo
|
||||||
|
|
||||||
|
# offline
|
||||||
|
python3 ../../python/sglang/bench_serving.py --backend trt --dataset-name random --num-prompt 3000 --random-input 1024 --random-output 1024 --model /root/Meta-Llama-3-8B-Instruct > trtllm_log11
|
||||||
|
python3 ../../python/sglang/bench_serving.py --backend trt --dataset-name random --num-prompt 4000 --random-input 1024 --random-output 512 --model /root/Meta-Llama-3-8B-Instruct > trtllm_log12
|
||||||
|
python3 ../../python/sglang/bench_serving.py --backend trt --dataset-name random --num-prompt 800 --random-input 4096 --random-output 2048 --model /root/Meta-Llama-3-8B-Instruct > trtllm_log13
|
||||||
|
python3 ../../python/sglang/bench_serving.py --backend trt --dataset-name random --num-prompt 1500 --random-input 4096 --random-output 1024 --model /root/Meta-Llama-3-8B-Instruct > trtllm_log14
|
||||||
|
python3 ../../python/sglang/bench_serving.py --backend trt --dataset-name random --num-prompt 6000 --random-input 256 --random-output 512 --model /root/Meta-Llama-3-8B-Instruct > trtllm_log15
|
||||||
|
python3 ../../python/sglang/bench_serving.py --backend trt --dataset-name sharegpt --num-prompt 2000 --model /root/Meta-Llama-3-8B-Instruct > trtllm_log21
|
||||||
|
|
||||||
|
# online
|
||||||
|
python3 ../../python/sglang/bench_serving.py --backend trt --dataset-name random --num-prompt 300 --request-rate 1 --random-input 1024 --random-output 1024 --model /root/Meta-Llama-3-8B-Instruct > trtllm_log31
|
||||||
|
python3 ../../python/sglang/bench_serving.py --backend trt --dataset-name random --num-prompt 600 --request-rate 2 --random-input 1024 --random-output 1024 --model /root/Meta-Llama-3-8B-Instruct > trtllm_log32
|
||||||
|
python3 ../../python/sglang/bench_serving.py --backend trt --dataset-name random --num-prompt 1200 --request-rate 4 --random-input 1024 --random-output 1024 --model /root/Meta-Llama-3-8B-Instruct > trtllm_log33
|
||||||
|
python3 ../../python/sglang/bench_serving.py --backend trt --dataset-name random --num-prompt 2400 --request-rate 8 --random-input 1024 --random-output 1024 --model /root/Meta-Llama-3-8B-Instruct > trtllm_log34
|
||||||
|
python3 ../../python/sglang/bench_serving.py --backend trt --dataset-name random --num-prompt 3200 --request-rate 16 --random-input 1024 --random-output 1024 --model /root/Meta-Llama-3-8B-Instruct > trtllm_log35
|
||||||
|
|
@ -0,0 +1,24 @@
|
||||||
|
# Create dummy weights:
|
||||||
|
# 1. Create a folder `~/llama-3.1-405b-fp8-dummy` and create `config.json` and tokenizer under this folder.
|
||||||
|
# 2. Get `config.json`` from ./config.md
|
||||||
|
# 3. Download the tokenizer
|
||||||
|
# wget https://huggingface.co/neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8/resolve/main/tokenizer.json
|
||||||
|
# wget https://huggingface.co/neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8/resolve/main/tokenizer_config.json
|
||||||
|
|
||||||
|
# Launch vllm
|
||||||
|
# python3 -m vllm.entrypoints.openai.api_server --model ~/llama-3.1-405b-fp8-dummy/ --load-format dummy --disable-log-requests --tensor-parallel-size 8 --max-model-len 10000
|
||||||
|
|
||||||
|
# offline
|
||||||
|
python3 ../../python/sglang/bench_serving.py --backend vllm --dataset-name random --num-prompt 3000 --random-input 1024 --random-output 1024 > vllm_log11
|
||||||
|
python3 ../../python/sglang/bench_serving.py --backend vllm --dataset-name random --num-prompt 4000 --random-input 1024 --random-output 512 > vllm_log12
|
||||||
|
python3 ../../python/sglang/bench_serving.py --backend vllm --dataset-name random --num-prompt 800 --random-input 4096 --random-output 2048 > vllm_log13
|
||||||
|
python3 ../../python/sglang/bench_serving.py --backend vllm --dataset-name random --num-prompt 1500 --random-input 4096 --random-output 1024 > vllm_log14
|
||||||
|
python3 ../../python/sglang/bench_serving.py --backend vllm --dataset-name random --num-prompt 6000 --random-input 256 --random-output 512 > vllm_log15
|
||||||
|
python3 ../../python/sglang/bench_serving.py --backend vllm --dataset-name sharegpt --num-prompt 2000 > vllm_log21
|
||||||
|
|
||||||
|
# online
|
||||||
|
python3 ../../python/sglang/bench_serving.py --backend vllm --dataset-name random --num-prompt 300 --request-rate 1 --random-input 1024 --random-output 1024 > vllm_log31
|
||||||
|
python3 ../../python/sglang/bench_serving.py --backend vllm --dataset-name random --num-prompt 600 --request-rate 2 --random-input 1024 --random-output 1024 > vllm_log32
|
||||||
|
python3 ../../python/sglang/bench_serving.py --backend vllm --dataset-name random --num-prompt 1200 --request-rate 4 --random-input 1024 --random-output 1024 > vllm_log33
|
||||||
|
python3 ../../python/sglang/bench_serving.py --backend vllm --dataset-name random --num-prompt 2400 --request-rate 8 --random-input 1024 --random-output 1024 > vllm_log34
|
||||||
|
python3 ../../python/sglang/bench_serving.py --backend vllm --dataset-name random --num-prompt 3200 --request-rate 16 --random-input 1024 --random-output 1024 > vllm_log35
|
||||||
|
|
@ -0,0 +1,164 @@
|
||||||
|
# How to reproduce the benchmark results of SGLang
|
||||||
|
|
||||||
|
## Prerequisite
|
||||||
|
|
||||||
|
### Install the latest SGLang
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git clone https://github.com/sgl-project/sglang.git
|
||||||
|
cd sglang
|
||||||
|
git checkout v0.2.7
|
||||||
|
|
||||||
|
pip install --upgrade pip
|
||||||
|
pip install -e "python[all]"
|
||||||
|
|
||||||
|
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.3/
|
||||||
|
```
|
||||||
|
|
||||||
|
### Set up ulimit and HF_TOKEN
|
||||||
|
|
||||||
|
```bash
|
||||||
|
ulimit -n 65535
|
||||||
|
# Change the token to a real and usable one, with access permissions for the Llama 3 models.
|
||||||
|
export HF_TOKEN=hf_token
|
||||||
|
```
|
||||||
|
|
||||||
|
### Launch the server
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Meta-Llama-3.1-8B-Instruct
|
||||||
|
python -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --enable-torch-compile --disable-radix-cache
|
||||||
|
|
||||||
|
# Meta-Llama-3.1-70B-Instruct
|
||||||
|
python -m sglang.launch_server --model-path meta-llama/Llama-3.1-70B-Instruct --disable-radix-cache --tp 8
|
||||||
|
|
||||||
|
# Meta-Llama-3-70B-Instruct-FP8
|
||||||
|
python -m sglang.launch_server --model-path neuralmagic/Meta-Llama-3-70B-Instruct-FP8 --disable-radix-cache --tp 8
|
||||||
|
```
|
||||||
|
|
||||||
|
## Benchmark
|
||||||
|
|
||||||
|
### Hardware Requirements
|
||||||
|
|
||||||
|
- 8B models: Single NVIDIA A100 80GB GPU
|
||||||
|
- 70B models: 8 x NVIDIA A100 80GB GPUs with Tensor Parallelism (TP) 8
|
||||||
|
- 70B FP8 models: 8 x NVIDIA H100 GPUs with Tensor Parallelism (TP) 8
|
||||||
|
|
||||||
|
Please ensure you have the appropriate hardware before running the benchmarks.
|
||||||
|
|
||||||
|
#### Offline benchmark
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompts 4000 --random-input 1024 --random-output 1024 --output-file offline.jsonl
|
||||||
|
python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompts 5000 --random-input 1024 --random-output 512 --output-file offline.jsonl
|
||||||
|
python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompts 1000 --random-input 4096 --random-output 2048 --output-file offline.jsonl
|
||||||
|
python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompts 2000 --random-input 4096 --random-output 1024 --output-file offline.jsonl
|
||||||
|
python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompts 6000 --random-input 256 --random-output 512 --output-file offline.jsonl
|
||||||
|
python3 -m sglang.bench_serving --backend sglang --dataset-name sharegpt --num-prompts 3000 --output-file offline.jsonl
|
||||||
|
cat offline.jsonl | cut -d':' -f12 | cut -d',' -f1
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Online benchmark
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 300 --request-rate 1 --output-file online.jsonl
|
||||||
|
python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 600 --request-rate 2 --output-file online.jsonl
|
||||||
|
python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 1200 --request-rate 4 --output-file online.jsonl
|
||||||
|
python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 2400 --request-rate 8 --output-file online.jsonl
|
||||||
|
python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 3200 --request-rate 16 --output-file online.jsonl
|
||||||
|
cat online.jsonl | cut -d':' -f9 | cut -d',' -f1
|
||||||
|
```
|
||||||
|
|
||||||
|
## Other
|
||||||
|
|
||||||
|
We tried using vLLM 0.5.3.post1, but it often crashes under high loads, and it seems to have similar or worse performance compared to vLLM 0.5.2 from our partial benchmarking, so we are using the older version, vLLM 0.5.2.
|
||||||
|
|
||||||
|
Preparation for TensorRT LLM can refer to https://github.com/sgl-project/tensorrt-demo. Specifically, we used a batch size of 512, a max input length of 8192, and a max number of tokens of 8192. The instance count for preprocessing and postprocessing in Triton Server is 16.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# vLLM
|
||||||
|
pip install vllm==0.5.2
|
||||||
|
pip install jsonschema==4.21.1
|
||||||
|
|
||||||
|
# Meta-Llama-3-8B-Instruct
|
||||||
|
python -m vllm.entrypoints.openai.api_server --model meta-llama/Meta-Llama-3-8B-Instruct --disable-log-requests
|
||||||
|
|
||||||
|
# meta-llama/Meta-Llama-3-70B-Instruct
|
||||||
|
python -m vllm.entrypoints.openai.api_server --model meta-llama/Meta-Llama-3-70B-Instruct --disable-log-requests --tensor 8
|
||||||
|
|
||||||
|
# neuralmagic/Meta-Llama-3-70B-Instruct-FP8
|
||||||
|
python -m vllm.entrypoints.openai.api_server --model neuralmagic/Meta-Llama-3-70B-Instruct-FP8 --disable-log-requests --tensor 8
|
||||||
|
```
|
||||||
|
|
||||||
|
```bash
|
||||||
|
wget https://raw.githubusercontent.com/sgl-project/sglang/main/python/sglang/bench_serving.py
|
||||||
|
```
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# vLLM Offline
|
||||||
|
|
||||||
|
python3 bench_serving.py --backend vllm --dataset-name random --num-prompts 4000 --random-input 1024 --random-output 1024 --output-file offline_vllm.jsonl
|
||||||
|
python3 bench_serving.py --backend vllm --dataset-name random --num-prompts 5000 --random-input 1024 --random-output 512 --output-file offline_vllm.jsonl
|
||||||
|
python3 bench_serving.py --backend vllm --dataset-name random --num-prompts 1000 --random-input 4096 --random-output 2048 --output-file offline_vllm.jsonl
|
||||||
|
python3 bench_serving.py --backend vllm --dataset-name random --num-prompts 2000 --random-input 4096 --random-output 1024 --output-file offline_vllm.jsonl
|
||||||
|
python3 bench_serving.py --backend vllm --dataset-name random --num-prompts 6000 --random-input 256 --random-output 512 --output-file offline_vllm.jsonl
|
||||||
|
python3 bench_serving.py --backend vllm --dataset-name sharegpt --num-prompts 3000 --output-file offline_vllm.jsonl
|
||||||
|
cat offline_vllm.jsonl | cut -d':' -f12 | cut -d',' -f1
|
||||||
|
```
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# vLLM Online
|
||||||
|
|
||||||
|
python3 bench_serving.py --backend vllm --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 300 --request-rate 1 --output-file online_vllm.jsonl
|
||||||
|
python3 bench_serving.py --backend vllm --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 600 --request-rate 2 --output-file online_vllm.jsonl
|
||||||
|
python3 bench_serving.py --backend vllm --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 1200 --request-rate 4 --output-file online_vllm.jsonl
|
||||||
|
python3 bench_serving.py --backend vllm --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 2400 --request-rate 8 --output-file online_vllm.jsonl
|
||||||
|
python3 bench_serving.py --backend vllm --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 3200 --request-rate 16 --output-file online_vllm.jsonl
|
||||||
|
cat online_vllm.jsonl | cut -d':' -f9 | cut -d',' -f1
|
||||||
|
```
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# TensorRT LLM Offline 8B
|
||||||
|
|
||||||
|
python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-8B-Instruct --dataset-name random --num-prompts 4000 --random-input 1024 --random-output 1024 --output-file offline_trt_8b.jsonl
|
||||||
|
python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-8B-Instruct --dataset-name random --num-prompts 5000 --random-input 1024 --random-output 512 --output-file offline_trt_8b.jsonl
|
||||||
|
python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-8B-Instruct --dataset-name random --num-prompts 1000 --random-input 4096 --random-output 2048 --output-file offline_trt_8b.jsonl
|
||||||
|
python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-8B-Instruct --dataset-name random --num-prompts 2000 --random-input 4096 --random-output 1024 --output-file offline_trt_8b.jsonl
|
||||||
|
python3 bench_serving.py --backend trt --dataset-name random --num-prompts 6000 --random-input 256 --random-output 512 --output-file offline_trt_8b.jsonl --model meta-llama/Meta-Llama-3-8B-Instruct
|
||||||
|
python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-8B-Instruct --dataset-name sharegpt --num-prompts 3000 --output-file offline_trt_8b.jsonl
|
||||||
|
cat offline_trt_8b.jsonl | cut -d':' -f12 | cut -d',' -f1
|
||||||
|
```
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# TensorRT LLM Online 8B
|
||||||
|
|
||||||
|
python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-8B-Instruct --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 300 --request-rate 1 --output-file online_trt_8b.jsonl
|
||||||
|
python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-8B-Instruct --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 600 --request-rate 2 --output-file online_trt_8b.jsonl
|
||||||
|
python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-8B-Instruct --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 1200 --request-rate 4 --output-file online_trt_8b.jsonl
|
||||||
|
python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-8B-Instruct --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 2400 --request-rate 8 --output-file online_trt_8b.jsonl
|
||||||
|
python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-8B-Instruct --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 3200 --request-rate 16 --output-file online_trt_8b.jsonl
|
||||||
|
cat online_trt_8b.jsonl | cut -d':' -f9 | cut -d',' -f1
|
||||||
|
```
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# TensorRT LLM Offline 70B
|
||||||
|
|
||||||
|
python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-70B-Instruct --dataset-name random --num-prompts 4000 --random-input 1024 --random-output 1024 --output-file offline_trt_70b.jsonl
|
||||||
|
python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-70B-Instruct --dataset-name random --num-prompts 5000 --random-input 1024 --random-output 512 --output-file offline_trt_70b.jsonl
|
||||||
|
python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-70B-Instruct --dataset-name random --num-prompts 1000 --random-input 4096 --random-output 2048 --output-file offline_trt_70b.jsonl
|
||||||
|
python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-70B-Instruct --dataset-name random --num-prompts 2000 --random-input 4096 --random-output 1024 --output-file offline_trt_70b.jsonl
|
||||||
|
python3 bench_serving.py --backend trt --dataset-name random --num-prompts 6000 --random-input 256 --random-output 512 --output-file offline_trt_70b.jsonl --model meta-llama/Meta-Llama-3-70B-Instruct
|
||||||
|
python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-70B-Instruct --dataset-name sharegpt --num-prompts 3000 --output-file offline_trt_70b.jsonl
|
||||||
|
cat offline_trt_70b.jsonl | cut -d':' -f12 | cut -d',' -f1
|
||||||
|
```
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# TensorRT LLM Online 70B
|
||||||
|
|
||||||
|
python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-70B-Instruct --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 300 --request-rate 1 --output-file online_trt_70b.jsonl
|
||||||
|
python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-70B-Instruct --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 600 --request-rate 2 --output-file online_trt_70b.jsonl
|
||||||
|
python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-70B-Instruct --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 1200 --request-rate 4 --output-file online_trt_70b.jsonl
|
||||||
|
python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-70B-Instruct --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 2400 --request-rate 8 --output-file online_trt_70b.jsonl
|
||||||
|
python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-70B-Instruct --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 3200 --request-rate 16 --output-file online_trt_70b.jsonl
|
||||||
|
cat online_trt_70b.jsonl | cut -d':' -f9 | cut -d',' -f1
|
||||||
|
```
|
||||||
|
|
@ -0,0 +1,100 @@
|
||||||
|
### used for TensorRT LLM
|
||||||
|
|
||||||
|
```
|
||||||
|
{
|
||||||
|
"architecture": "LlamaForCausalLM",
|
||||||
|
"dtype": "float16",
|
||||||
|
"logits_dtype": "float32",
|
||||||
|
"vocab_size": 128256,
|
||||||
|
"max_position_embeddings": 8192,
|
||||||
|
"hidden_size": 16384,
|
||||||
|
"num_hidden_layers": 126,
|
||||||
|
"num_attention_heads": 128,
|
||||||
|
"num_key_value_heads": 16,
|
||||||
|
"head_size": 128,
|
||||||
|
"qk_layernorm": false,
|
||||||
|
"hidden_act": "silu",
|
||||||
|
"intermediate_size": 53248,
|
||||||
|
"norm_epsilon": 1e-05,
|
||||||
|
"position_embedding_type": "rope_gpt_neox",
|
||||||
|
"use_parallel_embedding": false,
|
||||||
|
"embedding_sharding_dim": 0,
|
||||||
|
"share_embedding_table": false,
|
||||||
|
"mapping": {
|
||||||
|
"world_size": 8,
|
||||||
|
"tp_size": 8,
|
||||||
|
"pp_size": 1,
|
||||||
|
"gpus_per_node": 8
|
||||||
|
},
|
||||||
|
"quantization": {
|
||||||
|
"quant_algo": "FP8",
|
||||||
|
"kv_cache_quant_algo": null,
|
||||||
|
"group_size": 128,
|
||||||
|
"smoothquant_val": null,
|
||||||
|
"has_zero_point": false,
|
||||||
|
"pre_quant_scale": false,
|
||||||
|
"exclude_modules": [
|
||||||
|
"lm_head"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"kv_dtype": "float16",
|
||||||
|
"rotary_scaling": null,
|
||||||
|
"residual_mlp": false,
|
||||||
|
"moe_normalization_mode": null,
|
||||||
|
"rotary_base": 500000.0,
|
||||||
|
"moe_num_experts": 0,
|
||||||
|
"moe_top_k": 0,
|
||||||
|
"moe_tp_mode": 2,
|
||||||
|
"attn_bias": false,
|
||||||
|
"disable_weight_only_quant_plugin": false,
|
||||||
|
"mlp_bias": false
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### used for vLLM and SGLang
|
||||||
|
|
||||||
|
```
|
||||||
|
{
|
||||||
|
"_name_or_path": "dummy_fp8",
|
||||||
|
"architectures": [
|
||||||
|
"LlamaForCausalLM"
|
||||||
|
],
|
||||||
|
"attention_bias": false,
|
||||||
|
"attention_dropout": 0.0,
|
||||||
|
"bos_token_id": 128000,
|
||||||
|
"eos_token_id": 128009,
|
||||||
|
"hidden_act": "silu",
|
||||||
|
"hidden_size": 16384,
|
||||||
|
"initializer_range": 0.02,
|
||||||
|
"intermediate_size": 53248,
|
||||||
|
"mlp_bias": false,
|
||||||
|
"model_type": "llama",
|
||||||
|
"num_attention_heads": 128,
|
||||||
|
"num_hidden_layers": 126,
|
||||||
|
"num_key_value_heads": 8,
|
||||||
|
"pretraining_tp": 1,
|
||||||
|
"quantization_config": {
|
||||||
|
"activation_scheme": "static",
|
||||||
|
"ignored_layers": [
|
||||||
|
"lm_head"
|
||||||
|
],
|
||||||
|
"quant_method": "fp8"
|
||||||
|
},
|
||||||
|
"rope_scaling": {
|
||||||
|
"factor": 8.0,
|
||||||
|
"low_freq_factor": 1.0,
|
||||||
|
"high_freq_factor": 4.0,
|
||||||
|
"original_max_position_embeddings": 8192,
|
||||||
|
"rope_type": "llama3"
|
||||||
|
},
|
||||||
|
"max_position_embeddings": 131072,
|
||||||
|
"rms_norm_eps": 1e-05,
|
||||||
|
"rope_scaling": null,
|
||||||
|
"rope_theta": 500000.0,
|
||||||
|
"tie_word_embeddings": false,
|
||||||
|
"torch_dtype": "bfloat16",
|
||||||
|
"transformers_version": "4.41.1",
|
||||||
|
"use_cache": true,
|
||||||
|
"vocab_size": 128256
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
@ -0,0 +1,285 @@
|
||||||
|
# DeepSeek V3 Support
|
||||||
|
|
||||||
|
The SGLang and DeepSeek teams collaborated to get DeepSeek V3 FP8 running on NVIDIA and AMD GPUs **from day one**. SGLang also supports [MLA optimization](https://lmsys.org/blog/2024-09-04-sglang-v0-3/#deepseek-multi-head-latent-attention-mla-throughput-optimizations) and [DP attention](https://lmsys.org/blog/2024-12-04-sglang-v0-4/#data-parallelism-attention-for-deepseek-models), making SGLang one of the best open-source LLM engines for running DeepSeek models. SGLang is the inference engine recommended by the official [DeepSeek team](https://github.com/deepseek-ai/DeepSeek-V3/tree/main?tab=readme-ov-file#62-inference-with-sglang-recommended).
|
||||||
|
|
||||||
|
Special thanks to Meituan's Search & Recommend Platform Team and Baseten's Model Performance Team for implementing the model, and DataCrunch for providing GPU resources.
|
||||||
|
|
||||||
|
For optimizations made on the DeepSeek series models regarding SGLang, please refer to [DeepSeek Model Optimizations in SGLang](https://docs.sglang.ai/references/deepseek.html).
|
||||||
|
|
||||||
|
## Installation & Launch
|
||||||
|
|
||||||
|
If you encounter errors when starting the server, ensure the weights have finished downloading. It's recommended to download them beforehand or restart multiple times until all weights are downloaded.
|
||||||
|
|
||||||
|
### Using Docker (Recommended)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Pull latest image
|
||||||
|
# https://hub.docker.com/r/lmsysorg/sglang/tags
|
||||||
|
docker pull lmsysorg/sglang:latest
|
||||||
|
|
||||||
|
# Launch
|
||||||
|
docker run --gpus all --shm-size 32g -p 30000:30000 -v ~/.cache/huggingface:/root/.cache/huggingface --ipc=host --network=host --privileged lmsysorg/sglang:latest \
|
||||||
|
python3 -m sglang.launch_server --model deepseek-ai/DeepSeek-V3 --tp 8 --trust-remote-code --port 30000
|
||||||
|
```
|
||||||
|
|
||||||
|
If you are using RDMA, please note that:
|
||||||
|
|
||||||
|
1. `--network host` and `--privileged` are required by RDMA. If you don't need RDMA, you can remove them.
|
||||||
|
2. You may need to set `NCCL_IB_GID_INDEX` if you are using RoCE, for example: `export NCCL_IB_GID_INDEX=3`.
|
||||||
|
|
||||||
|
Add [performance optimization options](#performance-optimization-options) as needed.
|
||||||
|
|
||||||
|
### Using pip
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Installation
|
||||||
|
pip install "sglang[all]>=0.4.3" --find-links https://flashinfer.ai/whl/cu124/torch2.5/flashinfer-python
|
||||||
|
|
||||||
|
# Launch
|
||||||
|
python3 -m sglang.launch_server --model deepseek-ai/DeepSeek-V3 --tp 8 --trust-remote-code
|
||||||
|
```
|
||||||
|
|
||||||
|
Add [performance optimization options](#performance-optimization-options) as needed.
|
||||||
|
|
||||||
|
<a id="option_args"></a>
|
||||||
|
|
||||||
|
### Performance Optimization Options
|
||||||
|
|
||||||
|
[MLA optimizations](https://lmsys.org/blog/2024-09-04-sglang-v0-3/#deepseek-multi-head-latent-attention-mla-throughput-optimizations) are enabled by default. Here are some optional optimizations can be enabled as needed.
|
||||||
|
|
||||||
|
- [Data Parallelism Attention](https://lmsys.org/blog/2024-12-04-sglang-v0-4/#data-parallelism-attention-for-deepseek-models): For high QPS scenarios, add the `--enable-dp-attention` argument to boost throughput.
|
||||||
|
- [Torch.compile Optimization](https://lmsys.org/blog/2024-09-04-sglang-v0-3/#torchcompile-latency-optimizations): Add `--enable-torch-compile` argument to enable it. This will take some time while server starts. The maximum batch size for torch.compile optimization can be controlled with `--torch-compile-max-bs`. It's recommended to set it between `1` and `8`. (e.g., `--torch-compile-max-bs 8`)
|
||||||
|
|
||||||
|
### Example: Sending requests with OpenAI API
|
||||||
|
|
||||||
|
```python3
|
||||||
|
import openai
|
||||||
|
client = openai.Client(
|
||||||
|
base_url="http://127.0.0.1:30000/v1", api_key="EMPTY")
|
||||||
|
|
||||||
|
# Chat completion
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model="default",
|
||||||
|
messages=[
|
||||||
|
{"role": "system", "content": "You are a helpful AI assistant"},
|
||||||
|
{"role": "user", "content": "List 3 countries and their capitals."},
|
||||||
|
],
|
||||||
|
temperature=0,
|
||||||
|
max_tokens=64,
|
||||||
|
)
|
||||||
|
print(response)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Example: Serving with two H20\*8 nodes
|
||||||
|
|
||||||
|
For example, there are two H20 nodes, each with 8 GPUs. The first node's IP is `10.0.0.1`, and the second node's IP is `10.0.0.2`. Please **use the first node's IP** for both commands.
|
||||||
|
|
||||||
|
If the command fails, try setting the `GLOO_SOCKET_IFNAME` parameter. For more information, see [Common Environment Variables](https://pytorch.org/docs/stable/distributed.html#common-environment-variables).
|
||||||
|
|
||||||
|
If the multi nodes support NVIDIA InfiniBand and encounter hanging issues during startup, consider adding the parameter `export NCCL_IB_GID_INDEX=3`. For more information, see [this](https://github.com/sgl-project/sglang/issues/3516#issuecomment-2668493307).
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# node 1
|
||||||
|
python3 -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3 --tp 16 --dist-init-addr 10.0.0.1:5000 --nnodes 2 --node-rank 0 --trust-remote-code
|
||||||
|
|
||||||
|
# node 2
|
||||||
|
python3 -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3 --tp 16 --dist-init-addr 10.0.0.1:5000 --nnodes 2 --node-rank 1 --trust-remote-code
|
||||||
|
```
|
||||||
|
|
||||||
|
If you have two H100 nodes, the usage is similar to the aforementioned H20.
|
||||||
|
|
||||||
|
> **Note that the launch command here does not enable Data Parallelism Attention or `torch.compile` Optimization**. For optimal performance, please refer to the command options in [Performance Optimization Options](#option_args).
|
||||||
|
|
||||||
|
### Example: Serving with two H200\*8 nodes and docker
|
||||||
|
|
||||||
|
There are two H200 nodes, each with 8 GPUs. The first node's IP is `192.168.114.10`, and the second node's IP is `192.168.114.11`. Configure the endpoint to expose it to another Docker container using `--host 0.0.0.0` and `--port 40000`, and set up communications with `--dist-init-addr 192.168.114.10:20000`.
|
||||||
|
A single H200 with 8 devices can run DeepSeek V3, the dual H200 setup is just to demonstrate multi-node usage.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# node 1
|
||||||
|
docker run --gpus all \
|
||||||
|
--shm-size 32g \
|
||||||
|
--network=host \
|
||||||
|
-v ~/.cache/huggingface:/root/.cache/huggingface \
|
||||||
|
--name sglang_multinode1 \
|
||||||
|
-it \
|
||||||
|
--rm \
|
||||||
|
--env "HF_TOKEN=$HF_TOKEN" \
|
||||||
|
--ipc=host \
|
||||||
|
lmsysorg/sglang:latest \
|
||||||
|
python3 -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3 --tp 16 --dist-init-addr 192.168.114.10:20000 --nnodes 2 --node-rank 0 --trust-remote-code --host 0.0.0.0 --port 40000
|
||||||
|
```
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# node 2
|
||||||
|
docker run --gpus all \
|
||||||
|
--shm-size 32g \
|
||||||
|
--network=host \
|
||||||
|
-v ~/.cache/huggingface:/root/.cache/huggingface \
|
||||||
|
--name sglang_multinode2 \
|
||||||
|
-it \
|
||||||
|
--rm \
|
||||||
|
--env "HF_TOKEN=$HF_TOKEN" \
|
||||||
|
--ipc=host \
|
||||||
|
lmsysorg/sglang:latest \
|
||||||
|
python3 -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3 --tp 16 --dist-init-addr 192.168.114.10:20000 --nnodes 2 --node-rank 1 --trust-remote-code --host 0.0.0.0 --port 40000
|
||||||
|
```
|
||||||
|
|
||||||
|
To ensure functionality, we include a test from a client Docker container.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
docker run --gpus all \
|
||||||
|
--shm-size 32g \
|
||||||
|
--network=host \
|
||||||
|
-v ~/.cache/huggingface:/root/.cache/huggingface \
|
||||||
|
--name sglang_multinode_client \
|
||||||
|
-it \
|
||||||
|
--rm \
|
||||||
|
--env "HF_TOKEN=$HF_TOKEN" \
|
||||||
|
--ipc=host \
|
||||||
|
lmsysorg/sglang:latest \
|
||||||
|
python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 1 --random-output 512 --random-range-ratio 1 --num-prompts 1 --host 0.0.0.0 --port 40000 --output-file "deepseekv3_multinode.jsonl"
|
||||||
|
```
|
||||||
|
|
||||||
|
> **Note that the launch command here does not enable Data Parallelism Attention or `torch.compile` Optimization**. For optimal performance, please refer to the command options in [Performance Optimization Options](#option_args).
|
||||||
|
|
||||||
|
### Example: Serving with four A100\*8 nodes
|
||||||
|
|
||||||
|
To serve DeepSeek-V3 with A100 GPUs, we need to convert the [FP8 model checkpoints](https://huggingface.co/deepseek-ai/DeepSeek-V3) to BF16 with [script](https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/fp8_cast_bf16.py) mentioned [here](https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/fp8_cast_bf16.py) first.
|
||||||
|
|
||||||
|
Since the BF16 model is over 1.3 TB, we need to prepare four A100 nodes, each with 8 80GB GPUs. Assume the first node's IP is `10.0.0.1`, and the converted model path is `/path/to/DeepSeek-V3-BF16`, we can have following commands to launch the server.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# node 1
|
||||||
|
python3 -m sglang.launch_server --model-path /path/to/DeepSeek-V3-BF16 --tp 32 --dist-init-addr 10.0.0.1:5000 --nnodes 4 --node-rank 0 --trust-remote-code --host 0.0.0.0 --port 30000
|
||||||
|
|
||||||
|
# node 2
|
||||||
|
python3 -m sglang.launch_server --model-path /path/to/DeepSeek-V3-BF16 --tp 32 --dist-init-addr 10.0.0.1:5000 --nnodes 4 --node-rank 1 --trust-remote-code
|
||||||
|
|
||||||
|
# node 3
|
||||||
|
python3 -m sglang.launch_server --model-path /path/to/DeepSeek-V3-BF16 --tp 32 --dist-init-addr 10.0.0.1:5000 --nnodes 4 --node-rank 2 --trust-remote-code
|
||||||
|
|
||||||
|
# node 4
|
||||||
|
python3 -m sglang.launch_server --model-path /path/to/DeepSeek-V3-BF16 --tp 32 --dist-init-addr 10.0.0.1:5000 --nnodes 4 --node-rank 3 --trust-remote-code
|
||||||
|
```
|
||||||
|
|
||||||
|
> **Note that the launch command here does not enable Data Parallelism Attention or `torch.compile` Optimization**. For optimal performance, please refer to the command options in [Performance Optimization Options](#option_args).
|
||||||
|
|
||||||
|
Then we can benchmark the accuracy and latency by accessing the first node's exposed port with the following example commands.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# bench accuracy
|
||||||
|
python3 benchmark/gsm8k/bench_sglang.py --num-questions 1319 --host http://10.0.0.1 --port 30000
|
||||||
|
|
||||||
|
# bench latency
|
||||||
|
python3 -m sglang.bench_one_batch_server --model None --base-url http://10.0.0.1:30000 --batch-size 1 --input-len 128 --output-len 128
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
### Example: Serving with 8 A100/A800 with AWQ Quantization
|
||||||
|
|
||||||
|
AWQ does not support BF16, so add the `--dtype half` flag if AWQ is used for quantization. One example is as follows:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python3 -m sglang.launch_server --model cognitivecomputations/DeepSeek-R1-AWQ --tp 8 --trust-remote-code --dtype half
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
### Example: Serving with 16 A100/A800 with int8 Quantization
|
||||||
|
|
||||||
|
There are block-wise and per-channel quantization methods, and the quantization parameters have already been uploaded to Huggingface. One example is as follows:
|
||||||
|
|
||||||
|
- [meituan/DeepSeek-R1-Block-INT8](https://huggingface.co/meituan/DeepSeek-R1-Block-INT8)
|
||||||
|
- [meituan/DeepSeek-R1-Channel-INT8](https://huggingface.co/meituan/DeepSeek-R1-Channel-INT8)
|
||||||
|
|
||||||
|
Assuming that master node IP is `MASTER_IP`, checkpoint path is `/path/to/DeepSeek-R1-INT8` and port=5000, we can have following commands to launch the server:
|
||||||
|
```bash
|
||||||
|
#master
|
||||||
|
python3 -m sglang.launch_server \
|
||||||
|
--model meituan/DeepSeek-R1-Block-INT8 --tp 16 --dist-init-addr \
|
||||||
|
MASTER_IP:5000 --nnodes 2 --node-rank 0 --trust-remote-code --enable-torch-compile --torch-compile-max-bs 8
|
||||||
|
#cluster
|
||||||
|
python3 -m sglang.launch_server \
|
||||||
|
--model meituan/DeepSeek-R1-Block-INT8 --tp 16 --dist-init-addr \
|
||||||
|
MASTER_IP:5000 --nnodes 2 --node-rank 1 --trust-remote-code --enable-torch-compile --torch-compile-max-bs 8
|
||||||
|
```
|
||||||
|
|
||||||
|
> **Note that the launch command here enables `torch.compile` Optimization**. For optimal performance, please refer to the command options in [Performance Optimization Options](#option_args).
|
||||||
|
|
||||||
|
Then on the **master node**, supposing the ShareGPT data is located at `/path/to/ShareGPT_V3_unfiltered_cleaned_split.json`, you can run the following commands to benchmark the launched server:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# bench accuracy
|
||||||
|
python3 benchmark/gsm8k/bench_sglang.py --num-questions 1319
|
||||||
|
|
||||||
|
# bench serving
|
||||||
|
python3 -m sglang.bench_serving --dataset-path /path/to/ShareGPT_V3_unfiltered_cleaned_split.json --dataset-name random --random-input 128 --random-output 128 --num-prompts 1000 --request-rate 128 --random-range-ratio 1.0
|
||||||
|
```
|
||||||
|
|
||||||
|
> **Note: using `--parallel 200` can accelerate accuracy benchmarking**.
|
||||||
|
|
||||||
|
### Example: Serving with 32 L40S with int8 Quantization
|
||||||
|
|
||||||
|
Running with per-channel quantization model:
|
||||||
|
|
||||||
|
- [meituan/DeepSeek-R1-Channel-INT8](https://huggingface.co/meituan/DeepSeek-R1-Channel-INT8)
|
||||||
|
|
||||||
|
Assuming that master node IP is `MASTER_IP`, checkpoint path is `/path/to/DeepSeek-R1-Channel-INT8` and port=5000, we can have following commands to launch the server:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
#master
|
||||||
|
python3 -m sglang.launch_server --model meituan/DeepSeek-R1-Channel-INT8 --tp 32 --quantization w8a8_int8 \
|
||||||
|
--dist-init-addr MASTER_IP:5000 --nnodes 4 --node-rank 0 --trust-remote \
|
||||||
|
--enable-torch-compile --torch-compile-max-bs 32
|
||||||
|
#cluster
|
||||||
|
python3 -m sglang.launch_server --model meituan/DeepSeek-R1-Channel-INT8 --tp 32 --quantization w8a8_int8 \
|
||||||
|
--dist-init-addr MASTER_IP:5000 --nnodes 4 --node-rank 1 --trust-remote \
|
||||||
|
--enable-torch-compile --torch-compile-max-bs 32
|
||||||
|
python3 -m sglang.launch_server --model meituan/DeepSeek-R1-Channel-INT8 --tp 32 --quantization w8a8_int8 \
|
||||||
|
--dist-init-addr MASTER_IP:5000 --nnodes 4 --node-rank 2 --trust-remote \
|
||||||
|
--enable-torch-compile --torch-compile-max-bs 32
|
||||||
|
python3 -m sglang.launch_server --model meituan/DeepSeek-R1-Channel-INT8 --tp 32 --quantization w8a8_int8 \
|
||||||
|
--dist-init-addr MASTER_IP:5000 --nnodes 4 --node-rank 3 --trust-remote \
|
||||||
|
--enable-torch-compile --torch-compile-max-bs 32
|
||||||
|
```
|
||||||
|
|
||||||
|
The benchmarking method is the same as describted in the previous [16 x A100](https://github.com/sgl-project/sglang/tree/main/benchmark/deepseek_v3#example-serving-with-16-a100a800-with-int8-quantization) example.
|
||||||
|
|
||||||
|
### Example: Serving on any cloud or Kubernetes with SkyPilot
|
||||||
|
|
||||||
|
SkyPilot helps find cheapest available GPUs across any cloud or existing Kubernetes clusters and launch distributed serving with a single command. See details [here](https://github.com/skypilot-org/skypilot/tree/master/llm/deepseek-r1).
|
||||||
|
|
||||||
|
To serve on multiple nodes:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git clone https://github.com/skypilot-org/skypilot.git
|
||||||
|
# Serve on 2 H100/H200x8 nodes
|
||||||
|
sky launch -c r1 llm/deepseek-r1/deepseek-r1-671B.yaml --retry-until-up
|
||||||
|
# Serve on 4 A100x8 nodes
|
||||||
|
sky launch -c r1 llm/deepseek-r1/deepseek-r1-671B-A100.yaml --retry-until-up
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Troubleshooting
|
||||||
|
|
||||||
|
If you encounter the following error with fp16/bf16 checkpoint:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
ValueError: Weight output_partition_size = 576 is not divisible by weight quantization block_n = 128.
|
||||||
|
```
|
||||||
|
|
||||||
|
edit your `config.json` and remove the `quantization_config` block. For example:
|
||||||
|
|
||||||
|
```json
|
||||||
|
"quantization_config": {
|
||||||
|
"activation_scheme": "dynamic",
|
||||||
|
"fmt": "e4m3",
|
||||||
|
"quant_method": "fp8",
|
||||||
|
"weight_block_size": [128, 128]
|
||||||
|
},
|
||||||
|
```
|
||||||
|
|
||||||
|
Removing this block typically resolves the error. For more details, see the discussion in [sgl-project/sglang#3491](https://github.com/sgl-project/sglang/issues/3491#issuecomment-2650779851).
|
||||||
|
|
||||||
|
## DeepSeek V3 Optimization Plan
|
||||||
|
|
||||||
|
https://github.com/sgl-project/sglang/issues/2591
|
||||||
|
|
@ -0,0 +1,51 @@
|
||||||
|
## Install
|
||||||
|
|
||||||
|
```
|
||||||
|
pip3 install dspy-ai
|
||||||
|
```
|
||||||
|
|
||||||
|
Turn off cache at https://github.com/stanfordnlp/dspy/blob/34d8420383ec752037aa271825c1d3bf391e1277/dsp/modules/cache_utils.py#L10.
|
||||||
|
```
|
||||||
|
cache_turn_on = False
|
||||||
|
```
|
||||||
|
|
||||||
|
or set the environment variable
|
||||||
|
|
||||||
|
```
|
||||||
|
export DSP_CACHEBOOL=false
|
||||||
|
```
|
||||||
|
|
||||||
|
## Benchmark SGLang
|
||||||
|
```
|
||||||
|
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
|
||||||
|
```
|
||||||
|
|
||||||
|
```
|
||||||
|
python3 bench_dspy_intro.py --backend sglang
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## Benchmark TGI
|
||||||
|
```
|
||||||
|
docker run --name tgi --rm -ti --gpus all --network host \
|
||||||
|
-v /home/ubuntu/model_weights/Llama-2-7b-chat-hf:/Llama-2-7b-chat-hf \
|
||||||
|
ghcr.io/huggingface/text-generation-inference:1.3.0 \
|
||||||
|
--model-id /Llama-2-7b-chat-hf --num-shard 1 --trust-remote-code \
|
||||||
|
--max-input-length 2048 --max-total-tokens 4096 \
|
||||||
|
--port 24000
|
||||||
|
```
|
||||||
|
|
||||||
|
```
|
||||||
|
python3 bench_dspy_intro.py --backend tgi
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
## Benchmark vLLM
|
||||||
|
```
|
||||||
|
python3 -m vllm.entrypoints.openai.api_server --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000
|
||||||
|
```
|
||||||
|
|
||||||
|
```
|
||||||
|
python3 bench_dspy_intro.py --backend vllm
|
||||||
|
```
|
||||||
|
|
@ -0,0 +1,192 @@
|
||||||
|
"""
|
||||||
|
Adapted from
|
||||||
|
https://github.com/stanfordnlp/dspy/blob/34d8420383ec752037aa271825c1d3bf391e1277/intro.ipynb#L9
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
import dspy
|
||||||
|
from dspy.datasets import HotPotQA
|
||||||
|
|
||||||
|
|
||||||
|
class BasicQA(dspy.Signature):
|
||||||
|
"""Answer questions with short factoid answers."""
|
||||||
|
|
||||||
|
question = dspy.InputField()
|
||||||
|
answer = dspy.OutputField(desc="often between 1 and 5 words")
|
||||||
|
|
||||||
|
|
||||||
|
class GenerateAnswer(dspy.Signature):
|
||||||
|
"""Answer questions with short factoid answers."""
|
||||||
|
|
||||||
|
context = dspy.InputField(desc="may contain relevant facts")
|
||||||
|
question = dspy.InputField()
|
||||||
|
answer = dspy.OutputField(desc="often between 1 and 5 words")
|
||||||
|
|
||||||
|
|
||||||
|
class RAG(dspy.Module):
|
||||||
|
def __init__(self, num_passages=3):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.retrieve = dspy.Retrieve(k=num_passages)
|
||||||
|
self.generate_answer = dspy.ChainOfThought(GenerateAnswer)
|
||||||
|
|
||||||
|
def forward(self, question):
|
||||||
|
context = self.retrieve(question).passages
|
||||||
|
prediction = self.generate_answer(context=context, question=question)
|
||||||
|
return dspy.Prediction(context=context, answer=prediction.answer)
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
# lm = dspy.OpenAI(model='gpt-3.5-turbo')
|
||||||
|
if args.backend == "tgi":
|
||||||
|
lm = dspy.HFClientTGI(
|
||||||
|
model="meta-llama/Llama-2-7b-chat-hf",
|
||||||
|
port=args.port,
|
||||||
|
url="http://localhost",
|
||||||
|
)
|
||||||
|
elif args.backend == "sglang":
|
||||||
|
lm = dspy.HFClientSGLang(
|
||||||
|
model="meta-llama/Llama-2-7b-chat-hf",
|
||||||
|
port=args.port,
|
||||||
|
url="http://localhost",
|
||||||
|
)
|
||||||
|
elif args.backend == "vllm":
|
||||||
|
lm = dspy.HFClientVLLM(
|
||||||
|
model="meta-llama/Llama-2-7b-chat-hf",
|
||||||
|
port=args.port,
|
||||||
|
url="http://localhost",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid backend: {args.backend}")
|
||||||
|
|
||||||
|
colbertv2_wiki17_abstracts = dspy.ColBERTv2(
|
||||||
|
url="http://20.102.90.50:2017/wiki17_abstracts"
|
||||||
|
)
|
||||||
|
dspy.settings.configure(lm=lm, rm=colbertv2_wiki17_abstracts)
|
||||||
|
|
||||||
|
# Load the dataset.
|
||||||
|
dataset = HotPotQA(
|
||||||
|
train_seed=1, train_size=20, eval_seed=2023, dev_size=args.dev_size, test_size=0
|
||||||
|
)
|
||||||
|
|
||||||
|
# Tell DSPy that the 'question' field is the input. Any other fields are labels and/or metadata.
|
||||||
|
trainset = [x.with_inputs("question") for x in dataset.train]
|
||||||
|
devset = [x.with_inputs("question") for x in dataset.dev]
|
||||||
|
|
||||||
|
print(len(trainset), len(devset))
|
||||||
|
|
||||||
|
train_example = trainset[0]
|
||||||
|
print(f"Question: {train_example.question}")
|
||||||
|
print(f"Answer: {train_example.answer}")
|
||||||
|
|
||||||
|
dev_example = devset[18]
|
||||||
|
print(f"Question: {dev_example.question}")
|
||||||
|
print(f"Answer: {dev_example.answer}")
|
||||||
|
print(f"Relevant Wikipedia Titles: {dev_example.gold_titles}")
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"For this dataset, training examples have input keys {train_example.inputs().keys()} and label keys {train_example.labels().keys()}"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"For this dataset, dev examples have input keys {dev_example.inputs().keys()} and label keys {dev_example.labels().keys()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Define the predictor.
|
||||||
|
generate_answer = dspy.Predict(BasicQA)
|
||||||
|
|
||||||
|
# Call the predictor on a particular input.
|
||||||
|
pred = generate_answer(question=dev_example.question)
|
||||||
|
|
||||||
|
# Print the input and the prediction.
|
||||||
|
print(f"Question: {dev_example.question}")
|
||||||
|
print(f"Predicted Answer: {pred.answer}")
|
||||||
|
|
||||||
|
lm.inspect_history(n=1)
|
||||||
|
|
||||||
|
# Define the predictor. Notice we're just changing the class. The signature BasicQA is unchanged.
|
||||||
|
generate_answer_with_chain_of_thought = dspy.ChainOfThought(BasicQA)
|
||||||
|
|
||||||
|
# Call the predictor on the same input.
|
||||||
|
pred = generate_answer_with_chain_of_thought(question=dev_example.question)
|
||||||
|
|
||||||
|
# Print the input, the chain of thought, and the prediction.
|
||||||
|
print(f"Question: {dev_example.question}")
|
||||||
|
print(f"Thought: {pred.rationale.split('.', 1)[1].strip()}")
|
||||||
|
print(f"Predicted Answer: {pred.answer}")
|
||||||
|
|
||||||
|
retrieve = dspy.Retrieve(k=3)
|
||||||
|
topK_passages = retrieve(dev_example.question).passages
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"Top {retrieve.k} passages for question: {dev_example.question} \n",
|
||||||
|
"-" * 30,
|
||||||
|
"\n",
|
||||||
|
)
|
||||||
|
|
||||||
|
for idx, passage in enumerate(topK_passages):
|
||||||
|
print(f"{idx+1}]", passage, "\n")
|
||||||
|
|
||||||
|
retrieve("When was the first FIFA World Cup held?").passages[0]
|
||||||
|
|
||||||
|
from dspy.teleprompt import BootstrapFewShot
|
||||||
|
|
||||||
|
# Validation logic: check that the predicted answer is correct.
|
||||||
|
# Also check that the retrieved context does actually contain that answer.
|
||||||
|
def validate_context_and_answer(example, pred, trace=None):
|
||||||
|
answer_EM = dspy.evaluate.answer_exact_match(example, pred)
|
||||||
|
answer_PM = dspy.evaluate.answer_passage_match(example, pred)
|
||||||
|
return answer_EM and answer_PM
|
||||||
|
|
||||||
|
# Set up a basic teleprompter, which will compile our RAG program.
|
||||||
|
teleprompter = BootstrapFewShot(metric=validate_context_and_answer)
|
||||||
|
|
||||||
|
# Compile!
|
||||||
|
compiled_rag = teleprompter.compile(RAG(), trainset=trainset)
|
||||||
|
|
||||||
|
# Ask any question you like to this simple RAG program.
|
||||||
|
my_question = "What castle did David Gregory inherit?"
|
||||||
|
|
||||||
|
# Get the prediction. This contains `pred.context` and `pred.answer`.
|
||||||
|
pred = compiled_rag(my_question)
|
||||||
|
|
||||||
|
# Print the contexts and the answer.
|
||||||
|
print(f"Question: {my_question}")
|
||||||
|
print(f"Predicted Answer: {pred.answer}")
|
||||||
|
print(f"Retrieved Contexts (truncated): {[c[:200] + '...' for c in pred.context]}")
|
||||||
|
|
||||||
|
from dspy.evaluate.evaluate import Evaluate
|
||||||
|
|
||||||
|
# Set up the `evaluate_on_hotpotqa` function. We'll use this many times below.
|
||||||
|
evaluate_on_hotpotqa = Evaluate(
|
||||||
|
devset=devset,
|
||||||
|
num_threads=args.num_threads,
|
||||||
|
display_progress=True,
|
||||||
|
display_table=5,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Evaluate the `compiled_rag` program with the `answer_exact_match` metric.
|
||||||
|
metric = dspy.evaluate.answer_exact_match
|
||||||
|
evaluate_on_hotpotqa(compiled_rag, metric=metric)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--port", type=int)
|
||||||
|
parser.add_argument("--num-threads", type=int, default=32)
|
||||||
|
parser.add_argument("--dev-size", type=int, default=150)
|
||||||
|
parser.add_argument(
|
||||||
|
"--backend", type=str, choices=["sglang", "tgi", "vllm"], default="sglang"
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.port is None:
|
||||||
|
default_port = {
|
||||||
|
"vllm": 21000,
|
||||||
|
"lightllm": 22000,
|
||||||
|
"tgi": 24000,
|
||||||
|
"sglang": 30000,
|
||||||
|
}
|
||||||
|
args.port = default_port.get(args.backend, None)
|
||||||
|
|
||||||
|
main(args)
|
||||||
|
|
@ -0,0 +1,38 @@
|
||||||
|
## Download the dataset
|
||||||
|
|
||||||
|
```
|
||||||
|
wget -O agent_calls.jsonl https://drive.google.com/uc?export=download&id=19qLpD45e9JGTKF2cUjJJegwzSUEZEKht
|
||||||
|
```
|
||||||
|
|
||||||
|
## Run benchmark
|
||||||
|
|
||||||
|
Ensure that this benchmark is run in a serial manner (using --parallel 1) to preserve any potential dependencies between requests.
|
||||||
|
|
||||||
|
### Benchmark sglang
|
||||||
|
```
|
||||||
|
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
|
||||||
|
```
|
||||||
|
|
||||||
|
```
|
||||||
|
python3 bench_sglang.py --num-events 1000 --parallel 1
|
||||||
|
```
|
||||||
|
|
||||||
|
### Benchmark vllm
|
||||||
|
```
|
||||||
|
python3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000
|
||||||
|
```
|
||||||
|
|
||||||
|
```
|
||||||
|
python3 bench_other.py --num-events 1000 --backend vllm --parallel 1
|
||||||
|
```
|
||||||
|
|
||||||
|
### Benchmark guidance
|
||||||
|
```
|
||||||
|
python3 bench_other.py --num-events 1000 --backend guidance --parallel 1 --n-ctx 4096 --model-path path/to/gguf
|
||||||
|
```
|
||||||
|
|
||||||
|
### Benchmark lmql
|
||||||
|
|
||||||
|
```
|
||||||
|
python3 bench_other.py --num-events 1000 --backend lmql --parallel 1
|
||||||
|
```
|
||||||
|
|
@ -0,0 +1,300 @@
|
||||||
|
import sglang as sgl
|
||||||
|
|
||||||
|
# here are the top five agent functions contributing ~70% LLM calls
|
||||||
|
# reference: https://github.com/joonspk-research/generative_agents/
|
||||||
|
|
||||||
|
|
||||||
|
@sgl.function
|
||||||
|
def poignancy_event(s, persona_name, persona_iss, event):
|
||||||
|
s += "Here is a brief description of " + persona_name + ".\n"
|
||||||
|
s += persona_iss + "\n"
|
||||||
|
s += "On the scale of 1 to 10, where 1 is purely mundane (e.g., brushing teeth, making bed) and 10 is extremely poignant (e.g., a break up, college acceptance), rate the likely poignancy of the following event for"
|
||||||
|
s += persona_name + ".\n\n"
|
||||||
|
s += "Event: " + event
|
||||||
|
s += "Rate (return a number between 1 to 10):"
|
||||||
|
s += sgl.gen(name="Rate", max_tokens=2)
|
||||||
|
|
||||||
|
|
||||||
|
def poignancy_event_prompt(persona_name, persona_iss, event):
|
||||||
|
# return prompt and max_tokens
|
||||||
|
s = ""
|
||||||
|
s += "Here is a brief description of " + persona_name + ".\n"
|
||||||
|
s += persona_iss + "\n"
|
||||||
|
s += "On the scale of 1 to 10, where 1 is purely mundane (e.g., brushing teeth, making bed) and 10 is extremely poignant (e.g., a break up, college acceptance), rate the likely poignancy of the following event for"
|
||||||
|
s += persona_name + ".\n\n"
|
||||||
|
s += "Event: " + event
|
||||||
|
s += "Rate (return a number between 1 to 10):"
|
||||||
|
return {"prompt": s, "max_tokens": 2, "stop": None}
|
||||||
|
|
||||||
|
|
||||||
|
@sgl.function
|
||||||
|
def generate_event_triple(s, persona_name, action):
|
||||||
|
s += """Task: Turn the input into (subject, predicate, object).
|
||||||
|
Input: Sam Johnson is eating breakfast.
|
||||||
|
Output: (Dolores Murphy, eat, breakfast)
|
||||||
|
---
|
||||||
|
Input: Joon Park is brewing coffee.
|
||||||
|
Output: (Joon Park, brew, coffee)
|
||||||
|
---
|
||||||
|
Input: Jane Cook is sleeping.
|
||||||
|
Output: (Jane Cook, is, sleep)
|
||||||
|
---
|
||||||
|
Input: Michael Bernstein is writing email on a computer.
|
||||||
|
Output: (Michael Bernstein, write, email)
|
||||||
|
---
|
||||||
|
Input: Percy Liang is teaching students in a classroom.
|
||||||
|
Output: (Percy Liang, teach, students)
|
||||||
|
---
|
||||||
|
Input: Merrie Morris is running on a treadmill.
|
||||||
|
Output: (Merrie Morris, run, treadmill)
|
||||||
|
---"""
|
||||||
|
s += persona_name + "is" + action + ".\n"
|
||||||
|
s += "(" + persona_name + ","
|
||||||
|
s += sgl.gen(name="Triple", max_tokens=20, stop=")")
|
||||||
|
|
||||||
|
|
||||||
|
def generate_event_triple_prompt(persona_name, action):
|
||||||
|
s = ""
|
||||||
|
s += """Task: Turn the input into (subject, predicate, object).
|
||||||
|
Input: Sam Johnson is eating breakfast.
|
||||||
|
Output: (Dolores Murphy, eat, breakfast)
|
||||||
|
---
|
||||||
|
Input: Joon Park is brewing coffee.
|
||||||
|
Output: (Joon Park, brew, coffee)
|
||||||
|
---
|
||||||
|
Input: Jane Cook is sleeping.
|
||||||
|
Output: (Jane Cook, is, sleep)
|
||||||
|
---
|
||||||
|
Input: Michael Bernstein is writing email on a computer.
|
||||||
|
Output: (Michael Bernstein, write, email)
|
||||||
|
---
|
||||||
|
Input: Percy Liang is teaching students in a classroom.
|
||||||
|
Output: (Percy Liang, teach, students)
|
||||||
|
---
|
||||||
|
Input: Merrie Morris is running on a treadmill.
|
||||||
|
Output: (Merrie Morris, run, treadmill)
|
||||||
|
---"""
|
||||||
|
s += persona_name + "is" + action + ".\n"
|
||||||
|
s += "(" + persona_name + ","
|
||||||
|
return {"prompt": s, "max_tokens": 20, "stop": ")"}
|
||||||
|
|
||||||
|
|
||||||
|
@sgl.function
|
||||||
|
def generate_pronunciatio(s, action):
|
||||||
|
s += "Convert an action description to an emoji (important: use two or less emojis).\n"
|
||||||
|
s += "Action description: " + action + ".\n"
|
||||||
|
s += "Emoji:" + sgl.gen(name="Emoji", max_tokens=6)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_pronunciatio_prompt(action):
|
||||||
|
s = ""
|
||||||
|
s += "Convert an action description to an emoji (important: use two or less emojis).\n"
|
||||||
|
s += "Action description: " + action + ".\n"
|
||||||
|
s += "Emoji:"
|
||||||
|
return {"prompt": s, "max_tokens": 6, "stop": None}
|
||||||
|
|
||||||
|
|
||||||
|
@sgl.function
|
||||||
|
def action_location_sector(
|
||||||
|
s,
|
||||||
|
persona_name,
|
||||||
|
living_sector,
|
||||||
|
living_sector_areas,
|
||||||
|
current_sector,
|
||||||
|
current_sector_areas,
|
||||||
|
daily_plan,
|
||||||
|
sector_options,
|
||||||
|
current_action,
|
||||||
|
next_action,
|
||||||
|
):
|
||||||
|
s += """Task -- choose an appropriate area from the area options for a task at hand.
|
||||||
|
Sam Kim lives in {Sam Kim's house} that has Sam Kim's room, bathroom, kitchen.
|
||||||
|
Sam Kim is currently in {Sam Kim's house} that has Sam Kim's room, bathroom, kitchen.
|
||||||
|
Area options: {Sam Kim's house, The Rose and Crown Pub, Hobbs Cafe, Oak Hill College, Johnson Park, Harvey Oak Supply Store, The Willows Market and Pharmacy}.
|
||||||
|
* Stay in the current area if the activity can be done there. Only go out if the activity needs to take place in another place.
|
||||||
|
* Must be one of the "Area options," verbatim.
|
||||||
|
For taking a walk, Sam Kim should go to the following area: {Johnson Park}
|
||||||
|
---
|
||||||
|
Jane Anderson lives in {Oak Hill College Student Dormatory} that has Jane Anderson's room.
|
||||||
|
Jane Anderson is currently in {Oak Hill College} that has a classroom, library
|
||||||
|
Area options: {Oak Hill College Student Dormatory, The Rose and Crown Pub, Hobbs Cafe, Oak Hill College, Johnson Park, Harvey Oak Supply Store, The Willows Market and Pharmacy}.
|
||||||
|
* Stay in the current area if the activity can be done there. Only go out if the activity needs to take place in another place.
|
||||||
|
* Must be one of the "Area options," verbatim.
|
||||||
|
For eating dinner, Jane Anderson should go to the following area: {Hobbs Cafe}
|
||||||
|
---"""
|
||||||
|
s += (
|
||||||
|
persona_name
|
||||||
|
+ " lives in "
|
||||||
|
+ living_sector
|
||||||
|
+ " that has "
|
||||||
|
+ living_sector_areas
|
||||||
|
+ ".\n"
|
||||||
|
)
|
||||||
|
s += (
|
||||||
|
persona_name
|
||||||
|
+ " is currently in "
|
||||||
|
+ current_sector
|
||||||
|
+ " that has "
|
||||||
|
+ current_sector_areas
|
||||||
|
+ ".\n"
|
||||||
|
)
|
||||||
|
s += daily_plan + ".\n"
|
||||||
|
s += "Area options: " + sector_options + ".\n"
|
||||||
|
s += """* Stay in the current area if the activity can be done there. Only go out if the activity needs to take place in another place.
|
||||||
|
* Must be one of the "Area options," verbatim.\n"""
|
||||||
|
s += (
|
||||||
|
persona_name
|
||||||
|
+ " is "
|
||||||
|
+ current_action
|
||||||
|
+ ". For "
|
||||||
|
+ next_action
|
||||||
|
+ ", "
|
||||||
|
+ persona_name
|
||||||
|
+ " should go to the following area: {"
|
||||||
|
)
|
||||||
|
s += sgl.gen(name="Location", max_tokens=10, stop="}")
|
||||||
|
|
||||||
|
|
||||||
|
def action_location_sector_prompt(
|
||||||
|
persona_name,
|
||||||
|
living_sector,
|
||||||
|
living_sector_areas,
|
||||||
|
current_sector,
|
||||||
|
current_sector_areas,
|
||||||
|
daily_plan,
|
||||||
|
sector_options,
|
||||||
|
current_action,
|
||||||
|
next_action,
|
||||||
|
):
|
||||||
|
s = ""
|
||||||
|
s += """Task -- choose an appropriate area from the area options for a task at hand.
|
||||||
|
Sam Kim lives in {Sam Kim's house} that has Sam Kim's room, bathroom, kitchen.
|
||||||
|
Sam Kim is currently in {Sam Kim's house} that has Sam Kim's room, bathroom, kitchen.
|
||||||
|
Area options: {Sam Kim's house, The Rose and Crown Pub, Hobbs Cafe, Oak Hill College, Johnson Park, Harvey Oak Supply Store, The Willows Market and Pharmacy}.
|
||||||
|
* Stay in the current area if the activity can be done there. Only go out if the activity needs to take place in another place.
|
||||||
|
* Must be one of the "Area options," verbatim.
|
||||||
|
For taking a walk, Sam Kim should go to the following area: {Johnson Park}
|
||||||
|
---
|
||||||
|
Jane Anderson lives in {Oak Hill College Student Dormatory} that has Jane Anderson's room.
|
||||||
|
Jane Anderson is currently in {Oak Hill College} that has a classroom, library
|
||||||
|
Area options: {Oak Hill College Student Dormatory, The Rose and Crown Pub, Hobbs Cafe, Oak Hill College, Johnson Park, Harvey Oak Supply Store, The Willows Market and Pharmacy}.
|
||||||
|
* Stay in the current area if the activity can be done there. Only go out if the activity needs to take place in another place.
|
||||||
|
* Must be one of the "Area options," verbatim.
|
||||||
|
For eating dinner, Jane Anderson should go to the following area: {Hobbs Cafe}
|
||||||
|
---"""
|
||||||
|
s += (
|
||||||
|
persona_name
|
||||||
|
+ " lives in "
|
||||||
|
+ living_sector
|
||||||
|
+ " that has "
|
||||||
|
+ living_sector_areas
|
||||||
|
+ ".\n"
|
||||||
|
)
|
||||||
|
s += (
|
||||||
|
persona_name
|
||||||
|
+ " is currently in "
|
||||||
|
+ current_sector
|
||||||
|
+ " that has "
|
||||||
|
+ current_sector_areas
|
||||||
|
+ ".\n"
|
||||||
|
)
|
||||||
|
s += daily_plan + ".\n"
|
||||||
|
s += "Area options: " + sector_options + ".\n"
|
||||||
|
s += """* Stay in the current area if the activity can be done there. Only go out if the activity needs to take place in another place.
|
||||||
|
* Must be one of the "Area options," verbatim.\n"""
|
||||||
|
s += (
|
||||||
|
persona_name
|
||||||
|
+ " is "
|
||||||
|
+ current_action
|
||||||
|
+ ". For "
|
||||||
|
+ next_action
|
||||||
|
+ ", "
|
||||||
|
+ persona_name
|
||||||
|
+ " should go to the following area: {"
|
||||||
|
)
|
||||||
|
return {"prompt": s, "max_tokens": 10, "stop": "}"}
|
||||||
|
|
||||||
|
|
||||||
|
@sgl.function
|
||||||
|
def action_location_object(
|
||||||
|
s, persona_name, target_sector, target_sector_areas, current_action, next_action
|
||||||
|
):
|
||||||
|
s += """
|
||||||
|
Jane Anderson is in kitchen in Jane Anderson's house.
|
||||||
|
Jane Anderson is going to Jane Anderson's house that has the following areas: {kitchen, bedroom, bathroom}
|
||||||
|
Stay in the current area if the activity can be done there. Never go into other people's rooms unless necessary.
|
||||||
|
For cooking, Jane Anderson should go to the following area in Jane Anderson's house:
|
||||||
|
Answer: {kitchen}
|
||||||
|
---
|
||||||
|
Tom Watson is in common room in Tom Watson's apartment.
|
||||||
|
Tom Watson is going to Hobbs Cafe that has the following areas: {cafe}
|
||||||
|
Stay in the current area if the activity can be done there. Never go into other people's rooms unless necessary.
|
||||||
|
For getting coffee, Tom Watson should go to the following area in Hobbs Cafe:
|
||||||
|
Answer: {cafe}
|
||||||
|
---"""
|
||||||
|
s += (
|
||||||
|
persona_name
|
||||||
|
+ " is going to "
|
||||||
|
+ target_sector
|
||||||
|
+ " that has the following areas: {"
|
||||||
|
+ target_sector_areas
|
||||||
|
+ "}\n"
|
||||||
|
)
|
||||||
|
s += """* Stay in the current area if the activity can be done there.
|
||||||
|
* NEVER go into other people's rooms unless necessary."""
|
||||||
|
s += (
|
||||||
|
persona_name
|
||||||
|
+ " is "
|
||||||
|
+ current_action
|
||||||
|
+ ". For "
|
||||||
|
+ next_action
|
||||||
|
+ ", "
|
||||||
|
+ persona_name
|
||||||
|
+ "should go to the following area in "
|
||||||
|
+ target_sector
|
||||||
|
)
|
||||||
|
s += " (MUST pick one of {" + target_sector_areas + "}):\n"
|
||||||
|
s += "Answer: {" + sgl.gen(name="Area", max_tokens=5, stop="}")
|
||||||
|
|
||||||
|
|
||||||
|
def action_location_object_prompt(
|
||||||
|
persona_name, target_sector, target_sector_areas, current_action, next_action
|
||||||
|
):
|
||||||
|
s = ""
|
||||||
|
s += """
|
||||||
|
Jane Anderson is in kitchen in Jane Anderson's house.
|
||||||
|
Jane Anderson is going to Jane Anderson's house that has the following areas: {kitchen, bedroom, bathroom}
|
||||||
|
Stay in the current area if the activity can be done there. Never go into other people's rooms unless necessary.
|
||||||
|
For cooking, Jane Anderson should go to the following area in Jane Anderson's house:
|
||||||
|
Answer: {kitchen}
|
||||||
|
---
|
||||||
|
Tom Watson is in common room in Tom Watson's apartment.
|
||||||
|
Tom Watson is going to Hobbs Cafe that has the following areas: {cafe}
|
||||||
|
Stay in the current area if the activity can be done there. Never go into other people's rooms unless necessary.
|
||||||
|
For getting coffee, Tom Watson should go to the following area in Hobbs Cafe:
|
||||||
|
Answer: {cafe}
|
||||||
|
---"""
|
||||||
|
s += (
|
||||||
|
persona_name
|
||||||
|
+ " is going to "
|
||||||
|
+ target_sector
|
||||||
|
+ " that has the following areas: {"
|
||||||
|
+ target_sector_areas
|
||||||
|
+ "}\n"
|
||||||
|
)
|
||||||
|
s += """* Stay in the current area if the activity can be done there.
|
||||||
|
* NEVER go into other people's rooms unless necessary."""
|
||||||
|
s += (
|
||||||
|
persona_name
|
||||||
|
+ " is "
|
||||||
|
+ current_action
|
||||||
|
+ ". For "
|
||||||
|
+ next_action
|
||||||
|
+ ", "
|
||||||
|
+ persona_name
|
||||||
|
+ "should go to the following area in "
|
||||||
|
+ target_sector
|
||||||
|
)
|
||||||
|
s += " (MUST pick one of {" + target_sector_areas + "}):\n"
|
||||||
|
s += "Answer: {"
|
||||||
|
return {"prompt": s, "max_tokens": 5, "stop": "}"}
|
||||||
|
|
@ -0,0 +1,80 @@
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
|
||||||
|
from agent_functions import (
|
||||||
|
action_location_object_prompt,
|
||||||
|
action_location_sector_prompt,
|
||||||
|
generate_event_triple_prompt,
|
||||||
|
generate_pronunciatio_prompt,
|
||||||
|
poignancy_event_prompt,
|
||||||
|
)
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate
|
||||||
|
from sglang.utils import dump_state_text, read_jsonl
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
lines = read_jsonl(args.data_path)[: args.num_events]
|
||||||
|
mapping = {
|
||||||
|
"poignancy_event": poignancy_event_prompt,
|
||||||
|
"generate_event_triple": generate_event_triple_prompt,
|
||||||
|
"generate_pronunciatio": generate_pronunciatio_prompt,
|
||||||
|
"action_location_sector": action_location_sector_prompt,
|
||||||
|
"action_location_object": action_location_object_prompt,
|
||||||
|
}
|
||||||
|
|
||||||
|
arguments = [mapping[k](**v) for l in lines for k, v in l.items()]
|
||||||
|
states = []
|
||||||
|
|
||||||
|
# Select backend
|
||||||
|
call_generate = get_call_generate(args)
|
||||||
|
|
||||||
|
def get_one_answer(arg):
|
||||||
|
answer = call_generate(**arg, temperature=0)
|
||||||
|
states.append(answer)
|
||||||
|
|
||||||
|
async def get_one_answer_async(arg):
|
||||||
|
answer = await call_generate(**arg, temperature=0)
|
||||||
|
states.append(answer)
|
||||||
|
|
||||||
|
tic = time.time()
|
||||||
|
# we always sequentially execute agent calls to maintain its dependency
|
||||||
|
if args.backend != "lmql":
|
||||||
|
for arg in tqdm(arguments):
|
||||||
|
get_one_answer(arg)
|
||||||
|
else:
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
for arg in tqdm(arguments):
|
||||||
|
loop.run_until_complete(get_one_answer_async(arg))
|
||||||
|
latency = time.time() - tic
|
||||||
|
|
||||||
|
print(f"Latency: {latency:.3f}")
|
||||||
|
|
||||||
|
# Write results
|
||||||
|
dump_state_text(f"tmp_output_{args.backend}.txt", states)
|
||||||
|
|
||||||
|
with open(args.result_file, "a") as fout:
|
||||||
|
value = {
|
||||||
|
"task": "Generative Agents",
|
||||||
|
"backend": args.backend,
|
||||||
|
"num_gpus": 1,
|
||||||
|
"latency": round(latency, 3),
|
||||||
|
# to pack weighted functions as a single agent
|
||||||
|
"num_requests": len(arguments) / len(mapping),
|
||||||
|
"other": {
|
||||||
|
"parallel": args.parallel,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
fout.write(json.dumps(value) + "\n")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--data-path", type=str, default="agent_calls.jsonl")
|
||||||
|
parser.add_argument("--num-events", type=int, default=10)
|
||||||
|
args = add_common_other_args_and_parse(parser)
|
||||||
|
main(args)
|
||||||
|
|
@ -0,0 +1,74 @@
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
|
||||||
|
from agent_functions import (
|
||||||
|
action_location_object,
|
||||||
|
action_location_sector,
|
||||||
|
generate_event_triple,
|
||||||
|
generate_pronunciatio,
|
||||||
|
poignancy_event,
|
||||||
|
)
|
||||||
|
|
||||||
|
import sglang as sgl
|
||||||
|
from sglang.test.test_utils import (
|
||||||
|
add_common_sglang_args_and_parse,
|
||||||
|
select_sglang_backend,
|
||||||
|
)
|
||||||
|
from sglang.utils import dump_state_text, read_jsonl
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
lines = read_jsonl(args.data_path)[: args.num_events]
|
||||||
|
mapping = {
|
||||||
|
"poignancy_event": poignancy_event,
|
||||||
|
"generate_event_triple": generate_event_triple,
|
||||||
|
"generate_pronunciatio": generate_pronunciatio,
|
||||||
|
"action_location_sector": action_location_sector,
|
||||||
|
"action_location_object": action_location_object,
|
||||||
|
}
|
||||||
|
arguments = [{mapping[k]: v for k, v in l.items()} for l in lines]
|
||||||
|
|
||||||
|
# Select backend
|
||||||
|
backend = select_sglang_backend(args)
|
||||||
|
sgl.set_default_backend(backend)
|
||||||
|
|
||||||
|
states = []
|
||||||
|
# Run requests
|
||||||
|
tic = time.time()
|
||||||
|
for a in arguments:
|
||||||
|
# only a single key in the dict
|
||||||
|
for func, arg in a.items():
|
||||||
|
result = func.run(**arg)
|
||||||
|
result.sync()
|
||||||
|
states.append(result)
|
||||||
|
latency = time.time() - tic
|
||||||
|
|
||||||
|
# Compute accuracy
|
||||||
|
print(f"Latency: {latency:.3f}")
|
||||||
|
|
||||||
|
# Write results
|
||||||
|
dump_state_text(f"tmp_output_{args.backend}.txt", states)
|
||||||
|
|
||||||
|
with open(args.result_file, "a") as fout:
|
||||||
|
value = {
|
||||||
|
"task": "Generative Agents",
|
||||||
|
"backend": args.backend,
|
||||||
|
"num_gpus": 1,
|
||||||
|
"latency": round(latency, 3),
|
||||||
|
# to pack weighted functions as a single agent
|
||||||
|
"num_requests": len(arguments) / len(mapping),
|
||||||
|
"other": {
|
||||||
|
"num_events": args.num_events,
|
||||||
|
"parallel": args.parallel,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
fout.write(json.dumps(value) + "\n")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--data-path", type=str, default="agent_calls.jsonl")
|
||||||
|
parser.add_argument("--num-events", type=int, default=10)
|
||||||
|
args = add_common_sglang_args_and_parse(parser)
|
||||||
|
main(args)
|
||||||
|
|
@ -0,0 +1,47 @@
|
||||||
|
## Run benchmark
|
||||||
|
|
||||||
|
### Benchmark sglang
|
||||||
|
```
|
||||||
|
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
|
||||||
|
```
|
||||||
|
|
||||||
|
```
|
||||||
|
python3 bench_sglang.py --num-questions 200
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
### Benchmark vllm
|
||||||
|
```
|
||||||
|
python3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000
|
||||||
|
```
|
||||||
|
|
||||||
|
```
|
||||||
|
python3 bench_other.py --num-questions 200 --backend vllm
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
### Benchmark lightllm
|
||||||
|
```
|
||||||
|
# A10G
|
||||||
|
python -m lightllm.server.api_server --tokenizer_mode auto --model_dir ~/model_weights/llama-2-7b-chat-hf --max_total_token_num 16000 --port 22000
|
||||||
|
```
|
||||||
|
|
||||||
|
```
|
||||||
|
python3 bench_other.py --num-questions 200 --backend lightllm
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
### Benchmark guidance
|
||||||
|
```
|
||||||
|
python3 bench_other.py --num-questions 200 --backend guidance --parallel 1 --n-ctx 4096 --model-path path/to/gguf
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
### Benchmark lmql
|
||||||
|
```
|
||||||
|
CUDA_VISIBLE_DEVICES=0,1 lmql serve-model meta-llama/Llama-2-7b-chat-hf --cuda --port 23000
|
||||||
|
```
|
||||||
|
|
||||||
|
```
|
||||||
|
python3 bench_other.py --num-questions 100 --backend lmql --parallel 2
|
||||||
|
```
|
||||||
|
|
@ -0,0 +1,151 @@
|
||||||
|
import argparse
|
||||||
|
import ast
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
import time
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate
|
||||||
|
from sglang.utils import download_and_cache_file, dump_state_text, read_jsonl
|
||||||
|
|
||||||
|
INVALID = -9999999
|
||||||
|
|
||||||
|
|
||||||
|
def get_one_example(lines, i, include_answer):
|
||||||
|
ret = "Question: " + lines[i]["question"] + "\nAnswer:"
|
||||||
|
if include_answer:
|
||||||
|
ret += " " + lines[i]["answer"]
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
def get_few_shot_examples(lines, k):
|
||||||
|
ret = ""
|
||||||
|
for i in range(k):
|
||||||
|
ret += get_one_example(lines, i, True) + "\n\n"
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
def get_answer_value(answer_str):
|
||||||
|
answer_str = answer_str.replace(",", "")
|
||||||
|
numbers = re.findall(r"\d+", answer_str)
|
||||||
|
if len(numbers) < 1:
|
||||||
|
return INVALID
|
||||||
|
try:
|
||||||
|
return ast.literal_eval(numbers[-1])
|
||||||
|
except SyntaxError:
|
||||||
|
return INVALID
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
# Select backend
|
||||||
|
call_generate = get_call_generate(args)
|
||||||
|
|
||||||
|
# Read data
|
||||||
|
url = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl"
|
||||||
|
filename = download_and_cache_file(url)
|
||||||
|
lines = list(read_jsonl(filename))
|
||||||
|
|
||||||
|
# Construct prompts
|
||||||
|
num_questions = args.num_questions
|
||||||
|
num_shots = args.num_shots
|
||||||
|
few_shot_examples = get_few_shot_examples(lines, num_shots)
|
||||||
|
|
||||||
|
questions = []
|
||||||
|
labels = []
|
||||||
|
for i in range(len(lines[:num_questions])):
|
||||||
|
questions.append(get_one_example(lines, i, False))
|
||||||
|
labels.append(get_answer_value(lines[i]["answer"]))
|
||||||
|
assert all(l != INVALID for l in labels)
|
||||||
|
|
||||||
|
states = [None] * len(labels)
|
||||||
|
|
||||||
|
# Run requests
|
||||||
|
if args.backend != "lmql":
|
||||||
|
# Use thread pool
|
||||||
|
def get_one_answer(i):
|
||||||
|
answer = call_generate(
|
||||||
|
prompt=few_shot_examples + questions[i],
|
||||||
|
temperature=0,
|
||||||
|
max_tokens=256,
|
||||||
|
stop=["Question", "Assistant:", "<|separator|>"],
|
||||||
|
)
|
||||||
|
states[i] = answer
|
||||||
|
|
||||||
|
tic = time.time()
|
||||||
|
if args.parallel == 1:
|
||||||
|
for i in tqdm(range(len(questions))):
|
||||||
|
get_one_answer(i)
|
||||||
|
else:
|
||||||
|
with ThreadPoolExecutor(args.parallel) as executor:
|
||||||
|
list(
|
||||||
|
tqdm(
|
||||||
|
executor.map(get_one_answer, list(range(len(questions)))),
|
||||||
|
total=len(questions),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Use asyncio
|
||||||
|
async def batched_call(batch_size):
|
||||||
|
for i in range(0, len(questions), batch_size):
|
||||||
|
tasks = []
|
||||||
|
for q in questions[i : i + batch_size]:
|
||||||
|
tasks.append(
|
||||||
|
call_generate(
|
||||||
|
few_shot_examples + q,
|
||||||
|
temperature=0,
|
||||||
|
max_tokens=256,
|
||||||
|
stop="Question",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
rets = await asyncio.gather(*tasks)
|
||||||
|
for j in range(len(rets)):
|
||||||
|
states[i + j] = rets[j]
|
||||||
|
|
||||||
|
tic = time.time()
|
||||||
|
asyncio.run(batched_call(batch_size=args.parallel))
|
||||||
|
latency = time.time() - tic
|
||||||
|
|
||||||
|
preds = []
|
||||||
|
for i in range(len(states)):
|
||||||
|
preds.append(get_answer_value(states[i]))
|
||||||
|
|
||||||
|
# Compute accuracy
|
||||||
|
acc = np.mean(np.array(preds) == np.array(labels))
|
||||||
|
invalid = np.mean(np.array(preds) == INVALID)
|
||||||
|
|
||||||
|
# Print results
|
||||||
|
print(f"Accuracy: {acc:.3f}")
|
||||||
|
print(f"Invalid: {invalid:.3f}")
|
||||||
|
print(f"Latency: {latency:.3f} s")
|
||||||
|
|
||||||
|
# Dump results
|
||||||
|
dump_state_text(f"tmp_output_{args.backend}.txt", states)
|
||||||
|
|
||||||
|
with open(args.result_file, "a") as fout:
|
||||||
|
value = {
|
||||||
|
"task": "gsm8k",
|
||||||
|
"backend": args.backend,
|
||||||
|
"num_gpus": 1,
|
||||||
|
"latency": round(latency, 3),
|
||||||
|
"accuracy": round(acc, 3),
|
||||||
|
"num_requests": args.num_questions,
|
||||||
|
"other": {
|
||||||
|
"num_questions": args.num_questions,
|
||||||
|
"parallel": args.parallel,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
fout.write(json.dumps(value) + "\n")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--num-shots", type=int, default=5)
|
||||||
|
parser.add_argument("--data-path", type=str, default="test.jsonl")
|
||||||
|
parser.add_argument("--num-questions", type=int, default=200)
|
||||||
|
args = add_common_other_args_and_parse(parser)
|
||||||
|
main(args)
|
||||||
|
|
@ -0,0 +1,141 @@
|
||||||
|
import argparse
|
||||||
|
import ast
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import time
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from sglang.api import set_default_backend
|
||||||
|
from sglang.test.test_utils import (
|
||||||
|
add_common_sglang_args_and_parse,
|
||||||
|
select_sglang_backend,
|
||||||
|
)
|
||||||
|
from sglang.utils import download_and_cache_file, dump_state_text, read_jsonl
|
||||||
|
|
||||||
|
INVALID = -9999999
|
||||||
|
|
||||||
|
|
||||||
|
def get_one_example(lines, i, include_answer):
|
||||||
|
ret = "Question: " + lines[i]["question"] + "\nAnswer:"
|
||||||
|
if include_answer:
|
||||||
|
ret += " " + lines[i]["answer"]
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
def get_few_shot_examples(lines, k):
|
||||||
|
ret = ""
|
||||||
|
for i in range(k):
|
||||||
|
ret += get_one_example(lines, i, True) + "\n\n"
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
def get_answer_value(answer_str):
|
||||||
|
answer_str = answer_str.replace(",", "")
|
||||||
|
numbers = re.findall(r"\d+", answer_str)
|
||||||
|
if len(numbers) < 1:
|
||||||
|
return INVALID
|
||||||
|
try:
|
||||||
|
return ast.literal_eval(numbers[-1])
|
||||||
|
except SyntaxError:
|
||||||
|
return INVALID
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
# Select backend
|
||||||
|
set_default_backend(select_sglang_backend(args))
|
||||||
|
|
||||||
|
# Read data
|
||||||
|
data_path = args.data_path
|
||||||
|
url = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl"
|
||||||
|
if not os.path.isfile(data_path):
|
||||||
|
data_path = download_and_cache_file(url)
|
||||||
|
lines = list(read_jsonl(data_path))
|
||||||
|
|
||||||
|
# Construct prompts
|
||||||
|
num_questions = args.num_questions
|
||||||
|
num_shots = args.num_shots
|
||||||
|
few_shot_examples = get_few_shot_examples(lines, num_shots)
|
||||||
|
|
||||||
|
questions = []
|
||||||
|
labels = []
|
||||||
|
for i in range(len(lines[:num_questions])):
|
||||||
|
questions.append(get_one_example(lines, i, False))
|
||||||
|
labels.append(get_answer_value(lines[i]["answer"]))
|
||||||
|
assert all(l != INVALID for l in labels)
|
||||||
|
arguments = [{"question": q} for q in questions]
|
||||||
|
|
||||||
|
#####################################
|
||||||
|
######### SGL Program Begin #########
|
||||||
|
#####################################
|
||||||
|
|
||||||
|
import sglang as sgl
|
||||||
|
|
||||||
|
@sgl.function
|
||||||
|
def few_shot_gsm8k(s, question):
|
||||||
|
s += few_shot_examples + question
|
||||||
|
s += sgl.gen(
|
||||||
|
"answer", max_tokens=512, stop=["Question", "Assistant:", "<|separator|>"]
|
||||||
|
)
|
||||||
|
|
||||||
|
#####################################
|
||||||
|
########## SGL Program End ##########
|
||||||
|
#####################################
|
||||||
|
|
||||||
|
# Run requests
|
||||||
|
tic = time.time()
|
||||||
|
states = few_shot_gsm8k.run_batch(
|
||||||
|
arguments,
|
||||||
|
temperature=0,
|
||||||
|
num_threads=args.parallel,
|
||||||
|
progress_bar=True,
|
||||||
|
)
|
||||||
|
latency = time.time() - tic
|
||||||
|
|
||||||
|
preds = []
|
||||||
|
for i in range(len(states)):
|
||||||
|
preds.append(get_answer_value(states[i]["answer"]))
|
||||||
|
|
||||||
|
# Compute accuracy
|
||||||
|
acc = np.mean(np.array(preds) == np.array(labels))
|
||||||
|
invalid = np.mean(np.array(preds) == INVALID)
|
||||||
|
|
||||||
|
# Compute speed
|
||||||
|
num_output_tokens = sum(
|
||||||
|
s.get_meta_info("answer")["completion_tokens"] for s in states
|
||||||
|
)
|
||||||
|
output_throughput = num_output_tokens / latency
|
||||||
|
|
||||||
|
# Print results
|
||||||
|
print(f"Accuracy: {acc:.3f}")
|
||||||
|
print(f"Invalid: {invalid:.3f}")
|
||||||
|
print(f"Latency: {latency:.3f} s")
|
||||||
|
print(f"Output throughput: {output_throughput:.3f} token/s")
|
||||||
|
|
||||||
|
# Dump results
|
||||||
|
dump_state_text(f"tmp_output_{args.backend}.txt", states)
|
||||||
|
|
||||||
|
with open(args.result_file, "a") as fout:
|
||||||
|
value = {
|
||||||
|
"task": "gsm8k",
|
||||||
|
"backend": args.backend,
|
||||||
|
"num_gpus": 1,
|
||||||
|
"latency": round(latency, 3),
|
||||||
|
"accuracy": round(acc, 3),
|
||||||
|
"num_requests": args.num_questions,
|
||||||
|
"other": {
|
||||||
|
"num_questions": args.num_questions,
|
||||||
|
"parallel": args.parallel,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
fout.write(json.dumps(value) + "\n")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--num-shots", type=int, default=5)
|
||||||
|
parser.add_argument("--data-path", type=str, default="test.jsonl")
|
||||||
|
parser.add_argument("--num-questions", type=int, default=200)
|
||||||
|
args = add_common_sglang_args_and_parse(parser)
|
||||||
|
main(args)
|
||||||
|
|
@ -0,0 +1,47 @@
|
||||||
|
## Run benchmark
|
||||||
|
|
||||||
|
### Benchmark sglang
|
||||||
|
```
|
||||||
|
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
|
||||||
|
```
|
||||||
|
|
||||||
|
```
|
||||||
|
python3 bench_sglang.py --num-questions 200
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
### Benchmark vllm
|
||||||
|
```
|
||||||
|
python3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000
|
||||||
|
```
|
||||||
|
|
||||||
|
```
|
||||||
|
python3 bench_other.py --num-questions 200 --backend vllm
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
### Benchmark lightllm
|
||||||
|
```
|
||||||
|
# A10G
|
||||||
|
python -m lightllm.server.api_server --tokenizer_mode auto --model_dir ~/model_weights/llama-2-7b-chat-hf --max_total_token_num 16000 --port 22000
|
||||||
|
```
|
||||||
|
|
||||||
|
```
|
||||||
|
python3 bench_other.py --num-questions 200 --backend lightllm
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
### Benchmark guidance
|
||||||
|
```
|
||||||
|
CUDA_VISIBLE_DEVICES=0,1 python3 bench_other.py --num-questions 200 --backend guidance --parallel 1 --n-ctx 4096 --model-path path/to/gguf
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
### Benchmark lmql
|
||||||
|
```
|
||||||
|
lmql serve-model meta-llama/Llama-2-7b-chat-hf --cuda --port 23000
|
||||||
|
```
|
||||||
|
|
||||||
|
```
|
||||||
|
python3 bench_other.py --num-questions 200 --backend lmql --port 23000 --parallel 1
|
||||||
|
```
|
||||||
|
|
@ -0,0 +1,118 @@
|
||||||
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from sglang.test.test_utils import add_common_other_args_and_parse, get_call_select
|
||||||
|
from sglang.utils import download_and_cache_file, read_jsonl
|
||||||
|
|
||||||
|
|
||||||
|
def get_one_example(lines, i, include_answer):
|
||||||
|
ret = lines[i]["activity_label"] + ": " + lines[i]["ctx"] + " "
|
||||||
|
if include_answer:
|
||||||
|
ret += lines[i]["endings"][lines[i]["label"]]
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
def get_few_shot_examples(lines, k):
|
||||||
|
ret = ""
|
||||||
|
for i in range(k):
|
||||||
|
ret += get_one_example(lines, i, True) + "\n\n"
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
# Select backend
|
||||||
|
call_select = get_call_select(args)
|
||||||
|
|
||||||
|
# Read data
|
||||||
|
url = "https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_val.jsonl"
|
||||||
|
filename = download_and_cache_file(url)
|
||||||
|
lines = list(read_jsonl(filename))
|
||||||
|
|
||||||
|
# Construct prompts
|
||||||
|
num_questions = args.num_questions
|
||||||
|
num_shots = args.num_shots
|
||||||
|
few_shot_examples = get_few_shot_examples(lines, num_shots)
|
||||||
|
|
||||||
|
questions = []
|
||||||
|
choices = []
|
||||||
|
labels = []
|
||||||
|
for i in range(len(lines[:num_questions])):
|
||||||
|
questions.append(get_one_example(lines, i, False))
|
||||||
|
choices.append(lines[i]["endings"])
|
||||||
|
labels.append(lines[i]["label"])
|
||||||
|
|
||||||
|
preds = [None] * len(labels)
|
||||||
|
|
||||||
|
# Run requests
|
||||||
|
if args.backend != "lmql":
|
||||||
|
# Use thread pool
|
||||||
|
def get_one_answer(i):
|
||||||
|
preds[i] = call_select(
|
||||||
|
context=few_shot_examples + questions[i], choices=choices[i]
|
||||||
|
)
|
||||||
|
|
||||||
|
tic = time.time()
|
||||||
|
if args.parallel == 1:
|
||||||
|
for i in tqdm(range(len(questions))):
|
||||||
|
get_one_answer(i)
|
||||||
|
else:
|
||||||
|
with ThreadPoolExecutor(args.parallel) as executor:
|
||||||
|
list(
|
||||||
|
tqdm(
|
||||||
|
executor.map(get_one_answer, list(range(len(questions)))),
|
||||||
|
total=len(questions),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Use asyncio
|
||||||
|
async def batched_call(batch_size):
|
||||||
|
for i in range(0, len(questions), batch_size):
|
||||||
|
tasks = []
|
||||||
|
for q, c in zip(
|
||||||
|
questions[i : i + batch_size], choices[i : i + batch_size]
|
||||||
|
):
|
||||||
|
tasks.append(call_select(context=few_shot_examples + q, choices=c))
|
||||||
|
rets = await asyncio.gather(*tasks)
|
||||||
|
for j in range(len(rets)):
|
||||||
|
preds[i + j] = rets[j]
|
||||||
|
|
||||||
|
tic = time.time()
|
||||||
|
asyncio.run(batched_call(batch_size=args.parallel))
|
||||||
|
|
||||||
|
latency = time.time() - tic
|
||||||
|
|
||||||
|
# Compute accuracy
|
||||||
|
acc = np.mean(np.array(preds) == np.array(labels))
|
||||||
|
print(f"Latency: {latency:.3f}")
|
||||||
|
print(f"Accuracy: {acc:.3f}")
|
||||||
|
|
||||||
|
# Write results
|
||||||
|
with open(args.result_file, "a") as fout:
|
||||||
|
value = {
|
||||||
|
"task": "hellaswag",
|
||||||
|
"backend": args.backend,
|
||||||
|
"num_gpus": 1,
|
||||||
|
"latency": round(latency, 3),
|
||||||
|
"accuracy": round(acc, 3),
|
||||||
|
"num_requests": args.num_questions,
|
||||||
|
"other": {
|
||||||
|
"num_questions": args.num_questions,
|
||||||
|
"parallel": args.parallel,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
fout.write(json.dumps(value) + "\n")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--num-shots", type=int, default=20)
|
||||||
|
parser.add_argument("--data-path", type=str, default="hellaswag_val.jsonl")
|
||||||
|
parser.add_argument("--num-questions", type=int, default=200)
|
||||||
|
args = add_common_other_args_and_parse(parser)
|
||||||
|
main(args)
|
||||||
|
|
@ -0,0 +1,109 @@
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from sglang.api import set_default_backend
|
||||||
|
from sglang.test.test_utils import (
|
||||||
|
add_common_sglang_args_and_parse,
|
||||||
|
select_sglang_backend,
|
||||||
|
)
|
||||||
|
from sglang.utils import download_and_cache_file, read_jsonl
|
||||||
|
|
||||||
|
|
||||||
|
def get_one_example(lines, i, include_answer):
|
||||||
|
ret = lines[i]["activity_label"] + ": " + lines[i]["ctx"] + " "
|
||||||
|
if include_answer:
|
||||||
|
ret += lines[i]["endings"][lines[i]["label"]]
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
def get_few_shot_examples(lines, k):
|
||||||
|
ret = ""
|
||||||
|
for i in range(k):
|
||||||
|
ret += get_one_example(lines, i, True) + "\n\n"
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
# Select backend
|
||||||
|
set_default_backend(select_sglang_backend(args))
|
||||||
|
|
||||||
|
# Read data
|
||||||
|
data_path = args.data_path
|
||||||
|
url = "https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_val.jsonl"
|
||||||
|
if not os.path.isfile(data_path):
|
||||||
|
data_path = download_and_cache_file(url)
|
||||||
|
lines = list(read_jsonl(data_path))
|
||||||
|
|
||||||
|
# Construct prompts
|
||||||
|
num_questions = args.num_questions
|
||||||
|
num_shots = args.num_shots
|
||||||
|
few_shot_examples = get_few_shot_examples(lines, num_shots)
|
||||||
|
|
||||||
|
questions = []
|
||||||
|
choices = []
|
||||||
|
labels = []
|
||||||
|
for i in range(len(lines[:num_questions])):
|
||||||
|
questions.append(get_one_example(lines, i, False))
|
||||||
|
choices.append(lines[i]["endings"])
|
||||||
|
labels.append(lines[i]["label"])
|
||||||
|
arguments = [{"question": q, "choices": c} for q, c in zip(questions, choices)]
|
||||||
|
|
||||||
|
#####################################
|
||||||
|
######### SGL Program Begin #########
|
||||||
|
#####################################
|
||||||
|
|
||||||
|
import sglang as sgl
|
||||||
|
|
||||||
|
@sgl.function
|
||||||
|
def few_shot_hellaswag(s, question, choices):
|
||||||
|
s += few_shot_examples + question
|
||||||
|
s += sgl.select("answer", choices=choices)
|
||||||
|
|
||||||
|
#####################################
|
||||||
|
########## SGL Program End ##########
|
||||||
|
#####################################
|
||||||
|
|
||||||
|
# Run requests
|
||||||
|
tic = time.time()
|
||||||
|
rets = few_shot_hellaswag.run_batch(
|
||||||
|
arguments,
|
||||||
|
temperature=0,
|
||||||
|
num_threads=args.parallel,
|
||||||
|
progress_bar=True,
|
||||||
|
)
|
||||||
|
preds = [choices[i].index(rets[i]["answer"]) for i in range(len(rets))]
|
||||||
|
latency = time.time() - tic
|
||||||
|
|
||||||
|
# Compute accuracy
|
||||||
|
acc = np.mean(np.array(preds) == np.array(labels))
|
||||||
|
print(f"Latency: {latency:.3f}")
|
||||||
|
print(f"Accuracy: {acc:.3f}")
|
||||||
|
|
||||||
|
# Write results
|
||||||
|
with open(args.result_file, "a") as fout:
|
||||||
|
value = {
|
||||||
|
"task": "hellaswag",
|
||||||
|
"backend": args.backend,
|
||||||
|
"num_gpus": 1,
|
||||||
|
"latency": round(latency, 3),
|
||||||
|
"accuracy": round(acc, 3),
|
||||||
|
"num_requests": args.num_questions,
|
||||||
|
"other": {
|
||||||
|
"num_questions": args.num_questions,
|
||||||
|
"parallel": args.parallel,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
fout.write(json.dumps(value) + "\n")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--num-shots", type=int, default=20)
|
||||||
|
parser.add_argument("--data-path", type=str, default="hellaswag_val.jsonl")
|
||||||
|
parser.add_argument("--num-questions", type=int, default=200)
|
||||||
|
args = add_common_sglang_args_and_parse(parser)
|
||||||
|
main(args)
|
||||||
|
|
@ -0,0 +1,91 @@
|
||||||
|
## Run synthetic multi-turn benchmark
|
||||||
|
|
||||||
|
```
|
||||||
|
# SGLang server with radix cache disabled
|
||||||
|
python -m sglang.launch_server --model-path Qwen/Qwen2.5-14B-Instruct --port 30000 --disable-radix-cache
|
||||||
|
|
||||||
|
# SGLang server with radix cache on and first-come-first-serve policy
|
||||||
|
python -m sglang.launch_server --model-path Qwen/Qwen2.5-14B-Instruct --port 30000 --schedule-policy fcfs
|
||||||
|
|
||||||
|
# The default SGLang server with radix cache on and long-prefix-match policy
|
||||||
|
python -m sglang.launch_server --model-path Qwen/Qwen2.5-14B-Instruct --port 30000
|
||||||
|
|
||||||
|
# SGLang server with hierarchical radix cache enabled
|
||||||
|
python -m sglang.launch_server --model-path Qwen/Qwen2.5-14B-Instruct --port 30000 --enable-hierarchical-cache
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
```
|
||||||
|
python bench_multiturn.py --model-path Qwen/Qwen2.5-14B-Instruct
|
||||||
|
```
|
||||||
|
|
||||||
|
Note: The performance gain of hierarchical caching depends on the ratio of reusable tokens to GPU memory capacity. The more tokens to be reused, the larger the model, and the more constrained the GPU memory size, the greater the benefit one can expect from hierarchical caching.
|
||||||
|
|
||||||
|
|
||||||
|
# Benchmark with more datasets
|
||||||
|
## Download Dataset
|
||||||
|
```bash
|
||||||
|
./download.sh {sharegpt|ultragpt|loogle|nextqa|all}
|
||||||
|
```
|
||||||
|
This script will automatically download the required dataset to the current working directory
|
||||||
|
|
||||||
|
## Multiturn Benchmark
|
||||||
|
### Supported Datasets
|
||||||
|
- sharegpt
|
||||||
|
- ultrachat
|
||||||
|
- loogle
|
||||||
|
### Example Usage:
|
||||||
|
```bash
|
||||||
|
python3 bench_serving.py --model mistralai/Mistral-7B-Instruct-v0.3 --backend sglang \
|
||||||
|
--dataset-path longdep_qa.json --dataset-name loogle --request-rate 10 --num-prompts 10 \
|
||||||
|
--port 8001 --enable-multiturn --disable-shuffle
|
||||||
|
```
|
||||||
|
This uses `mistralai/Mistral-7B-Instruct-v0.3` model with `sglang` as backend. The dataset
|
||||||
|
is `longdep_qa.json`. We send `10 conversations` with `10 req/s` to port 8001. We enable
|
||||||
|
multiturn chat without shuffling the order of conversations (i.e. following the original
|
||||||
|
order in the dataset file).
|
||||||
|
|
||||||
|
### Note:
|
||||||
|
The requests of multiple conversations are sent in a round robin fashion.
|
||||||
|
For example, if we have 3 conversations A, B, C whose rounds are `[2, 3, 4]` correspondingly,
|
||||||
|
multiturn chat will send the requests to the backend in the following order: `[A1, B1, C1, A2, B2, C2, B3, C3, C4]`
|
||||||
|
This has implications on the cache reuse patterns: the cache reuse distance is the largest
|
||||||
|
under this request pattern (which means a prefix-aware local scheduler in the backend can
|
||||||
|
yield the most benefit compared to a FIFO scheduler)
|
||||||
|
|
||||||
|
## Shared Prefix Benchmark
|
||||||
|
### Supported Datasets
|
||||||
|
- loogle
|
||||||
|
### Example Usage:
|
||||||
|
```bash
|
||||||
|
python3 bench_serving.py --model mistralai/Mistral-7B-Instruct-v0.3 --backend sglang \
|
||||||
|
--dataset-path longdep_qa.json --dataset-name loogle --request-rate 10 --num-prompts 10 \
|
||||||
|
--port 8001 --enable-shared-prefix --disable-shuffle
|
||||||
|
```
|
||||||
|
### Note:
|
||||||
|
Shared Prefix benchmark sends the questions for the same prompt together. For example,
|
||||||
|
if we have 3 shared prefix A, B, C, which have [2, 3, 4] questions correspondingly,
|
||||||
|
the shared prefix benchmark will send the requests to the
|
||||||
|
backend in the following order: `[A+Q1, A+Q2, B+Q1, B+Q2, B+Q3, C+Q1, C+Q2, C+Q3]`.
|
||||||
|
|
||||||
|
|
||||||
|
## Multi Modality Benchmark (WIP)
|
||||||
|
### Supported Datasets:
|
||||||
|
- nextqa
|
||||||
|
### Example Usage:
|
||||||
|
```bash
|
||||||
|
Server:
|
||||||
|
python3 -m sglang.launch_server --model-path lmms-lab/LLaVA-NeXT-Video-7B --tp 2 --dp 1 --port 8001 \
|
||||||
|
--host 0.0.0.0 --mem-fraction-static 0.9 --tokenizer-path llava-hf/llava-1.5-7b-hf \
|
||||||
|
--json-model-override-args "{\"architectures\": [\"LlavaVidForCausalLM\"], \"model_type\":\"llava\", \"mm_spatial_pool_stride\":2}"
|
||||||
|
|
||||||
|
Client:
|
||||||
|
python3 bench_serving.py --model lmms-lab/LLaVA-NeXT-Video-7B --backend sglang --dataset-path \
|
||||||
|
NExTVideo --dataset-name nextqa --request-rate 10 --num-prompts 1 --disable-shuffle --port 8001 \ --enable-multiturn --max-frames 16 --tokenizer llava-hf/llava-1.5-7b-hf --fixed-output-len 2048
|
||||||
|
```
|
||||||
|
Note: for the server args, `tokenizer-path`, overriding architecture are necessary.
|
||||||
|
|
||||||
|
## Supported Backend
|
||||||
|
- sglang (oai)
|
||||||
|
- vllm (oai)
|
||||||
|
- lmdeploy (oai)
|
||||||
|
|
@ -0,0 +1,395 @@
|
||||||
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import queue
|
||||||
|
import random
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
import requests
|
||||||
|
from tqdm.asyncio import tqdm
|
||||||
|
|
||||||
|
from sglang.bench_serving import (
|
||||||
|
RequestFuncOutput,
|
||||||
|
get_tokenizer,
|
||||||
|
remove_prefix,
|
||||||
|
sample_random_requests,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Script to benchmark concurrent requests to a server."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-clients",
|
||||||
|
type=int,
|
||||||
|
default=256,
|
||||||
|
help="Number of concurrent clients",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-parallel",
|
||||||
|
type=int,
|
||||||
|
default=128,
|
||||||
|
help="Maximum number of parallel requests",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--request-length",
|
||||||
|
type=int,
|
||||||
|
default=512,
|
||||||
|
help="Length of each new request",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output-length",
|
||||||
|
type=int,
|
||||||
|
default=64,
|
||||||
|
help="Length of each output",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-rounds",
|
||||||
|
type=int,
|
||||||
|
default=5,
|
||||||
|
help="Number of rounds per client",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--distribution",
|
||||||
|
type=str,
|
||||||
|
default="poisson",
|
||||||
|
choices=["poisson", "uniform"],
|
||||||
|
help="Distribution type for request intervals (poisson or uniform)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--request-rate",
|
||||||
|
type=float,
|
||||||
|
default=1.0,
|
||||||
|
help="Average number of requests per second",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--host",
|
||||||
|
type=str,
|
||||||
|
default="localhost",
|
||||||
|
help="Server hostname or IP (default: localhost)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--port",
|
||||||
|
type=int,
|
||||||
|
default=30000,
|
||||||
|
help="Server port (default: 30000)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model-path",
|
||||||
|
type=str,
|
||||||
|
default="meta-llama/Llama-3.1-8B-Instruct",
|
||||||
|
help="model path compatible with Hugging Face Transformers",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dataset-path",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help="local dataset to sample tokens from",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--log-file",
|
||||||
|
type=str,
|
||||||
|
default="performance_metrics.jsonl",
|
||||||
|
help="File to log performance metrics",
|
||||||
|
)
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
async def async_request_sglang_generate(
|
||||||
|
payload,
|
||||||
|
url,
|
||||||
|
pbar: Optional[tqdm] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Sends a streaming request to the server. Gathers text token-by-token.
|
||||||
|
"""
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
headers = {}
|
||||||
|
generated_text = ""
|
||||||
|
ttft = 0.0
|
||||||
|
st = time.perf_counter()
|
||||||
|
most_recent_timestamp = st
|
||||||
|
output = RequestFuncOutput()
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with session.post(url=url, json=payload, headers=headers) as response:
|
||||||
|
if response.status == 200:
|
||||||
|
async for chunk_bytes in response.content:
|
||||||
|
chunk_bytes = chunk_bytes.strip()
|
||||||
|
if not chunk_bytes:
|
||||||
|
continue
|
||||||
|
|
||||||
|
chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ")
|
||||||
|
latency = time.perf_counter() - st
|
||||||
|
if chunk == "[DONE]":
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
data = json.loads(chunk)
|
||||||
|
|
||||||
|
if data["text"]:
|
||||||
|
timestamp = time.perf_counter()
|
||||||
|
# First token
|
||||||
|
if ttft == 0.0:
|
||||||
|
ttft = time.perf_counter() - st
|
||||||
|
output.ttft = ttft
|
||||||
|
|
||||||
|
# Decoding phase
|
||||||
|
else:
|
||||||
|
output.itl.append(timestamp - most_recent_timestamp)
|
||||||
|
|
||||||
|
most_recent_timestamp = timestamp
|
||||||
|
generated_text = data["text"]
|
||||||
|
|
||||||
|
output.generated_text = generated_text
|
||||||
|
output.success = True
|
||||||
|
output.latency = latency
|
||||||
|
else:
|
||||||
|
output.error = response.reason or ""
|
||||||
|
output.success = False
|
||||||
|
except Exception as e:
|
||||||
|
output.success = False
|
||||||
|
output.error = str(e)
|
||||||
|
print(f"Request failed: {e}")
|
||||||
|
|
||||||
|
if pbar:
|
||||||
|
pbar.update(1)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def gen_payload(prompt, output_len):
|
||||||
|
payload = {
|
||||||
|
"text": prompt,
|
||||||
|
"sampling_params": {
|
||||||
|
"temperature": 0.0,
|
||||||
|
"max_new_tokens": output_len,
|
||||||
|
"ignore_eos": True,
|
||||||
|
},
|
||||||
|
"stream": True,
|
||||||
|
"lora_path": "",
|
||||||
|
"return_logprob": False,
|
||||||
|
"logprob_start_len": -1,
|
||||||
|
}
|
||||||
|
return payload
|
||||||
|
|
||||||
|
|
||||||
|
def log_to_jsonl_file(data, file_path="performance_metrics.jsonl"):
|
||||||
|
"""Append the data with a timestamp to the specified JSONL file."""
|
||||||
|
timestamped_data = {"timestamp": datetime.now().isoformat(), **data}
|
||||||
|
try:
|
||||||
|
with open(file_path, "a") as file:
|
||||||
|
file.write(
|
||||||
|
json.dumps(timestamped_data) + "\n"
|
||||||
|
) # Write as a single line in JSONL format
|
||||||
|
except IOError as e:
|
||||||
|
print(f"Error writing to JSONL file: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
class ReadyQueue:
|
||||||
|
"""
|
||||||
|
Thread-safe queue that can pop requests in different orders based on given policy.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, init_requests=None, policy="random"):
|
||||||
|
self.lock = threading.Lock()
|
||||||
|
self.requests = init_requests or []
|
||||||
|
self.policy = policy
|
||||||
|
|
||||||
|
def append(self, item):
|
||||||
|
with self.lock:
|
||||||
|
self.requests.append(item)
|
||||||
|
|
||||||
|
def pop(self):
|
||||||
|
with self.lock:
|
||||||
|
if not self.requests:
|
||||||
|
return None
|
||||||
|
if self.policy == "random":
|
||||||
|
index = random.randrange(len(self.requests))
|
||||||
|
return self.requests.pop(index)
|
||||||
|
elif self.policy == "fifo":
|
||||||
|
return self.requests.pop(0)
|
||||||
|
else:
|
||||||
|
# todo, varying thinking time of clients
|
||||||
|
raise ValueError(f"{self.policy} not implemented")
|
||||||
|
|
||||||
|
|
||||||
|
class WorkloadGenerator:
|
||||||
|
def __init__(self, args):
|
||||||
|
# Construct the base URL for requests
|
||||||
|
self.url = f"http://{args.host}:{args.port}/generate"
|
||||||
|
|
||||||
|
self.tokenizer = get_tokenizer(args.model_path)
|
||||||
|
self.distribution = args.distribution
|
||||||
|
self.request_rate = args.request_rate
|
||||||
|
self.start_time = None
|
||||||
|
self.finished_time = None
|
||||||
|
|
||||||
|
self.sent_requests = 0
|
||||||
|
self.completed_requests = 0
|
||||||
|
|
||||||
|
self.candidate_inputs = sample_random_requests(
|
||||||
|
input_len=args.request_length,
|
||||||
|
output_len=args.output_length,
|
||||||
|
num_prompts=args.num_clients * args.num_rounds,
|
||||||
|
range_ratio=1.0,
|
||||||
|
tokenizer=self.tokenizer,
|
||||||
|
dataset_path=args.dataset_path,
|
||||||
|
)
|
||||||
|
self.candidate_inputs = [i[0] for i in self.candidate_inputs]
|
||||||
|
|
||||||
|
init_requests = [
|
||||||
|
(i, gen_payload(self.candidate_inputs[i], args.output_length))
|
||||||
|
for i in range(args.num_clients)
|
||||||
|
]
|
||||||
|
self.client_records = {
|
||||||
|
i: {"round": 0, "history": init_requests[i][1]["text"]}
|
||||||
|
for i in range(args.num_clients)
|
||||||
|
}
|
||||||
|
self.ready_queue = ReadyQueue(init_requests=init_requests)
|
||||||
|
self.candidate_inputs = self.candidate_inputs[args.num_clients :]
|
||||||
|
|
||||||
|
self.response_queue = queue.Queue()
|
||||||
|
self.pbar = tqdm(total=args.num_clients * args.num_rounds)
|
||||||
|
self.performance_metrics = {"ttft": [], "latency": []}
|
||||||
|
|
||||||
|
async def handle_request(self, item):
|
||||||
|
try:
|
||||||
|
client_id, payload = item
|
||||||
|
response = await async_request_sglang_generate(payload, self.url, self.pbar)
|
||||||
|
if self.pbar.n == self.pbar.total:
|
||||||
|
self.finished_time = time.time()
|
||||||
|
self.response_queue.put((client_id, response))
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Request failed: {e}")
|
||||||
|
|
||||||
|
def request_sender(self):
|
||||||
|
async def request_loop():
|
||||||
|
while True:
|
||||||
|
if self.sent_requests - self.completed_requests < args.max_parallel:
|
||||||
|
new_request = self.ready_queue.pop()
|
||||||
|
if new_request:
|
||||||
|
asyncio.create_task(self.handle_request(new_request))
|
||||||
|
self.sent_requests += 1
|
||||||
|
else:
|
||||||
|
await asyncio.sleep(0.05)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if self.pbar.n == self.pbar.total:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Calculate Poisson-distributed wait time
|
||||||
|
if self.distribution == "poisson":
|
||||||
|
sleep_time = random.expovariate(self.request_rate)
|
||||||
|
elif self.distribution == "uniform":
|
||||||
|
avg_interval = (
|
||||||
|
1.0 / self.request_rate if self.request_rate > 0 else 1.0
|
||||||
|
)
|
||||||
|
sleep_time = random.uniform(0, 2 * avg_interval)
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid distribution type")
|
||||||
|
await asyncio.sleep(sleep_time) # Wait before sending the next request
|
||||||
|
|
||||||
|
# Create and run the event loop for asynchronous requests
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
loop.run_until_complete(request_loop())
|
||||||
|
loop.close()
|
||||||
|
|
||||||
|
def response_handler(self):
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
client_id, response = self.response_queue.get(
|
||||||
|
timeout=10
|
||||||
|
) # Block until response is available
|
||||||
|
if not response.success:
|
||||||
|
raise ValueError(f"Request failed with error: {response.error}")
|
||||||
|
self.client_records[client_id]["history"] += response.generated_text
|
||||||
|
self.client_records[client_id]["round"] += 1
|
||||||
|
self.performance_metrics["ttft"].append(response.ttft)
|
||||||
|
self.performance_metrics["latency"].append(response.latency)
|
||||||
|
self.completed_requests += 1
|
||||||
|
|
||||||
|
if self.client_records[client_id]["round"] < args.num_rounds:
|
||||||
|
self.client_records[client_id][
|
||||||
|
"history"
|
||||||
|
] += self.candidate_inputs.pop()
|
||||||
|
self.ready_queue.append(
|
||||||
|
(
|
||||||
|
client_id,
|
||||||
|
gen_payload(
|
||||||
|
self.client_records[client_id]["history"],
|
||||||
|
args.output_length,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except queue.Empty:
|
||||||
|
if self.pbar.n == self.pbar.total:
|
||||||
|
break
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
request_thread = threading.Thread(target=self.request_sender, daemon=True)
|
||||||
|
response_thread = threading.Thread(target=self.response_handler, daemon=True)
|
||||||
|
|
||||||
|
self.start_time = time.time()
|
||||||
|
request_thread.start()
|
||||||
|
response_thread.start()
|
||||||
|
|
||||||
|
request_thread.join()
|
||||||
|
response_thread.join()
|
||||||
|
self.pbar.close()
|
||||||
|
|
||||||
|
performance_data = {
|
||||||
|
"summary": {
|
||||||
|
"total_requests": len(self.performance_metrics["ttft"]),
|
||||||
|
"request_rate": self.request_rate,
|
||||||
|
"average_ttft": sum(self.performance_metrics["ttft"])
|
||||||
|
/ len(self.performance_metrics["ttft"]),
|
||||||
|
"p90_ttft": sorted(self.performance_metrics["ttft"])[
|
||||||
|
int(0.9 * len(self.performance_metrics["ttft"]))
|
||||||
|
],
|
||||||
|
"median_ttft": sorted(self.performance_metrics["ttft"])[
|
||||||
|
len(self.performance_metrics["ttft"]) // 2
|
||||||
|
],
|
||||||
|
"average_latency": sum(self.performance_metrics["latency"])
|
||||||
|
/ len(self.performance_metrics["latency"]),
|
||||||
|
"p90_latency": sorted(self.performance_metrics["latency"])[
|
||||||
|
int(0.9 * len(self.performance_metrics["latency"]))
|
||||||
|
],
|
||||||
|
"median_latency": sorted(self.performance_metrics["latency"])[
|
||||||
|
len(self.performance_metrics["latency"]) // 2
|
||||||
|
],
|
||||||
|
"throughput": self.pbar.total / (self.finished_time - self.start_time),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
print("All requests completed")
|
||||||
|
print("Performance metrics summary:")
|
||||||
|
print(
|
||||||
|
f" Total requests: {performance_data['summary']['total_requests']} at {performance_data['summary']['request_rate']} requests per second"
|
||||||
|
)
|
||||||
|
print(f" Average TTFT: {performance_data['summary']['average_ttft']:.2f}")
|
||||||
|
print(f" P90 TTFT: {performance_data['summary']['p90_ttft']:.2f}")
|
||||||
|
print(f" Median TTFT: {performance_data['summary']['median_ttft']:.2f}")
|
||||||
|
print(
|
||||||
|
f" Average latency: {performance_data['summary']['average_latency']:.2f}"
|
||||||
|
)
|
||||||
|
print(f" P90 latency: {performance_data['summary']['p90_latency']:.2f}")
|
||||||
|
print(f" Median latency: {performance_data['summary']['median_latency']:.2f}")
|
||||||
|
print(
|
||||||
|
f" Throughput: {performance_data['summary']['throughput']:.2f} requests per second"
|
||||||
|
)
|
||||||
|
log_to_jsonl_file(performance_data, args.log_file)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
args = parse_args()
|
||||||
|
flush_cache_url = f"http://{args.host}:{args.port}/flush_cache"
|
||||||
|
|
||||||
|
for request_rate in [16, 14, 12, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1]:
|
||||||
|
args.request_rate = request_rate
|
||||||
|
requests.post(flush_cache_url)
|
||||||
|
time.sleep(1)
|
||||||
|
WorkloadGenerator(args).run()
|
||||||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,589 @@
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import pickle
|
||||||
|
import random
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from nextqa import NExTQALoader
|
||||||
|
|
||||||
|
# from nextqa.video import , VideoPrompt
|
||||||
|
from tqdm.asyncio import tqdm
|
||||||
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
|
||||||
|
SHAREGPT_URL = "https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json"
|
||||||
|
|
||||||
|
from sglang.bench_serving import (
|
||||||
|
download_and_cache_file,
|
||||||
|
gen_prompt,
|
||||||
|
get_gen_prefix_cache_path,
|
||||||
|
)
|
||||||
|
from sglang.lang.chat_template import get_chat_template, get_chat_template_by_model_path
|
||||||
|
from sglang.srt.openai_api.protocol import ChatCompletionMessageContentPart
|
||||||
|
from sglang.utils import encode_video_base64
|
||||||
|
|
||||||
|
# type of content fields, can be only prompts or with images/videos
|
||||||
|
MsgContent = Union[str, List[ChatCompletionMessageContentPart]]
|
||||||
|
|
||||||
|
# A list of all the conversations. Each conversation is a list of
|
||||||
|
# tuples. If multiturn is not enabled, the length of list is 1,
|
||||||
|
# containing only the first Q&A pair.
|
||||||
|
# For the shared prefix workload (synthetic, loogle, nextqa), it
|
||||||
|
# is a list of conversations sharing the same prefix (synthetic,
|
||||||
|
# doc, video)
|
||||||
|
SampleOutput = List[List[Tuple[MsgContent, int, int]]]
|
||||||
|
|
||||||
|
|
||||||
|
def common_filter_chat(
|
||||||
|
num_requests: int,
|
||||||
|
new_dataset: List,
|
||||||
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
|
min_prompt_len: Optional[int],
|
||||||
|
min_output_len: Optional[int],
|
||||||
|
max_prompt_len: Optional[int],
|
||||||
|
max_output_len: Optional[int],
|
||||||
|
fixed_output_len: Optional[int],
|
||||||
|
) -> SampleOutput:
|
||||||
|
# Filter out sequences that are too long or too short
|
||||||
|
filtered_dataset: SampleOutput = []
|
||||||
|
l = 0
|
||||||
|
input_tokens = 0
|
||||||
|
output_tokens = 0
|
||||||
|
while l < num_requests:
|
||||||
|
for i in range(len(new_dataset)):
|
||||||
|
if l == num_requests:
|
||||||
|
break
|
||||||
|
processed = []
|
||||||
|
for j in new_dataset[i]:
|
||||||
|
# Tokenize the prompts and completions.
|
||||||
|
prompt = j[0]
|
||||||
|
prompt_token_ids = tokenizer.encode(prompt)
|
||||||
|
prompt_len = len(prompt_token_ids)
|
||||||
|
|
||||||
|
completion = j[1]
|
||||||
|
completion_token_ids = tokenizer.encode(completion)
|
||||||
|
output_len = (
|
||||||
|
len(completion_token_ids)
|
||||||
|
if fixed_output_len is None
|
||||||
|
else fixed_output_len
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
min_prompt_len is not None
|
||||||
|
and prompt_len < min_prompt_len
|
||||||
|
or min_output_len is not None
|
||||||
|
and output_len < min_output_len
|
||||||
|
or max_prompt_len is not None
|
||||||
|
and prompt_len > max_prompt_len
|
||||||
|
or max_output_len is not None
|
||||||
|
and output_len > max_output_len
|
||||||
|
):
|
||||||
|
# Prune too short sequences.
|
||||||
|
continue
|
||||||
|
input_tokens += prompt_len
|
||||||
|
output_tokens += output_len
|
||||||
|
processed.append((prompt, prompt_len, output_len))
|
||||||
|
filtered_dataset.append(processed)
|
||||||
|
l += 1
|
||||||
|
|
||||||
|
print(f"#Input tokens: {input_tokens}")
|
||||||
|
print(f"#Output tokens: {output_tokens}")
|
||||||
|
return filtered_dataset
|
||||||
|
|
||||||
|
|
||||||
|
def sample_sharegpt_requests(
|
||||||
|
dataset_path: str,
|
||||||
|
num_requests: int,
|
||||||
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
|
disable_shuffle: bool = False,
|
||||||
|
enable_multiturn: bool = True,
|
||||||
|
fixed_output_len: Optional[int] = None,
|
||||||
|
) -> SampleOutput:
|
||||||
|
if fixed_output_len is not None and fixed_output_len < 4:
|
||||||
|
raise ValueError("output_len too small")
|
||||||
|
|
||||||
|
# Download sharegpt if necessary
|
||||||
|
if not os.path.isfile(dataset_path):
|
||||||
|
dataset_path = download_and_cache_file(SHAREGPT_URL)
|
||||||
|
|
||||||
|
# Load the dataset.
|
||||||
|
with open(dataset_path) as f:
|
||||||
|
dataset = json.load(f)
|
||||||
|
# Filter out the conversations with less than 2 turns.
|
||||||
|
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
|
||||||
|
|
||||||
|
# Keep one conversation in one list
|
||||||
|
new_dataset = []
|
||||||
|
for data in dataset:
|
||||||
|
if len(data["conversations"]) % 2 != 0:
|
||||||
|
continue
|
||||||
|
if data["conversations"][0]["from"] != "human":
|
||||||
|
continue
|
||||||
|
chat = []
|
||||||
|
total_len = 2
|
||||||
|
if enable_multiturn:
|
||||||
|
total_len = len(data["conversations"])
|
||||||
|
for i in range(0, total_len, 2):
|
||||||
|
# One user One Assistant
|
||||||
|
chat.append(
|
||||||
|
(
|
||||||
|
data["conversations"][i]["value"],
|
||||||
|
data["conversations"][i + 1]["value"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
new_dataset.append(chat)
|
||||||
|
|
||||||
|
if not disable_shuffle:
|
||||||
|
# Shuffle the dataset.
|
||||||
|
random.shuffle(new_dataset)
|
||||||
|
|
||||||
|
# Filter out sequences that are too long or too short
|
||||||
|
filtered_dataset: SampleOutput = common_filter_chat(
|
||||||
|
num_requests, new_dataset, tokenizer, 4, 4, None, None, fixed_output_len
|
||||||
|
)
|
||||||
|
return filtered_dataset
|
||||||
|
|
||||||
|
|
||||||
|
def sample_ultrachat_requests(
|
||||||
|
dataset_path: str,
|
||||||
|
num_requests: int,
|
||||||
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
|
disable_shuffle: bool = False,
|
||||||
|
enable_multiturn: bool = True,
|
||||||
|
fixed_output_len: Optional[int] = None,
|
||||||
|
) -> SampleOutput:
|
||||||
|
if fixed_output_len is not None and fixed_output_len < 4:
|
||||||
|
raise ValueError("output_len too small")
|
||||||
|
|
||||||
|
# Load the dataset
|
||||||
|
dataset = []
|
||||||
|
with open(dataset_path) as f:
|
||||||
|
while True:
|
||||||
|
line = f.readline()
|
||||||
|
if not line:
|
||||||
|
break
|
||||||
|
dataset.append(json.loads(line))
|
||||||
|
|
||||||
|
# Filter out the conversations with less than 2 turns.
|
||||||
|
dataset = [data for data in dataset if len(data["data"]) >= 2]
|
||||||
|
|
||||||
|
# Keep one conversation in one list
|
||||||
|
new_dataset = []
|
||||||
|
for data in dataset:
|
||||||
|
if len(data["data"]) % 2 != 0:
|
||||||
|
continue
|
||||||
|
chat = []
|
||||||
|
total_len = 2
|
||||||
|
if enable_multiturn:
|
||||||
|
total_len = len(data["data"])
|
||||||
|
for i in range(0, total_len, 2):
|
||||||
|
# One user One Assistant
|
||||||
|
chat.append((data["data"][i], data["data"][i + 1]))
|
||||||
|
new_dataset.append(chat)
|
||||||
|
|
||||||
|
# Shuffle the dataset.
|
||||||
|
if not disable_shuffle:
|
||||||
|
random.shuffle(new_dataset)
|
||||||
|
|
||||||
|
# Filter out sequences that are too long or too short
|
||||||
|
filtered_dataset: SampleOutput = common_filter_chat(
|
||||||
|
num_requests, new_dataset, tokenizer, 4, 4, None, None, fixed_output_len
|
||||||
|
)
|
||||||
|
return filtered_dataset
|
||||||
|
|
||||||
|
|
||||||
|
def sample_loogle_requests(
|
||||||
|
dataset_path: str,
|
||||||
|
num_requests: int,
|
||||||
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
|
disable_shuffle: bool = False,
|
||||||
|
enable_multiturn: bool = True,
|
||||||
|
enable_shared_prefix: bool = False,
|
||||||
|
fixed_output_len: Optional[int] = None,
|
||||||
|
) -> SampleOutput:
|
||||||
|
if fixed_output_len is not None and fixed_output_len < 4:
|
||||||
|
raise ValueError("output_len too small")
|
||||||
|
|
||||||
|
# Load the dataset
|
||||||
|
dataset = []
|
||||||
|
with open(dataset_path) as f:
|
||||||
|
while True:
|
||||||
|
line = f.readline()
|
||||||
|
if not line:
|
||||||
|
break
|
||||||
|
dataset.append(json.loads(line))
|
||||||
|
|
||||||
|
# Keep one conversation in one list
|
||||||
|
new_dataset = []
|
||||||
|
# TODO: Add shared prefix support for loogle
|
||||||
|
# NOTE: Now we preprocess it only for chat
|
||||||
|
for data in dataset:
|
||||||
|
chat = []
|
||||||
|
if (
|
||||||
|
"qa_pairs" not in data
|
||||||
|
or data["qa_pairs"] == "none"
|
||||||
|
or len(data["qa_pairs"]) == 0
|
||||||
|
):
|
||||||
|
# If Q is none (for summarization),
|
||||||
|
# We add a question for summarization
|
||||||
|
# And keep the summary up to 1024 words
|
||||||
|
chat.append(
|
||||||
|
(
|
||||||
|
"Input: "
|
||||||
|
+ data["input"]
|
||||||
|
+ " Question: "
|
||||||
|
+ "Please summarize the input",
|
||||||
|
data["input"][:1024],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
new_dataset.append(chat)
|
||||||
|
else:
|
||||||
|
qa_pairs = eval(data["qa_pairs"])
|
||||||
|
for i, qa in enumerate(qa_pairs):
|
||||||
|
if i == 0 or enable_shared_prefix:
|
||||||
|
# Combine input with the first Q
|
||||||
|
chat.append(
|
||||||
|
("Input: " + data["input"] + " Question: " + qa["Q"], qa["A"])
|
||||||
|
)
|
||||||
|
elif enable_multiturn:
|
||||||
|
chat.append((qa["Q"], qa["A"]))
|
||||||
|
|
||||||
|
new_dataset.append(chat)
|
||||||
|
|
||||||
|
# Shuffle the dataset.
|
||||||
|
if not disable_shuffle:
|
||||||
|
random.shuffle(new_dataset)
|
||||||
|
|
||||||
|
# Filter out sequences that are too long or too short
|
||||||
|
filtered_dataset: SampleOutput = common_filter_chat(
|
||||||
|
num_requests, new_dataset, tokenizer, 4, None, None, None, fixed_output_len
|
||||||
|
)
|
||||||
|
return filtered_dataset
|
||||||
|
|
||||||
|
|
||||||
|
def sample_nextqa_requests(
|
||||||
|
dataset_path: str,
|
||||||
|
num_requests: int,
|
||||||
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
|
max_frames: int, # Specific for video
|
||||||
|
model_path: str,
|
||||||
|
disable_shuffle: bool = False,
|
||||||
|
enable_multiturn: bool = True, # No multiturn support for now
|
||||||
|
backend: str = "sglang-oai",
|
||||||
|
chat_template_name: Optional[str] = None,
|
||||||
|
fixed_output_len: Optional[int] = None,
|
||||||
|
) -> SampleOutput:
|
||||||
|
"""
|
||||||
|
Example of messages:
|
||||||
|
message = {
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "image_url", "image_url": {"url": base64_data}},
|
||||||
|
{"type": "text", "text": video.prompt},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
if fixed_output_len is None:
|
||||||
|
fixed_output_len = 4096
|
||||||
|
|
||||||
|
# TODO: Check for multiturn
|
||||||
|
dataset = NExTQALoader(video_dir=dataset_path, max_frames=max_frames)
|
||||||
|
new_dataset = []
|
||||||
|
for v in dataset:
|
||||||
|
new_dataset.append(v)
|
||||||
|
|
||||||
|
if not disable_shuffle:
|
||||||
|
random.shuffle(new_dataset)
|
||||||
|
|
||||||
|
# TODO: prompt len can get from server side
|
||||||
|
filtered_dataset = []
|
||||||
|
l = 0
|
||||||
|
while l < num_requests:
|
||||||
|
for i in range(len(new_dataset)):
|
||||||
|
if l == num_requests:
|
||||||
|
break
|
||||||
|
|
||||||
|
video = new_dataset[i]
|
||||||
|
|
||||||
|
# text prompt
|
||||||
|
prompt = video.prompt
|
||||||
|
|
||||||
|
# NOTE: Chat Template is a must for video benchmark because we have to
|
||||||
|
# add special image token for later expansion
|
||||||
|
if backend == "sglang" or backend == "sglang-native":
|
||||||
|
if "chat_template" in tokenizer.init_kwargs:
|
||||||
|
chat_template = get_chat_template(tokenizer.get_chat_template())
|
||||||
|
elif chat_template_name is not None:
|
||||||
|
chat_template = get_chat_template(chat_template_name)
|
||||||
|
else:
|
||||||
|
chat_template = get_chat_template_by_model_path(model_path)
|
||||||
|
prompt = chat_template.image_token + prompt
|
||||||
|
|
||||||
|
prompt_token_ids = tokenizer(prompt).input_ids
|
||||||
|
prompt_len = len(prompt_token_ids)
|
||||||
|
output_len = fixed_output_len # max output len, not real output len
|
||||||
|
|
||||||
|
# video input
|
||||||
|
base64_data = encode_video_base64(video.path, video.num_frames)
|
||||||
|
|
||||||
|
# NOTE: This will be replaced by the expanded length from the server
|
||||||
|
prompt_len += video.num_frames
|
||||||
|
|
||||||
|
# add to content
|
||||||
|
content = [
|
||||||
|
{"type": "image_url", "image_url": {"url": base64_data}},
|
||||||
|
{"type": "text", "text": prompt},
|
||||||
|
]
|
||||||
|
|
||||||
|
filtered_dataset.append([(content, prompt_len, output_len)])
|
||||||
|
l += 1
|
||||||
|
return filtered_dataset
|
||||||
|
|
||||||
|
|
||||||
|
def sample_random_requests(
|
||||||
|
input_len: int,
|
||||||
|
output_len: int,
|
||||||
|
num_prompts: int,
|
||||||
|
range_ratio: float,
|
||||||
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
|
dataset_path: str,
|
||||||
|
disable_shuffle: bool = False,
|
||||||
|
) -> SampleOutput:
|
||||||
|
|
||||||
|
input_lens = np.random.randint(
|
||||||
|
max(int(input_len * range_ratio), 1),
|
||||||
|
input_len + 1,
|
||||||
|
size=num_prompts,
|
||||||
|
)
|
||||||
|
output_lens = np.random.randint(
|
||||||
|
int(output_len * range_ratio),
|
||||||
|
output_len + 1,
|
||||||
|
size=num_prompts,
|
||||||
|
)
|
||||||
|
|
||||||
|
if True:
|
||||||
|
# Sample token ids from ShareGPT and repeat/truncate them to satisfy the input_lens
|
||||||
|
|
||||||
|
# Download sharegpt if necessary
|
||||||
|
if not os.path.isfile(dataset_path):
|
||||||
|
dataset_path = download_and_cache_file(SHAREGPT_URL)
|
||||||
|
|
||||||
|
# Load the dataset.
|
||||||
|
with open(dataset_path) as f:
|
||||||
|
dataset = json.load(f)
|
||||||
|
# Filter out the conversations with less than 2 turns.
|
||||||
|
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
|
||||||
|
# Only keep the first two turns of each conversation.
|
||||||
|
dataset = [
|
||||||
|
(data["conversations"][0]["value"], data["conversations"][1]["value"])
|
||||||
|
for data in dataset
|
||||||
|
]
|
||||||
|
|
||||||
|
if not disable_shuffle:
|
||||||
|
# Shuffle the dataset.
|
||||||
|
random.shuffle(dataset)
|
||||||
|
|
||||||
|
# Filter out sequences that are too long or too short
|
||||||
|
input_requests: SampleOutput = []
|
||||||
|
for data in dataset:
|
||||||
|
i = len(input_requests)
|
||||||
|
if i == num_prompts:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Tokenize the prompts and completions.
|
||||||
|
prompt = data[0]
|
||||||
|
prompt_token_ids = tokenizer.encode(prompt)
|
||||||
|
prompt_len = len(prompt_token_ids)
|
||||||
|
|
||||||
|
# Skip empty prompt
|
||||||
|
if prompt_len == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if prompt_len > input_lens[i]:
|
||||||
|
input_ids = prompt_token_ids[: input_lens[i]]
|
||||||
|
else:
|
||||||
|
ratio = (input_lens[i] + prompt_len - 1) // prompt_len
|
||||||
|
input_ids = (prompt_token_ids * ratio)[: input_lens[i]]
|
||||||
|
prompt = tokenizer.decode(input_ids)
|
||||||
|
input_requests.append([(prompt, int(input_lens[i]), int(output_lens[i]))])
|
||||||
|
else:
|
||||||
|
# Sample token ids from random integers. This can cause some NaN issues.
|
||||||
|
offsets = np.random.randint(0, tokenizer.vocab_size, size=num_prompts)
|
||||||
|
input_requests = []
|
||||||
|
for i in range(num_prompts):
|
||||||
|
prompt = tokenizer.decode(
|
||||||
|
[
|
||||||
|
(offsets[i] + i + j) % tokenizer.vocab_size
|
||||||
|
for j in range(input_lens[i])
|
||||||
|
]
|
||||||
|
)
|
||||||
|
input_requests.append([(prompt, int(input_lens[i]), int(output_lens[i]))])
|
||||||
|
|
||||||
|
print(f"#Input tokens: {np.sum(input_lens)}")
|
||||||
|
print(f"#Output tokens: {np.sum(output_lens)}")
|
||||||
|
return input_requests
|
||||||
|
|
||||||
|
|
||||||
|
def gen_prompt(tokenizer, token_num):
|
||||||
|
"""Generate a random prompt of specified token length using tokenizer vocabulary."""
|
||||||
|
all_available_tokens = list(tokenizer.get_vocab().values())
|
||||||
|
selected_tokens = random.choices(all_available_tokens, k=token_num)
|
||||||
|
return tokenizer.decode(selected_tokens)
|
||||||
|
|
||||||
|
|
||||||
|
def get_gen_prefix_cache_path(args, tokenizer):
|
||||||
|
"""Create cache directory under ~/.cache/sglang/benchmark"""
|
||||||
|
cache_dir = Path.home() / ".cache" / "sglang" / "benchmark"
|
||||||
|
|
||||||
|
# Create a unique cache filename based on the generation parameters
|
||||||
|
cache_key = (
|
||||||
|
f"gen_prefix_{args.gen_num_groups}_{args.gen_prompts_per_group}_"
|
||||||
|
f"{args.gen_system_prompt_len}_{args.gen_question_len}_{args.gen_output_len}_"
|
||||||
|
f"{tokenizer.__class__.__name__}.pkl"
|
||||||
|
)
|
||||||
|
return cache_dir / cache_key
|
||||||
|
|
||||||
|
|
||||||
|
def sample_generated_shared_prefix_requests(
|
||||||
|
num_groups: int,
|
||||||
|
prompts_per_group: int,
|
||||||
|
system_prompt_len: int,
|
||||||
|
question_len: int,
|
||||||
|
output_len: int,
|
||||||
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
|
args,
|
||||||
|
disable_shuffle: bool = False,
|
||||||
|
) -> SampleOutput:
|
||||||
|
"""Generate benchmark requests with shared system prompts using random tokens and caching."""
|
||||||
|
cache_path = get_gen_prefix_cache_path(args, tokenizer)
|
||||||
|
|
||||||
|
# Try to load from cache first
|
||||||
|
if cache_path.exists():
|
||||||
|
print(f"\nLoading cached generated input data from {cache_path}")
|
||||||
|
with open(cache_path, "rb") as f:
|
||||||
|
return pickle.load(f)
|
||||||
|
|
||||||
|
print("\nGenerating new input data...")
|
||||||
|
|
||||||
|
# Generate system prompts for each group
|
||||||
|
system_prompts = []
|
||||||
|
for _ in range(num_groups):
|
||||||
|
system_prompt = gen_prompt(tokenizer, system_prompt_len)
|
||||||
|
system_prompts.append(system_prompt)
|
||||||
|
|
||||||
|
# Generate questions
|
||||||
|
questions = []
|
||||||
|
for _ in range(num_groups * prompts_per_group):
|
||||||
|
question = gen_prompt(tokenizer, question_len)
|
||||||
|
questions.append(question)
|
||||||
|
|
||||||
|
# Combine system prompts with questions
|
||||||
|
input_requests = []
|
||||||
|
total_input_tokens = 0
|
||||||
|
total_output_tokens = 0
|
||||||
|
|
||||||
|
for group_idx in tqdm(range(num_groups), desc="Generating system prompt"):
|
||||||
|
system_prompt = system_prompts[group_idx]
|
||||||
|
input_requests.append([])
|
||||||
|
for prompt_idx in tqdm(
|
||||||
|
range(prompts_per_group), desc="Generating questions", leave=False
|
||||||
|
):
|
||||||
|
question = questions[group_idx * prompts_per_group + prompt_idx]
|
||||||
|
full_prompt = f"{system_prompt}\n\n{question}"
|
||||||
|
prompt_len = len(tokenizer.encode(full_prompt))
|
||||||
|
input_requests[-1].append((full_prompt, prompt_len, output_len))
|
||||||
|
total_input_tokens += prompt_len
|
||||||
|
total_output_tokens += output_len
|
||||||
|
|
||||||
|
if not disable_shuffle:
|
||||||
|
# Shuffle questions
|
||||||
|
random.shuffle(input_requests)
|
||||||
|
|
||||||
|
# Print statistics
|
||||||
|
print(f"\nGenerated shared prefix dataset statistics:")
|
||||||
|
print(f"Number of groups: {num_groups}")
|
||||||
|
print(f"Prompts per group: {prompts_per_group}")
|
||||||
|
print(f"Total prompts: {len(input_requests) * prompts_per_group}")
|
||||||
|
print(f"Total input tokens: {total_input_tokens}")
|
||||||
|
print(f"Total output tokens: {total_output_tokens}")
|
||||||
|
print(
|
||||||
|
f"Average system prompt length: {sum(len(tokenizer.encode(sp)) for sp in system_prompts) / len(system_prompts):.1f} tokens"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"Average question length: {sum(len(tokenizer.encode(q)) for q in questions) / len(questions):.1f} tokens\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save to cache
|
||||||
|
cache_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
print(f"Caching generated input data to {cache_path}")
|
||||||
|
with open(cache_path, "wb") as f:
|
||||||
|
pickle.dump(input_requests, f)
|
||||||
|
|
||||||
|
return input_requests
|
||||||
|
|
||||||
|
|
||||||
|
def get_dataset(args, tokenizer):
|
||||||
|
if args.dataset_name == "sharegpt":
|
||||||
|
input_requests = sample_sharegpt_requests(
|
||||||
|
dataset_path=args.dataset_path,
|
||||||
|
num_requests=args.num_prompts,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
disable_shuffle=args.disable_shuffle,
|
||||||
|
enable_multiturn=args.enable_multiturn,
|
||||||
|
fixed_output_len=args.fixed_output_len,
|
||||||
|
)
|
||||||
|
elif args.dataset_name == "ultrachat":
|
||||||
|
input_requests = sample_ultrachat_requests(
|
||||||
|
dataset_path=args.dataset_path,
|
||||||
|
num_requests=args.num_prompts,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
disable_shuffle=args.disable_shuffle,
|
||||||
|
enable_multiturn=args.enable_multiturn,
|
||||||
|
fixed_output_len=args.fixed_output_len,
|
||||||
|
)
|
||||||
|
elif args.dataset_name == "loogle":
|
||||||
|
input_requests = sample_loogle_requests(
|
||||||
|
dataset_path=args.dataset_path,
|
||||||
|
num_requests=args.num_prompts,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
disable_shuffle=args.disable_shuffle,
|
||||||
|
enable_multiturn=args.enable_multiturn,
|
||||||
|
enable_shared_prefix=args.enable_shared_prefix,
|
||||||
|
fixed_output_len=args.fixed_output_len,
|
||||||
|
)
|
||||||
|
elif args.dataset_name == "nextqa":
|
||||||
|
input_requests = sample_nextqa_requests(
|
||||||
|
dataset_path=args.dataset_path,
|
||||||
|
num_requests=args.num_prompts,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
max_frames=args.max_frames,
|
||||||
|
model_path=args.model,
|
||||||
|
disable_shuffle=args.disable_shuffle,
|
||||||
|
enable_multiturn=args.enable_multiturn,
|
||||||
|
backend=args.backend,
|
||||||
|
chat_template_name=args.chat_template,
|
||||||
|
fixed_output_len=args.fixed_output_len,
|
||||||
|
)
|
||||||
|
elif args.dataset_name == "random":
|
||||||
|
input_requests = sample_random_requests(
|
||||||
|
input_len=args.random_input_len,
|
||||||
|
output_len=args.random_output_len,
|
||||||
|
num_prompts=args.num_prompts,
|
||||||
|
range_ratio=args.random_range_ratio,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
dataset_path=args.dataset_path,
|
||||||
|
)
|
||||||
|
elif args.dataset_name == "generated-shared-prefix":
|
||||||
|
input_requests = sample_generated_shared_prefix_requests(
|
||||||
|
num_groups=args.gen_num_groups,
|
||||||
|
prompts_per_group=args.gen_prompts_per_group,
|
||||||
|
system_prompt_len=args.gen_system_prompt_len,
|
||||||
|
question_len=args.gen_question_len,
|
||||||
|
output_len=args.gen_output_len,
|
||||||
|
args=args,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown dataset: {args.dataset_name}")
|
||||||
|
return input_requests
|
||||||
|
|
@ -0,0 +1,66 @@
|
||||||
|
#!/usr/bin/bash
|
||||||
|
|
||||||
|
# The usage function
|
||||||
|
usage() {
|
||||||
|
echo "Usage: $0 {sharegpt|ultragpt|loogle|nextqa|all}"
|
||||||
|
exit 1
|
||||||
|
}
|
||||||
|
|
||||||
|
# The download function
|
||||||
|
download() {
|
||||||
|
case "$1" in
|
||||||
|
sharegpt)
|
||||||
|
echo $1
|
||||||
|
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
|
||||||
|
;;
|
||||||
|
ultragpt)
|
||||||
|
echo $1
|
||||||
|
# Questions about the world
|
||||||
|
wget https://cloud.tsinghua.edu.cn/seafhttp/files/be1d7b87-22ca-449e-a6a7-c61d1ea7e010/ultrachat_release_230407.json
|
||||||
|
# Writing and Creation
|
||||||
|
wget https://cloud.tsinghua.edu.cn/seafhttp/files/61742d2a-25e2-4d08-b2b9-15f47ae50ace/ultrachat_material_release_230417.json
|
||||||
|
wget https://cloud.tsinghua.edu.cn/seafhttp/files/f71f6aa6-d346-4b16-85b7-8502efa3d608/ultrachat_material_release_230412.json
|
||||||
|
# External materials
|
||||||
|
wget https://cloud.tsinghua.edu.cn/seafhttp/files/42d22e28-e899-4975-a70f-5eda163e265d/ultrachat_existent_material_release_230420.json.gz
|
||||||
|
gunzip ultrachat_existent_material_release_230420.json.gz
|
||||||
|
;;
|
||||||
|
loogle)
|
||||||
|
echo $1
|
||||||
|
git lfs install
|
||||||
|
git clone git@hf.co:datasets/bigainlco/LooGLE
|
||||||
|
unzip LooGLE/data.zip
|
||||||
|
;;
|
||||||
|
nextqa)
|
||||||
|
echo $1
|
||||||
|
git lfs install
|
||||||
|
git clone https://huggingface.co/datasets/lmms-lab/NExTQA
|
||||||
|
unzip NExTQA/videos.zip
|
||||||
|
;;
|
||||||
|
*)
|
||||||
|
usage
|
||||||
|
exit 1
|
||||||
|
;;
|
||||||
|
esac
|
||||||
|
}
|
||||||
|
|
||||||
|
# Arg check
|
||||||
|
if [ "$#" -ne 1 ]; then
|
||||||
|
usage
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Invoke
|
||||||
|
|
||||||
|
case "$1" in
|
||||||
|
sharegpt|ultragpt|loogle|nextqa)
|
||||||
|
download "$1"
|
||||||
|
;;
|
||||||
|
all)
|
||||||
|
download sharegpt
|
||||||
|
download ultragpt
|
||||||
|
download loogle
|
||||||
|
download nextqa
|
||||||
|
;;
|
||||||
|
*)
|
||||||
|
usage
|
||||||
|
;;
|
||||||
|
esac
|
||||||
|
|
@ -0,0 +1,159 @@
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import av
|
||||||
|
from datasets import load_dataset
|
||||||
|
|
||||||
|
|
||||||
|
def find_video_files(video_dir) -> List[str]:
|
||||||
|
if os.path.isfile(video_dir):
|
||||||
|
return [video_dir]
|
||||||
|
|
||||||
|
video_files = []
|
||||||
|
for root, dirs, files in os.walk(video_dir):
|
||||||
|
for file in files:
|
||||||
|
if file.endswith((".mp4", ".avi", ".mov")):
|
||||||
|
video_files.append(os.path.join(root, file))
|
||||||
|
# if file is dir
|
||||||
|
elif os.path.isdir(file):
|
||||||
|
video_files.extend(find_video_files(file))
|
||||||
|
return video_files
|
||||||
|
|
||||||
|
|
||||||
|
def video_frames(video_path, max_frames) -> int:
|
||||||
|
container = av.open(video_path)
|
||||||
|
total_frames = container.streams.video[0].frames
|
||||||
|
return min(total_frames, max_frames)
|
||||||
|
|
||||||
|
|
||||||
|
class Video:
|
||||||
|
def __init__(self, video_path, num_frames):
|
||||||
|
self.path = video_path
|
||||||
|
self.num_frames = num_frames
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return f"Video({self.path}, {self.num_frames})"
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return iter((self.path, self.num_frames))
|
||||||
|
|
||||||
|
|
||||||
|
class VideoPrompt(Video):
|
||||||
|
def __init__(self, video_path, num_frames, prompt):
|
||||||
|
super().__init__(video_path, num_frames)
|
||||||
|
self.prompt = prompt
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return f"VideoPrompt({self.path}, {self.num_frames}, {self.prompt})"
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return iter((self.path, self.num_frames, self.prompt))
|
||||||
|
|
||||||
|
|
||||||
|
class VideoLoader:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class VideoFileLoader(VideoLoader):
|
||||||
|
"""
|
||||||
|
Load all the videos in a directory
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, video_dir, batch_size=1, max_frames=sys.maxsize):
|
||||||
|
super().__init__()
|
||||||
|
self.video_dir = video_dir
|
||||||
|
self.video_files = find_video_files(video_dir)
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.max_frames = max_frames
|
||||||
|
print(f"batch_size: {batch_size}, max_frames: {max_frames}")
|
||||||
|
|
||||||
|
def __iter__(self): # (file, number of frames)
|
||||||
|
if self.batch_size == 1:
|
||||||
|
for video_file in self.video_files:
|
||||||
|
yield Video(video_file, video_frames(video_file, self.max_frames))
|
||||||
|
else:
|
||||||
|
batch = []
|
||||||
|
for video_file in self.video_files:
|
||||||
|
video = Video(video_file, video_frames(video_file, self.max_frames))
|
||||||
|
batch.append(video)
|
||||||
|
if len(batch) == self.batch_size:
|
||||||
|
yield batch
|
||||||
|
batch = []
|
||||||
|
|
||||||
|
|
||||||
|
class NExTQALoader(VideoLoader):
|
||||||
|
"""
|
||||||
|
Load vdideos and prompts from NExT dataset
|
||||||
|
set: train, test or validation
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, video_dir, batch_size=1, max_frames=sys.maxsize, dset="test", task="OE"
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
task: 'MV' or 'OE'
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.task = task
|
||||||
|
print(f"Loading the {dset} data of {task} from lmms-lab/NExTQA")
|
||||||
|
self.ds = load_dataset("lmms-lab/NExTQA", task)
|
||||||
|
self.ds = self.ds[dset]
|
||||||
|
|
||||||
|
# self.n = ds.num_rows
|
||||||
|
self.video_dir = video_dir
|
||||||
|
self.video_files = find_video_files(video_dir)
|
||||||
|
self.video_to_path = dict()
|
||||||
|
for video_file in self.video_files:
|
||||||
|
video_id = video_file.split("/")[-1].split(".")[0]
|
||||||
|
self.video_to_path[video_id] = video_file
|
||||||
|
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.max_frames = max_frames
|
||||||
|
|
||||||
|
def get_video_prompt(self, entry, max_frames) -> VideoPrompt:
|
||||||
|
# Get video
|
||||||
|
video_id = entry["video"]
|
||||||
|
video_path = self.video_to_path[video_id]
|
||||||
|
assert os.path.exists(video_path), f"Video not found: {video_path}"
|
||||||
|
num_frames = min(entry["frame_count"], max_frames)
|
||||||
|
video = Video(video_path, num_frames)
|
||||||
|
prompt = entry["question"] + "?"
|
||||||
|
if self.task == "MC": # add choices
|
||||||
|
prompt += f' a0: {entry["a0"]}, a1: {entry["a1"]}, a2: {entry["a2"]}, a3: {entry["a3"]}'
|
||||||
|
return VideoPrompt(video_path, num_frames, prompt)
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
if self.batch_size == 1:
|
||||||
|
for entry in self.ds:
|
||||||
|
yield self.get_video_prompt(entry, self.max_frames)
|
||||||
|
else:
|
||||||
|
batch = []
|
||||||
|
for entry in self.ds:
|
||||||
|
video = self.get_video_prompt(entry, self.max_frames)
|
||||||
|
batch.append(video)
|
||||||
|
if len(batch) == self.batch_size:
|
||||||
|
yield batch
|
||||||
|
batch = []
|
||||||
|
|
||||||
|
|
||||||
|
# main
|
||||||
|
if __name__ == "__main__":
|
||||||
|
video_dir = "./videos"
|
||||||
|
# video_loader = VideoFileLoader(video_dir, batch_size=16)
|
||||||
|
# for batch in video_loader:
|
||||||
|
# print(f"Number of videos in batch: {len(batch)}")
|
||||||
|
# for video_file, num_frames in batch:
|
||||||
|
# print(f"Video: {video_file} number of frames: {num_frames}")
|
||||||
|
|
||||||
|
video_loader = NExTQALoader(video_dir, batch_size=16, dset="test", task="OE")
|
||||||
|
for batch in video_loader:
|
||||||
|
print(f"Number of videos in batch: {len(batch)}")
|
||||||
|
for video_file, num_frames, prompt in batch:
|
||||||
|
print(
|
||||||
|
f"Video: {video_file} number of frames: {num_frames}, prompt: {prompt}"
|
||||||
|
)
|
||||||
|
# break
|
||||||
|
# for video_file, prompt in batch:
|
||||||
|
# print(f"Video: {video_file} prompt: {prompt}")
|
||||||
|
# break
|
||||||
|
|
@ -0,0 +1,60 @@
|
||||||
|
## Run benchmark
|
||||||
|
|
||||||
|
### Build dataset
|
||||||
|
```
|
||||||
|
pip install wikipedia
|
||||||
|
python3 build_dataset.py
|
||||||
|
```
|
||||||
|
|
||||||
|
### Dependencies
|
||||||
|
|
||||||
|
```
|
||||||
|
llama_cpp_python 0.2.19
|
||||||
|
guidance 0.1.10
|
||||||
|
vllm 0.2.5
|
||||||
|
outlines 0.0.22
|
||||||
|
```
|
||||||
|
|
||||||
|
### Benchmark sglang
|
||||||
|
|
||||||
|
Run Llama-7B
|
||||||
|
|
||||||
|
```
|
||||||
|
python3 -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
|
||||||
|
```
|
||||||
|
|
||||||
|
Run Mixtral-8x7B
|
||||||
|
|
||||||
|
```
|
||||||
|
python3 -m sglang.launch_server --model-path mistralai/Mixtral-8x7B-Instruct-v0.1 --port 30000 --tp-size 8
|
||||||
|
```
|
||||||
|
|
||||||
|
Benchmark
|
||||||
|
|
||||||
|
```
|
||||||
|
python3 bench_sglang.py --num-questions 10
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
### Benchmark Outlines + vLLM
|
||||||
|
|
||||||
|
Run Llama-7B
|
||||||
|
|
||||||
|
```
|
||||||
|
python3 -m outlines.serve.serve --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000
|
||||||
|
```
|
||||||
|
|
||||||
|
Benchmark
|
||||||
|
|
||||||
|
```
|
||||||
|
python3 bench_other.py --backend outlines --num-questions 10
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
### Benchmark guidance
|
||||||
|
|
||||||
|
Run Llama-7B and benchmark
|
||||||
|
|
||||||
|
```
|
||||||
|
python3 bench_other.py --backend guidance --num-questions 10 --parallel 1 --n-ctx 4096 --model-path path/to/gguf
|
||||||
|
```
|
||||||
|
|
@ -0,0 +1,98 @@
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from sglang.lang.ir import REGEX_FLOAT, REGEX_INT, REGEX_STR
|
||||||
|
from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate
|
||||||
|
from sglang.utils import dump_state_text, read_jsonl
|
||||||
|
|
||||||
|
REGEX_LIST = r"\[(" + REGEX_STR + ", )*" + REGEX_STR + r"\]"
|
||||||
|
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
def json_decode(document, generate):
|
||||||
|
s = "Please extract the information of a city from the following wikipedia page.\n"
|
||||||
|
s += "Page begin.\n" + document + "Page end.\n"
|
||||||
|
s += "Here is the name, country, and symbol of the city in JSON format.\n"
|
||||||
|
s += "{\n"
|
||||||
|
s += ' "name": '
|
||||||
|
s += generate(s, max_tokens=8, regex=REGEX_STR + ",") + "\n"
|
||||||
|
s += ' "country": '
|
||||||
|
s += generate(s, max_tokens=8, regex=REGEX_STR + ",") + "\n"
|
||||||
|
s += ' "latitude": '
|
||||||
|
s += generate(s, max_tokens=8, regex=REGEX_FLOAT + ",") + "\n"
|
||||||
|
s += ' "population": '
|
||||||
|
s += generate(s, max_tokens=8, regex=REGEX_INT + ",") + "\n"
|
||||||
|
s += ' "top 3 landmarks": '
|
||||||
|
s += generate(s, max_tokens=24, regex=REGEX_LIST) + "\n"
|
||||||
|
s += "}\n"
|
||||||
|
|
||||||
|
return s
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
lines = read_jsonl(args.data_path)
|
||||||
|
arguments = []
|
||||||
|
for i in range(len(lines[: args.num_questions])):
|
||||||
|
arguments.append(
|
||||||
|
{
|
||||||
|
"document": lines[i]["document"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
states = [None] * len(arguments)
|
||||||
|
|
||||||
|
# Select backend
|
||||||
|
call_generate = partial(get_call_generate(args), temperature=0)
|
||||||
|
|
||||||
|
# Run requests
|
||||||
|
def get_one_answer(i):
|
||||||
|
states[i] = json_decode(generate=call_generate, **arguments[i])
|
||||||
|
|
||||||
|
tic = time.time()
|
||||||
|
if args.parallel == 1:
|
||||||
|
for i in tqdm(range(len(arguments))):
|
||||||
|
get_one_answer(i)
|
||||||
|
else:
|
||||||
|
with ThreadPoolExecutor(args.parallel) as executor:
|
||||||
|
rets = list(
|
||||||
|
tqdm(
|
||||||
|
executor.map(get_one_answer, list(range(len(arguments)))),
|
||||||
|
total=len(arguments),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for _ in rets:
|
||||||
|
pass
|
||||||
|
|
||||||
|
latency = time.time() - tic
|
||||||
|
|
||||||
|
# Compute accuracy
|
||||||
|
print(f"Latency: {latency:.3f}")
|
||||||
|
|
||||||
|
# Write results
|
||||||
|
dump_state_text(f"tmp_output_{args.backend}.txt", states)
|
||||||
|
|
||||||
|
with open(args.result_file, "a") as fout:
|
||||||
|
value = {
|
||||||
|
"task": "json_decode_regex",
|
||||||
|
"backend": args.backend,
|
||||||
|
"num_gpus": 1,
|
||||||
|
"latency": round(latency, 3),
|
||||||
|
"num_requests": args.num_questions,
|
||||||
|
"other": {
|
||||||
|
"parallel": args.parallel,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
fout.write(json.dumps(value) + "\n")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--data-path", type=str, default="questions.jsonl")
|
||||||
|
parser.add_argument("--num-questions", type=int, default=20)
|
||||||
|
args = add_common_other_args_and_parse(parser)
|
||||||
|
main(args)
|
||||||
|
|
@ -0,0 +1,101 @@
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
|
||||||
|
import sglang as sgl
|
||||||
|
from sglang.lang.ir import REGEX_FLOAT, REGEX_INT, REGEX_STR
|
||||||
|
from sglang.test.test_utils import (
|
||||||
|
add_common_sglang_args_and_parse,
|
||||||
|
select_sglang_backend,
|
||||||
|
)
|
||||||
|
from sglang.utils import dump_state_text, read_jsonl
|
||||||
|
|
||||||
|
REGEX_LIST = r"\[(" + REGEX_STR + ", )*" + REGEX_STR + r"\]"
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
@sgl.function
|
||||||
|
def json_warm_up(s):
|
||||||
|
s += "The information about Hogwarts is in the following JSON format.\n"
|
||||||
|
with s.var_scope("json_output"):
|
||||||
|
s += "{\n"
|
||||||
|
s += ' "name": ' + sgl.gen("name", max_tokens=8, regex=REGEX_STR + ",") + "\n"
|
||||||
|
s += ' "country": ' + sgl.gen("country", max_tokens=8, regex=REGEX_STR + ",") + "\n"
|
||||||
|
s += ' "latitude": ' + sgl.gen("latitude", max_tokens=8, regex=REGEX_FLOAT + ",") + "\n"
|
||||||
|
s += ' "population": ' + sgl.gen("population", max_tokens=8, regex=REGEX_INT + ",") + "\n"
|
||||||
|
s += ' "top 3 landmarks": ' + sgl.gen( "landmarks", max_tokens=24, regex=REGEX_LIST) + "\n"
|
||||||
|
s += "}\n"
|
||||||
|
print(f'The warmp up json result is:\n{s["json_output"]}')
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
@sgl.function
|
||||||
|
def json_decode(s, document):
|
||||||
|
s += "Please extract the information of a city from the following wikipedia page.\n"
|
||||||
|
s += "Page begin.\n" + document + "Page end.\n"
|
||||||
|
s += "Here is the name, country, and symbol of the city in JSON format.\n"
|
||||||
|
with s.var_scope("json_output"):
|
||||||
|
s += "{\n"
|
||||||
|
s += ' "name": ' + sgl.gen("name", max_tokens=8, regex=REGEX_STR + ",") + "\n"
|
||||||
|
s += ' "country": ' + sgl.gen("country", max_tokens=8, regex=REGEX_STR + ",") + "\n"
|
||||||
|
s += ' "latitude": ' + sgl.gen("latitude", max_tokens=8, regex=REGEX_FLOAT + ",") + "\n"
|
||||||
|
s += ' "population": ' + sgl.gen("population", max_tokens=8, regex=REGEX_INT + ",") + "\n"
|
||||||
|
s += ' "top 3 landmarks": ' + sgl.gen( "landmarks", max_tokens=24, regex=REGEX_LIST) + "\n"
|
||||||
|
s += "}\n"
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
lines = read_jsonl(args.data_path)
|
||||||
|
lines = list(lines)
|
||||||
|
arguments = []
|
||||||
|
for i in range(len(lines[: args.num_questions])):
|
||||||
|
arguments.append(
|
||||||
|
{
|
||||||
|
"document": lines[i]["document"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Select backend
|
||||||
|
backend = select_sglang_backend(args)
|
||||||
|
sgl.set_default_backend(backend)
|
||||||
|
|
||||||
|
# Warm up
|
||||||
|
json_warm_up.run().sync()
|
||||||
|
|
||||||
|
# Run requests
|
||||||
|
tic = time.time()
|
||||||
|
states = json_decode.run_batch(
|
||||||
|
arguments, temperature=0, num_threads=args.parallel, progress_bar=True
|
||||||
|
)
|
||||||
|
latency = time.time() - tic
|
||||||
|
|
||||||
|
# Compute accuracy
|
||||||
|
print(f"Latency: {latency:.3f}")
|
||||||
|
|
||||||
|
# Write results
|
||||||
|
dump_state_text(f"tmp_output_{args.backend}.txt", states)
|
||||||
|
|
||||||
|
with open(f"tmp_{args.backend}_json_results.txt", "w") as fout:
|
||||||
|
for state in states:
|
||||||
|
fout.write(state["json_output"] + "\n")
|
||||||
|
|
||||||
|
with open(args.result_file, "a") as fout:
|
||||||
|
value = {
|
||||||
|
"task": "json_decode_regex",
|
||||||
|
"backend": args.backend,
|
||||||
|
"num_gpus": 1,
|
||||||
|
"latency": round(latency, 3),
|
||||||
|
"num_requests": args.num_questions,
|
||||||
|
"other": {
|
||||||
|
"parallel": args.parallel,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
fout.write(json.dumps(value) + "\n")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--data-path", type=str, default="questions.jsonl")
|
||||||
|
parser.add_argument("--num-questions", type=int, default=20)
|
||||||
|
args = add_common_sglang_args_and_parse(parser)
|
||||||
|
main(args)
|
||||||
|
|
@ -0,0 +1,58 @@
|
||||||
|
import json
|
||||||
|
|
||||||
|
import transformers
|
||||||
|
import wikipedia
|
||||||
|
|
||||||
|
model_path = "meta-llama/Llama-2-7b-chat-hf"
|
||||||
|
t = transformers.AutoTokenizer.from_pretrained(model_path)
|
||||||
|
city_names = [
|
||||||
|
"los angles",
|
||||||
|
"london",
|
||||||
|
"tokyo",
|
||||||
|
"beijing",
|
||||||
|
"singapore",
|
||||||
|
"paris",
|
||||||
|
"dubai",
|
||||||
|
"sydney",
|
||||||
|
"moscow",
|
||||||
|
"rome",
|
||||||
|
"toronto",
|
||||||
|
"rio de janeiro",
|
||||||
|
"istanbul",
|
||||||
|
"berlin",
|
||||||
|
"auckland",
|
||||||
|
"buenos aires",
|
||||||
|
"mexico city",
|
||||||
|
"mumbai",
|
||||||
|
"seoul",
|
||||||
|
"bangkok",
|
||||||
|
"cairo",
|
||||||
|
"athens",
|
||||||
|
"jerusalem",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def get_content(city_name):
|
||||||
|
content = str(wikipedia.page(city_name).content)
|
||||||
|
content = content.replace("\n\n", "\n")
|
||||||
|
|
||||||
|
tokens = t.encode(content)
|
||||||
|
|
||||||
|
expected_tokens = 3000
|
||||||
|
truncate_len = int((expected_tokens / len(tokens)) * len(content))
|
||||||
|
truncate_content = content[:truncate_len]
|
||||||
|
truncate_tokens = t.encode(truncate_content)
|
||||||
|
|
||||||
|
# Count token
|
||||||
|
print(
|
||||||
|
f"city_name: {city_name}, #tokens: {len(tokens)}, #truncate tokens: {len(truncate_tokens)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return truncate_content
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
with open("questions.jsonl", "w") as fout:
|
||||||
|
for city_name in city_names:
|
||||||
|
truncate_content = get_content(city_name)
|
||||||
|
fout.write(json.dumps({"document": truncate_content}) + "\n")
|
||||||
|
|
@ -0,0 +1,88 @@
|
||||||
|
## Run benchmark
|
||||||
|
|
||||||
|
### Dependencies
|
||||||
|
|
||||||
|
```
|
||||||
|
llama_cpp_python 0.2.38
|
||||||
|
guidance 0.1.10
|
||||||
|
vllm 0.2.7
|
||||||
|
outlines 0.0.25
|
||||||
|
```
|
||||||
|
|
||||||
|
### Build dataset
|
||||||
|
|
||||||
|
When benchmarking long document information retrieval, run the following command to build the dataset:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install wikipedia
|
||||||
|
python3 build_dataset.py
|
||||||
|
```
|
||||||
|
|
||||||
|
### Benchmark sglang
|
||||||
|
|
||||||
|
Run Llama-7B
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python3 -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
|
||||||
|
```
|
||||||
|
|
||||||
|
Benchmark Character Generation
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python3 bench_sglang.py --mode character
|
||||||
|
```
|
||||||
|
|
||||||
|
Benchmark City Information Retrieval
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python3 bench_sglang.py --mode city
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
### Benchmark Outlines + vLLM
|
||||||
|
|
||||||
|
Run Llama-7B
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python3 -m outlines.serve.serve --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000
|
||||||
|
```
|
||||||
|
|
||||||
|
Benchmark Character Generation
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python3 bench_other.py --mode character --backend outlines
|
||||||
|
```
|
||||||
|
|
||||||
|
Benchmark City Information Retrieval
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python3 bench_other.py --mode city --backend outlines
|
||||||
|
```
|
||||||
|
|
||||||
|
### Benchmark guidance
|
||||||
|
|
||||||
|
Run Llama-7B and benchmark character generation
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python3 bench_other.py --mode character --backend guidance --parallel 1 --n-ctx 4096 --model-path path/to/gguf
|
||||||
|
```
|
||||||
|
|
||||||
|
Run Llama-7B and benchmark city information retrieval
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python3 bench_other.py --mode city --backend guidance --parallel 1 --n-ctx 4096 --model-path path/to/gguf
|
||||||
|
```
|
||||||
|
|
||||||
|
### Benchmark lmql
|
||||||
|
|
||||||
|
Run Llama-7B and benchmark character generation
|
||||||
|
|
||||||
|
```
|
||||||
|
python3 bench_other.py --mode character --backend lmql --parallel 1
|
||||||
|
```
|
||||||
|
|
||||||
|
Run Llama-7B and benchmark city information retrieval
|
||||||
|
|
||||||
|
```
|
||||||
|
python3 bench_other.py --mode city --backend lmql --parallel 1
|
||||||
|
```
|
||||||
|
|
@ -0,0 +1,288 @@
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import guidance
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate
|
||||||
|
from sglang.utils import dump_state_text, read_jsonl
|
||||||
|
|
||||||
|
# there are some FSM bugs with json regex converted from pydantic model
|
||||||
|
# here use a string regex instead
|
||||||
|
# regex_string = build_regex_from_object(HarryPoterRole)
|
||||||
|
character_regex = (
|
||||||
|
r"""\{\n"""
|
||||||
|
+ r""" "name": "[\w\d\s]{1,16}",\n"""
|
||||||
|
+ r""" "house": "(Gryffindor|Slytherin|Ravenclaw|Hufflepuff)",\n"""
|
||||||
|
+ r""" "blood status": "(Pure-blood|Half-blood|Muggle-born)",\n"""
|
||||||
|
+ r""" "occupation": "(student|teacher|auror|ministry of magic|death eater|order of the phoenix)",\n"""
|
||||||
|
+ r""" "wand": \{\n"""
|
||||||
|
+ r""" "wood": "[\w\d\s]{1,16}",\n"""
|
||||||
|
+ r""" "core": "[\w\d\s]{1,16}",\n"""
|
||||||
|
+ r""" "length": [0-9]{1,2}\.[0-9]{0,2}\n"""
|
||||||
|
+ r""" \},\n"""
|
||||||
|
+ r""" "alive": "(Alive|Deceased)",\n"""
|
||||||
|
+ r""" "patronus": "[\w\d\s]{1,16}",\n"""
|
||||||
|
+ r""" "bogart": "[\w\d\s]{1,16}"\n"""
|
||||||
|
+ r"""\}"""
|
||||||
|
)
|
||||||
|
|
||||||
|
city_regex = (
|
||||||
|
r"""\{\n"""
|
||||||
|
+ r""" "name": "[\w\d\s]{1,16}",\n"""
|
||||||
|
+ r""" "country": "[\w\d\s]{1,16}",\n"""
|
||||||
|
+ r""" "latitude": [-+]?[0-9]*\.?[0-9]{0,2},\n"""
|
||||||
|
+ r""" "population": [-+]?[0-9]{1,9},\n"""
|
||||||
|
+ r""" "top 3 landmarks": \["[\w\d\s]{1,16}", "[\w\d\s]{1,16}", "[\w\d\s]{1,16}"\]\n"""
|
||||||
|
+ r"""\}"""
|
||||||
|
)
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
def character_gen(name, generate):
|
||||||
|
s = name + " is a character in Harry Potter. Please fill in the following information about this character.\n"
|
||||||
|
s += generate(s, max_tokens=256, regex=character_regex)
|
||||||
|
return s
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
def city_gen(document, generate):
|
||||||
|
s = "Please extract the information of a city from the following wikipedia page.\n"
|
||||||
|
s += "Page begin.\n" + document + "Page end.\n"
|
||||||
|
s += "Here is the name, country, and symbol of the city in JSON format.\n"
|
||||||
|
s += generate(s, max_tokens=256, regex=city_regex)
|
||||||
|
return s
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
|
||||||
|
@guidance
|
||||||
|
def character_maker(lm, name):
|
||||||
|
regex_str_no_quote = r"[\w\d\s]+"
|
||||||
|
regex_float = r"[0-9]+\.[0-9]+"
|
||||||
|
lm += f"""\
|
||||||
|
{name} is a character in Harry Potter. Please fill in the following information about this character.
|
||||||
|
{{
|
||||||
|
"name": "{guidance.gen("name", max_tokens=16, regex=regex_str_no_quote)}",
|
||||||
|
"house": "{guidance.select(options=['Gryffindor', 'Slytherin', 'Ravenclaw', 'Hufflepuff'], name='house')}",
|
||||||
|
"blood status": "{guidance.select(options=['Pure-blood', 'Half-blood', 'Muggle-born'], name='blood status')}",
|
||||||
|
"occupation": "{guidance.select(options=['student', 'teacher', 'auror', 'ministry of magic', 'death eater', 'order of the phoenix'], name='occupation')}",
|
||||||
|
"wand": {{
|
||||||
|
"wood": "{guidance.gen("wood", max_tokens=16, regex=regex_str_no_quote)}",
|
||||||
|
"core": "{guidance.gen('core', max_tokens=16, regex=regex_str_no_quote)}",
|
||||||
|
"length": {guidance.gen('length', max_tokens=10, regex=regex_float)}
|
||||||
|
}},
|
||||||
|
"alive": "{guidance.select(options=['Alive', 'Deceased'], name='alive')}",
|
||||||
|
"patronus": "{guidance.gen('patronus', max_tokens=16, regex=regex_str_no_quote)}",
|
||||||
|
"bogart": "{guidance.gen('bogart', max_tokens=16, regex=regex_str_no_quote)}"
|
||||||
|
}}
|
||||||
|
"""
|
||||||
|
|
||||||
|
return lm
|
||||||
|
|
||||||
|
|
||||||
|
async def call_generate_lmql(
|
||||||
|
prompt, temperature, max_tokens, regex, max_len=4096, model=None, **kwargs
|
||||||
|
):
|
||||||
|
assert model is not None
|
||||||
|
import lmql
|
||||||
|
|
||||||
|
@lmql.query(model=model)
|
||||||
|
async def program(question, max_tokens, regex):
|
||||||
|
'''lmql
|
||||||
|
"""{question}[ANSWER]""" where len(TOKENS(ANSWER)) < max_tokens and REGEX(ANSWER, regex)
|
||||||
|
return ANSWER
|
||||||
|
'''
|
||||||
|
|
||||||
|
return await program(
|
||||||
|
question=prompt,
|
||||||
|
temperature=temperature,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
max_len=max_len,
|
||||||
|
regex=regex,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@guidance
|
||||||
|
def city_maker(lm, document):
|
||||||
|
regex_str_no_quote = r"[\w\d\s]+"
|
||||||
|
regex_float = r"[0-9]+\.[0-9]+"
|
||||||
|
lm += f"""\
|
||||||
|
Please extract the information of a city from the following wikipedia page.
|
||||||
|
Page begin.
|
||||||
|
{document}
|
||||||
|
Page end.
|
||||||
|
Here is the name, country, and symbol of the city in JSON format.
|
||||||
|
{{
|
||||||
|
"name": "{guidance.gen("name", max_tokens=16, regex=regex_str_no_quote)}",
|
||||||
|
"country": "{guidance.gen("country", max_tokens=16, regex=regex_str_no_quote)}",
|
||||||
|
"latitude": {guidance.gen("latitude", max_tokens=10, regex=regex_float)},
|
||||||
|
"population": {guidance.gen("population", max_tokens=10, regex=r"[0-9]+")},
|
||||||
|
"top 3 landmarks": [
|
||||||
|
"{guidance.gen("landmark1", max_tokens=16, regex=regex_str_no_quote)}", "{guidance.gen("landmark2", max_tokens=16, regex=regex_str_no_quote)}", "{guidance.gen("landmark3", max_tokens=16, regex=regex_str_no_quote)}"
|
||||||
|
]
|
||||||
|
}}
|
||||||
|
"""
|
||||||
|
|
||||||
|
return lm
|
||||||
|
|
||||||
|
|
||||||
|
def bench_character(args):
|
||||||
|
arguments = []
|
||||||
|
with open(args.data_path, "r") as f:
|
||||||
|
for line in f:
|
||||||
|
arguments.append({"name": line.strip()})
|
||||||
|
arguments = arguments[: args.num_jsons]
|
||||||
|
|
||||||
|
states = [None] * len(arguments)
|
||||||
|
|
||||||
|
# Select backend
|
||||||
|
if args.backend == "outlines":
|
||||||
|
call_generate = partial(get_call_generate(args), temperature=0)
|
||||||
|
|
||||||
|
def get_one_answer(i):
|
||||||
|
states[i] = character_gen(**arguments[i], generate=call_generate)
|
||||||
|
|
||||||
|
elif args.backend == "guidance":
|
||||||
|
model = guidance.models.LlamaCpp(
|
||||||
|
args.model_path,
|
||||||
|
n_gpu_layers=-1,
|
||||||
|
n_ctx=args.n_ctx,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_one_answer(i):
|
||||||
|
lm = model + character_maker(**arguments[i])
|
||||||
|
states[i] = lm
|
||||||
|
|
||||||
|
elif args.backend == "lmql":
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
import lmql
|
||||||
|
|
||||||
|
model = lmql.model(args.model_path, endpoint=f"{args.host}:{args.port}")
|
||||||
|
call_generate = partial(
|
||||||
|
call_generate_lmql,
|
||||||
|
model=model,
|
||||||
|
max_tokens=256,
|
||||||
|
regex=character_regex,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_one_answer_async(i):
|
||||||
|
states[i] = await call_generate(prompt=arguments[i]["name"], temperature=0)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid backend: {args.backend}")
|
||||||
|
|
||||||
|
tic = time.time()
|
||||||
|
|
||||||
|
if args.backend != "lmql":
|
||||||
|
if args.parallel == 1:
|
||||||
|
for i in tqdm(range(len(arguments))):
|
||||||
|
get_one_answer(i)
|
||||||
|
else:
|
||||||
|
with ThreadPoolExecutor(args.parallel) as executor:
|
||||||
|
rets = list(
|
||||||
|
tqdm(
|
||||||
|
executor.map(get_one_answer, list(range(len(arguments)))),
|
||||||
|
total=len(arguments),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for _ in rets:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
batches = []
|
||||||
|
for i in range(0, len(arguments), args.parallel):
|
||||||
|
batches.append(list(range(i, min(i + args.parallel, len(arguments)))))
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
|
||||||
|
for bt in tqdm(batches):
|
||||||
|
loop.run_until_complete(
|
||||||
|
asyncio.gather(*[get_one_answer_async(i) for i in bt])
|
||||||
|
)
|
||||||
|
|
||||||
|
latency = time.time() - tic
|
||||||
|
|
||||||
|
return states, latency
|
||||||
|
|
||||||
|
|
||||||
|
def bench_city_doc(args):
|
||||||
|
arguments = []
|
||||||
|
for line in read_jsonl(args.data_path):
|
||||||
|
arguments.append({"document": line["document"]})
|
||||||
|
arguments = arguments[: args.num_jsons]
|
||||||
|
|
||||||
|
states = [None] * len(arguments)
|
||||||
|
|
||||||
|
# Select backend
|
||||||
|
if args.backend == "outlines":
|
||||||
|
call_generate = partial(get_call_generate(args), temperature=0)
|
||||||
|
|
||||||
|
def get_one_answer(i):
|
||||||
|
states[i] = city_gen(**arguments[i], generate=call_generate)
|
||||||
|
|
||||||
|
elif args.backend == "guidance":
|
||||||
|
model = guidance.models.LlamaCpp(
|
||||||
|
args.model_path,
|
||||||
|
n_gpu_layers=-1,
|
||||||
|
n_ctx=args.n_ctx,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_one_answer(i):
|
||||||
|
lm = model + city_maker(**arguments[i])
|
||||||
|
states[i] = lm
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid backend: {args.backend}")
|
||||||
|
|
||||||
|
tic = time.time()
|
||||||
|
if args.parallel == 1:
|
||||||
|
for i in tqdm(range(len(arguments))):
|
||||||
|
get_one_answer(i)
|
||||||
|
else:
|
||||||
|
with ThreadPoolExecutor(args.parallel) as executor:
|
||||||
|
rets = executor.map(get_one_answer, list(range(len(arguments))))
|
||||||
|
for _ in rets:
|
||||||
|
pass
|
||||||
|
|
||||||
|
latency = time.time() - tic
|
||||||
|
|
||||||
|
return states, latency
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
if args.mode == "character":
|
||||||
|
args.data_path = "dataset.txt"
|
||||||
|
states, latency = bench_character(args)
|
||||||
|
elif args.mode == "city":
|
||||||
|
args.data_path = "questions.jsonl"
|
||||||
|
states, latency = bench_city_doc(args)
|
||||||
|
|
||||||
|
# Compute accuracy
|
||||||
|
print(f"Latency: {latency:.3f}")
|
||||||
|
|
||||||
|
# Write results
|
||||||
|
dump_state_text(f"tmp_output_{args.backend}_{args.mode}.txt", states)
|
||||||
|
|
||||||
|
with open(args.result_file, "a") as fout:
|
||||||
|
value = {
|
||||||
|
"task": "json_jump_forward",
|
||||||
|
"backend": args.backend,
|
||||||
|
"latency": round(latency, 3),
|
||||||
|
"num_jsons": args.num_jsons,
|
||||||
|
"mode": args.mode,
|
||||||
|
"parallel": args.parallel,
|
||||||
|
}
|
||||||
|
fout.write(json.dumps(value) + "\n")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--data-path", type=str)
|
||||||
|
parser.add_argument("--num-jsons", type=int, default=50)
|
||||||
|
parser.add_argument(
|
||||||
|
"--mode", type=str, default="character", choices=["character", "city"]
|
||||||
|
)
|
||||||
|
args = add_common_other_args_and_parse(parser)
|
||||||
|
main(args)
|
||||||
|
|
@ -0,0 +1,143 @@
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
|
||||||
|
import sglang as sgl
|
||||||
|
from sglang.test.test_utils import (
|
||||||
|
add_common_sglang_args_and_parse,
|
||||||
|
select_sglang_backend,
|
||||||
|
)
|
||||||
|
from sglang.utils import dump_state_text, read_jsonl
|
||||||
|
|
||||||
|
# there are some FSM bugs with json regex converted from pydantic model
|
||||||
|
# here use a string regex instead
|
||||||
|
# regex_string = build_regex_from_object(HarryPoterRole)
|
||||||
|
character_regex = (
|
||||||
|
r"""\{\n"""
|
||||||
|
+ r""" "name": "[\w\d\s]{1,16}",\n"""
|
||||||
|
+ r""" "house": "(Gryffindor|Slytherin|Ravenclaw|Hufflepuff)",\n"""
|
||||||
|
+ r""" "blood status": "(Pure-blood|Half-blood|Muggle-born)",\n"""
|
||||||
|
+ r""" "occupation": "(student|teacher|auror|ministry of magic|death eater|order of the phoenix)",\n"""
|
||||||
|
+ r""" "wand": \{\n"""
|
||||||
|
+ r""" "wood": "[\w\d\s]{1,16}",\n"""
|
||||||
|
+ r""" "core": "[\w\d\s]{1,16}",\n"""
|
||||||
|
+ r""" "length": [0-9]{1,2}\.[0-9]{0,2}\n"""
|
||||||
|
+ r""" \},\n"""
|
||||||
|
+ r""" "alive": "(Alive|Deceased)",\n"""
|
||||||
|
+ r""" "patronus": "[\w\d\s]{1,16}",\n"""
|
||||||
|
+ r""" "bogart": "[\w\d\s]{1,16}"\n"""
|
||||||
|
+ r"""\}"""
|
||||||
|
)
|
||||||
|
|
||||||
|
city_regex = (
|
||||||
|
r"""\{\n"""
|
||||||
|
+ r""" "name": "[\w\d\s]{1,16}",\n"""
|
||||||
|
+ r""" "country": "[\w\d\s]{1,16}",\n"""
|
||||||
|
+ r""" "latitude": [-+]?[0-9]*\.?[0-9]{0,2},\n"""
|
||||||
|
+ r""" "population": [-+]?[0-9]{1,9},\n"""
|
||||||
|
+ r""" "top 3 landmarks": \["[\w\d\s]{1,16}", "[\w\d\s]{1,16}", "[\w\d\s]{1,16}"\]\n"""
|
||||||
|
+ r"""\}"""
|
||||||
|
)
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
@sgl.function
|
||||||
|
def character_gen(s, name):
|
||||||
|
s += name + " is a character in Harry Potter. Please fill in the following information about this character.\n"
|
||||||
|
s += sgl.gen("json_output", max_tokens=256, regex=character_regex)
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
@sgl.function
|
||||||
|
def city_gen(s, document):
|
||||||
|
s += "Please extract the information of a city from the following wikipedia page.\n"
|
||||||
|
s += "Page begin.\n" + document + "Page end.\n"
|
||||||
|
s += "Here is the name, country, and symbol of the city in JSON format.\n"
|
||||||
|
s += sgl.gen("json_output",max_tokens=256, regex=city_regex)
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
|
||||||
|
def bench_city_doc(args):
|
||||||
|
arguments = []
|
||||||
|
for line in read_jsonl(args.data_path):
|
||||||
|
arguments.append({"document": line["document"]})
|
||||||
|
arguments = arguments[: args.num_jsons]
|
||||||
|
|
||||||
|
# Select backend
|
||||||
|
backend = select_sglang_backend(args)
|
||||||
|
sgl.set_default_backend(backend)
|
||||||
|
|
||||||
|
# Run requests
|
||||||
|
tic = time.time()
|
||||||
|
states = city_gen.run_batch(
|
||||||
|
arguments,
|
||||||
|
temperature=0,
|
||||||
|
num_threads=args.parallel,
|
||||||
|
progress_bar=True,
|
||||||
|
)
|
||||||
|
latency = time.time() - tic
|
||||||
|
|
||||||
|
return states, latency
|
||||||
|
|
||||||
|
|
||||||
|
def bench_character(args):
|
||||||
|
arguments = []
|
||||||
|
with open(args.data_path, "r") as f:
|
||||||
|
for line in f:
|
||||||
|
arguments.append({"name": line.strip()})
|
||||||
|
arguments = arguments[: args.num_jsons]
|
||||||
|
|
||||||
|
# Select backend
|
||||||
|
backend = select_sglang_backend(args)
|
||||||
|
sgl.set_default_backend(backend)
|
||||||
|
|
||||||
|
# Run requests
|
||||||
|
tic = time.time()
|
||||||
|
states = character_gen.run_batch(
|
||||||
|
arguments,
|
||||||
|
temperature=0,
|
||||||
|
num_threads=args.parallel,
|
||||||
|
progress_bar=True,
|
||||||
|
)
|
||||||
|
latency = time.time() - tic
|
||||||
|
|
||||||
|
return states, latency
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
if args.mode == "character":
|
||||||
|
args.data_path = "dataset.txt"
|
||||||
|
states, latency = bench_character(args)
|
||||||
|
elif args.mode == "city":
|
||||||
|
args.data_path = "questions.jsonl"
|
||||||
|
states, latency = bench_city_doc(args)
|
||||||
|
|
||||||
|
# Compute accuracy
|
||||||
|
print(f"Latency: {latency:.3f}")
|
||||||
|
|
||||||
|
# Write results
|
||||||
|
dump_state_text(f"tmp_output_{args.backend}_{args.mode}.txt", states)
|
||||||
|
with open(f"{args.backend}_{args.mode}.json", "w") as fout:
|
||||||
|
for state in states:
|
||||||
|
fout.write(state["json_output"] + "\n")
|
||||||
|
|
||||||
|
with open(args.result_file, "a") as fout:
|
||||||
|
value = {
|
||||||
|
"task": "json_jump_forward",
|
||||||
|
"backend": args.backend,
|
||||||
|
"latency": round(latency, 3),
|
||||||
|
"num_jsons": args.num_jsons,
|
||||||
|
"mode": args.mode,
|
||||||
|
"parallel": args.parallel,
|
||||||
|
}
|
||||||
|
fout.write(json.dumps(value) + "\n")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--data-path", type=str)
|
||||||
|
parser.add_argument("--num-jsons", type=int, default=50)
|
||||||
|
parser.add_argument(
|
||||||
|
"--mode", type=str, default="character", choices=["character", "city"]
|
||||||
|
)
|
||||||
|
args = add_common_sglang_args_and_parse(parser)
|
||||||
|
main(args)
|
||||||
|
|
@ -0,0 +1,58 @@
|
||||||
|
import json
|
||||||
|
|
||||||
|
import transformers
|
||||||
|
import wikipedia
|
||||||
|
|
||||||
|
model_path = "meta-llama/Llama-2-7b-chat-hf"
|
||||||
|
t = transformers.AutoTokenizer.from_pretrained(model_path)
|
||||||
|
city_names = [
|
||||||
|
"los angles",
|
||||||
|
"london",
|
||||||
|
"tokyo",
|
||||||
|
"beijing",
|
||||||
|
"singapore",
|
||||||
|
"paris",
|
||||||
|
"dubai",
|
||||||
|
"sydney",
|
||||||
|
"moscow",
|
||||||
|
"rome",
|
||||||
|
"toronto",
|
||||||
|
"rio de janeiro",
|
||||||
|
"istanbul",
|
||||||
|
"berlin",
|
||||||
|
"auckland",
|
||||||
|
"buenos aires",
|
||||||
|
"mexico city",
|
||||||
|
"mumbai",
|
||||||
|
"seoul",
|
||||||
|
"bangkok",
|
||||||
|
"cairo",
|
||||||
|
"athens",
|
||||||
|
"jerusalem",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def get_content(city_name):
|
||||||
|
content = str(wikipedia.page(city_name).content)
|
||||||
|
content = content.replace("\n\n", "\n")
|
||||||
|
|
||||||
|
tokens = t.encode(content)
|
||||||
|
|
||||||
|
expected_tokens = 3000
|
||||||
|
truncate_len = int((expected_tokens / len(tokens)) * len(content))
|
||||||
|
truncate_content = content[:truncate_len]
|
||||||
|
truncate_tokens = t.encode(truncate_content)
|
||||||
|
|
||||||
|
# Count token
|
||||||
|
print(
|
||||||
|
f"city_name: {city_name}, #tokens: {len(tokens)}, #truncate tokens: {len(truncate_tokens)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return truncate_content
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
with open("questions.jsonl", "w") as fout:
|
||||||
|
for city_name in city_names:
|
||||||
|
truncate_content = get_content(city_name)
|
||||||
|
fout.write(json.dumps({"document": truncate_content}) + "\n")
|
||||||
|
|
@ -0,0 +1,50 @@
|
||||||
|
Harry Potter
|
||||||
|
Hermione Granger
|
||||||
|
Ron Weasley
|
||||||
|
Albus Dumbledore
|
||||||
|
Severus Snape
|
||||||
|
Rubeus Hagrid
|
||||||
|
Draco Malfoy
|
||||||
|
Ginny Weasley
|
||||||
|
Fred Weasley
|
||||||
|
George Weasley
|
||||||
|
Percy Weasley
|
||||||
|
Sirius Black
|
||||||
|
Remus Lupin
|
||||||
|
Neville Longbottom
|
||||||
|
Luna Lovegood
|
||||||
|
Cedric Diggory
|
||||||
|
Cho Chang
|
||||||
|
Lord Voldemort
|
||||||
|
Minerva McGonagall
|
||||||
|
Filius Flitwick
|
||||||
|
Dolores Umbridge
|
||||||
|
Bellatrix Lestrange
|
||||||
|
Lucius Malfoy
|
||||||
|
Molly Weasley
|
||||||
|
Arthur Weasley
|
||||||
|
Nymphadora Tonks
|
||||||
|
Dobby
|
||||||
|
Moaning Myrtle
|
||||||
|
Peter Pettigrew
|
||||||
|
Alastor 'Mad-Eye' Moody
|
||||||
|
Horace Slughorn
|
||||||
|
Vernon Dursley
|
||||||
|
Petunia Dursley
|
||||||
|
Dudley Dursley
|
||||||
|
Argus Filch
|
||||||
|
Sybill Trelawney
|
||||||
|
Gilderoy Lockhart
|
||||||
|
Fleur Delacour
|
||||||
|
Viktor Krum
|
||||||
|
Bill Weasley
|
||||||
|
Oliver Wood
|
||||||
|
Cornelius Fudge
|
||||||
|
Barty Crouch Sr.
|
||||||
|
Barty Crouch Jr.
|
||||||
|
Kingsley Shacklebolt
|
||||||
|
Quirinus Quirrell
|
||||||
|
Nearly Headless Nick
|
||||||
|
Aunt Marge
|
||||||
|
Griphook
|
||||||
|
Ludo Bagman
|
||||||
|
|
@ -0,0 +1,15 @@
|
||||||
|
## Run benchmark
|
||||||
|
|
||||||
|
### Benchmark sglang
|
||||||
|
|
||||||
|
Run Llama-8b
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python3 -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --port 30000
|
||||||
|
```
|
||||||
|
|
||||||
|
Benchmark
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python3 bench_sglang.py
|
||||||
|
```
|
||||||
|
|
@ -0,0 +1,146 @@
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
import jsonschema
|
||||||
|
from datasets import load_dataset
|
||||||
|
|
||||||
|
import sglang as sgl
|
||||||
|
from sglang.global_config import global_config
|
||||||
|
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||||
|
from sglang.test.test_utils import (
|
||||||
|
add_common_sglang_args_and_parse,
|
||||||
|
select_sglang_backend,
|
||||||
|
)
|
||||||
|
from sglang.utils import dump_state_text
|
||||||
|
|
||||||
|
|
||||||
|
@sgl.function
|
||||||
|
def schema_gen(s, message: Tuple[str, str], json_schema: str):
|
||||||
|
system, user = message
|
||||||
|
s += sgl.system(system)
|
||||||
|
s += sgl.user(user)
|
||||||
|
s += sgl.assistant(
|
||||||
|
sgl.gen("json_output", temperature=0, max_tokens=256, json_schema=json_schema)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def contains_formats(schema, formats: List[str]):
|
||||||
|
if isinstance(schema, dict):
|
||||||
|
if schema.get("format", None) in formats:
|
||||||
|
return True
|
||||||
|
for value in schema.values():
|
||||||
|
if contains_formats(value, formats):
|
||||||
|
return True
|
||||||
|
elif isinstance(schema, list):
|
||||||
|
for item in schema:
|
||||||
|
if contains_formats(item, formats):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def convert_dataset(path: str):
|
||||||
|
raw_dataset = load_dataset(path)
|
||||||
|
dataset = []
|
||||||
|
for data in raw_dataset["train"]:
|
||||||
|
messages = data["prompt"]
|
||||||
|
schema = data["schema"]
|
||||||
|
obj = json.loads(schema)
|
||||||
|
|
||||||
|
# skip some corrupted examples
|
||||||
|
if obj.get("type", None) is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# skip schema with format "email"
|
||||||
|
# which is not supported by outlines for now
|
||||||
|
if contains_formats(obj, ["email"]):
|
||||||
|
continue
|
||||||
|
|
||||||
|
system = messages[0]
|
||||||
|
user = messages[1]
|
||||||
|
assert system["role"] == "system", "invalid role"
|
||||||
|
assert user["role"] == "user", "invalid role"
|
||||||
|
assert len(messages) == 2, "invalid message length"
|
||||||
|
message = json.dumps(system["content"]), json.dumps(user["content"])
|
||||||
|
dataset.append(
|
||||||
|
{
|
||||||
|
"message": message,
|
||||||
|
"json_schema": schema,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
|
def bench_schema(args):
|
||||||
|
arguments = convert_dataset(args.data_path)
|
||||||
|
|
||||||
|
if args.num_jsons < 0 or args.num_jsons > len(arguments):
|
||||||
|
args.num_jsons = len(arguments)
|
||||||
|
arguments = arguments[: args.num_jsons]
|
||||||
|
|
||||||
|
# Select backend
|
||||||
|
backend = select_sglang_backend(args)
|
||||||
|
sgl.set_default_backend(backend)
|
||||||
|
|
||||||
|
# Run requests
|
||||||
|
tic = time.time()
|
||||||
|
states = schema_gen.run_batch(
|
||||||
|
arguments,
|
||||||
|
temperature=0,
|
||||||
|
num_threads=args.parallel,
|
||||||
|
progress_bar=True,
|
||||||
|
)
|
||||||
|
latency = time.time() - tic
|
||||||
|
|
||||||
|
# Check if the outputs are valid
|
||||||
|
indexs = []
|
||||||
|
for i, state in enumerate(states):
|
||||||
|
try:
|
||||||
|
schema = json.loads(arguments[i]["json_schema"])
|
||||||
|
obj = json.loads(state["json_output"])
|
||||||
|
assert jsonschema.validate(obj, schema) is None
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
indexs.append(i)
|
||||||
|
|
||||||
|
return states, latency
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
states, latency = bench_schema(args)
|
||||||
|
|
||||||
|
# Compute accuracy
|
||||||
|
tokenizer = get_tokenizer(
|
||||||
|
global_config.default_backend.get_server_info()["tokenizer_path"]
|
||||||
|
)
|
||||||
|
output_jsons = [state["json_output"] for state in states]
|
||||||
|
num_output_tokens = sum(len(tokenizer.encode(x)) for x in output_jsons)
|
||||||
|
print(f"Latency: {latency:.3f}")
|
||||||
|
print(f"Output throughput: {num_output_tokens / latency:.3f} token/s")
|
||||||
|
print(f"#output tokens: {num_output_tokens}")
|
||||||
|
|
||||||
|
# Write results
|
||||||
|
dump_state_text(f"tmp_output_{args.backend}.txt", states)
|
||||||
|
with open(f"{args.backend}.jsonl", "w") as fout:
|
||||||
|
for state in states:
|
||||||
|
fout.write(state["json_output"] + "\n")
|
||||||
|
|
||||||
|
with open(args.result_file, "a") as fout:
|
||||||
|
value = {
|
||||||
|
"task": "json_schema",
|
||||||
|
"backend": args.backend,
|
||||||
|
"latency": round(latency, 3),
|
||||||
|
"num_jsons": args.num_jsons,
|
||||||
|
"parallel": args.parallel,
|
||||||
|
}
|
||||||
|
fout.write(json.dumps(value) + "\n")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--data-path", type=str, default="NousResearch/json-mode-eval")
|
||||||
|
parser.add_argument("--num-jsons", type=int, default=-1)
|
||||||
|
args = add_common_sglang_args_and_parse(parser)
|
||||||
|
main(args)
|
||||||
|
|
@ -0,0 +1,403 @@
|
||||||
|
import itertools
|
||||||
|
import math
|
||||||
|
|
||||||
|
import cudnn
|
||||||
|
import torch
|
||||||
|
import torch.utils.benchmark as benchmark
|
||||||
|
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
|
||||||
|
|
||||||
|
from sglang.srt.layers.attention.triton_ops.decode_attention import decode_attention_fwd
|
||||||
|
from sglang.srt.utils import should_use_tensor_core
|
||||||
|
|
||||||
|
|
||||||
|
def benchmark_forward(
|
||||||
|
fn,
|
||||||
|
*inputs,
|
||||||
|
repeats=10,
|
||||||
|
amp=False,
|
||||||
|
amp_dtype=torch.float16,
|
||||||
|
**kwinputs,
|
||||||
|
):
|
||||||
|
def amp_wrapper(*inputs, **kwinputs):
|
||||||
|
with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
|
||||||
|
fn(*inputs, **kwinputs)
|
||||||
|
|
||||||
|
t = benchmark.Timer(
|
||||||
|
stmt="fn_amp(*inputs, **kwinputs)",
|
||||||
|
globals={"fn_amp": amp_wrapper, "inputs": inputs, "kwinputs": kwinputs},
|
||||||
|
num_threads=torch.get_num_threads(),
|
||||||
|
)
|
||||||
|
m = t.timeit(repeats)
|
||||||
|
return t, m
|
||||||
|
|
||||||
|
|
||||||
|
def time_fwd(func, *args, **kwargs):
|
||||||
|
time_f = benchmark_forward(func, *args, **kwargs)
|
||||||
|
return time_f[1].mean * 1e6
|
||||||
|
|
||||||
|
|
||||||
|
def decode_attention_sglang(
|
||||||
|
q,
|
||||||
|
kv_data,
|
||||||
|
batch_size,
|
||||||
|
kv_len,
|
||||||
|
head_num_q,
|
||||||
|
head_num_kv,
|
||||||
|
head_dim,
|
||||||
|
num_kv_splits,
|
||||||
|
warmup=10,
|
||||||
|
):
|
||||||
|
|
||||||
|
k_buffer = kv_data[0].view(-1, head_num_kv, head_dim)
|
||||||
|
v_buffer = kv_data[1].view(-1, head_num_kv, head_dim)
|
||||||
|
o = torch.empty_like(q)
|
||||||
|
total_tokens = batch_size * kv_len
|
||||||
|
req_to_token = torch.arange(0, total_tokens).to(0).int().view(batch_size, kv_len)
|
||||||
|
b_req_idx = torch.arange(0, batch_size).to(0).int()
|
||||||
|
b_seq_len = torch.full((batch_size,), kv_len, dtype=torch.int32, device="cuda")
|
||||||
|
max_len_in_batch = kv_len
|
||||||
|
sm_scale = 1.0 / (head_dim**0.5)
|
||||||
|
|
||||||
|
attn_logits = torch.empty(
|
||||||
|
(batch_size, head_num_q, num_kv_splits, head_dim + 1),
|
||||||
|
dtype=torch.float32,
|
||||||
|
device="cuda",
|
||||||
|
)
|
||||||
|
|
||||||
|
for _ in range(warmup):
|
||||||
|
decode_attention_fwd(
|
||||||
|
q,
|
||||||
|
k_buffer,
|
||||||
|
v_buffer,
|
||||||
|
o,
|
||||||
|
req_to_token,
|
||||||
|
b_req_idx,
|
||||||
|
b_seq_len,
|
||||||
|
attn_logits,
|
||||||
|
num_kv_splits,
|
||||||
|
sm_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
f = time_fwd(
|
||||||
|
decode_attention_fwd,
|
||||||
|
q,
|
||||||
|
k_buffer,
|
||||||
|
v_buffer,
|
||||||
|
o,
|
||||||
|
req_to_token,
|
||||||
|
b_req_idx,
|
||||||
|
b_seq_len,
|
||||||
|
attn_logits,
|
||||||
|
num_kv_splits,
|
||||||
|
sm_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
return f, o
|
||||||
|
|
||||||
|
|
||||||
|
def decode_attention_flashinfer(dtype, head_num_q, head_num_kv):
|
||||||
|
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device="cuda")
|
||||||
|
use_tensor_cores = should_use_tensor_core(
|
||||||
|
kv_cache_dtype=dtype,
|
||||||
|
num_attention_heads=head_num_q,
|
||||||
|
num_kv_heads=head_num_kv,
|
||||||
|
)
|
||||||
|
flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
||||||
|
workspace_buffer, "NHD", use_tensor_cores=use_tensor_cores
|
||||||
|
)
|
||||||
|
|
||||||
|
class FlashinferAttention(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
def forward(
|
||||||
|
ctx,
|
||||||
|
q,
|
||||||
|
kv_data,
|
||||||
|
batch_size,
|
||||||
|
kv_len,
|
||||||
|
head_num_q,
|
||||||
|
head_num_kv,
|
||||||
|
head_dim,
|
||||||
|
dtype,
|
||||||
|
warmup=10,
|
||||||
|
):
|
||||||
|
total_tokens = batch_size * kv_len
|
||||||
|
kv_indptr = torch.arange(0, batch_size + 1).to(0).int() * kv_len
|
||||||
|
kv_indices = torch.arange(0, total_tokens).to(0).int()
|
||||||
|
kv_last_page_len = torch.full(
|
||||||
|
(batch_size,), 1, dtype=torch.int32, device="cuda"
|
||||||
|
)
|
||||||
|
|
||||||
|
flashinfer_decode_wrapper.end_forward()
|
||||||
|
flashinfer_decode_wrapper.begin_forward(
|
||||||
|
kv_indptr,
|
||||||
|
kv_indices,
|
||||||
|
kv_last_page_len,
|
||||||
|
head_num_q,
|
||||||
|
head_num_kv,
|
||||||
|
head_dim,
|
||||||
|
1,
|
||||||
|
pos_encoding_mode="NONE",
|
||||||
|
data_type=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
for _ in range(warmup):
|
||||||
|
o = flashinfer_decode_wrapper.forward(
|
||||||
|
q.contiguous().view(-1, head_num_q, head_dim), kv_data
|
||||||
|
)
|
||||||
|
|
||||||
|
f = time_fwd(
|
||||||
|
flashinfer_decode_wrapper.forward,
|
||||||
|
q.contiguous().view(-1, head_num_q, head_dim),
|
||||||
|
kv_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
return f, o
|
||||||
|
|
||||||
|
return FlashinferAttention
|
||||||
|
|
||||||
|
|
||||||
|
def convert_to_cudnn_type(torch_type):
|
||||||
|
if torch_type == torch.float16:
|
||||||
|
return cudnn.data_type.HALF
|
||||||
|
elif torch_type == torch.bfloat16:
|
||||||
|
return cudnn.data_type.BFLOAT16
|
||||||
|
elif torch_type == torch.float32:
|
||||||
|
return cudnn.data_type.FLOAT
|
||||||
|
elif torch_type == torch.int32:
|
||||||
|
return cudnn.data_type.INT32
|
||||||
|
elif torch_type == torch.int64:
|
||||||
|
return cudnn.data_type.INT64
|
||||||
|
else:
|
||||||
|
raise ValueError("Unsupported tensor data type.")
|
||||||
|
|
||||||
|
|
||||||
|
def decode_attention_cudnn(
|
||||||
|
q, kv_data, batch_size, kv_len, head_num_q, head_num_kv, head_dim, dtype, warmup=10
|
||||||
|
):
|
||||||
|
# Prepare data: continuous q,k,v
|
||||||
|
dims_q = (batch_size, head_num_q, 1, head_dim)
|
||||||
|
strides_q = (head_num_q * head_dim, head_dim, head_num_q * head_dim, 1)
|
||||||
|
q_gpu = q.as_strided(dims_q, strides_q)
|
||||||
|
o_gpu = (
|
||||||
|
torch.empty(batch_size * head_num_q * head_dim)
|
||||||
|
.half()
|
||||||
|
.cuda()
|
||||||
|
.as_strided(dims_q, strides_q)
|
||||||
|
)
|
||||||
|
|
||||||
|
dims_kv = (batch_size, head_num_kv, kv_len, head_dim)
|
||||||
|
strides_kv = (
|
||||||
|
kv_len * head_num_kv * head_dim,
|
||||||
|
head_dim,
|
||||||
|
head_num_kv * head_dim,
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
k_gpu = kv_data[0].as_strided(dims_kv, strides_kv)
|
||||||
|
v_gpu = kv_data[1].as_strided(dims_kv, strides_kv)
|
||||||
|
|
||||||
|
seq_len_q_gpu = torch.full((batch_size, 1, 1, 1), 1, device="cuda")
|
||||||
|
seq_len_kv_gpu = torch.full((batch_size, 1, 1, 1), kv_len, device="cuda")
|
||||||
|
attn_scale = 1.0 / (head_dim**0.5)
|
||||||
|
|
||||||
|
# Prepare data: paged k,v
|
||||||
|
block_size = 1
|
||||||
|
blocks_per_batch = math.ceil(kv_len / block_size)
|
||||||
|
# [num_blocks, head_num_kv, block_size, head_dim], num_blocks = batch_size * blocks_per_batch
|
||||||
|
container_k_gpu = torch.cat(k_gpu.chunk(blocks_per_batch, dim=2), dim=0)
|
||||||
|
container_v_gpu = torch.cat(v_gpu.chunk(blocks_per_batch, dim=2), dim=0)
|
||||||
|
page_table_k_gpu = (
|
||||||
|
torch.linspace(
|
||||||
|
0,
|
||||||
|
batch_size * blocks_per_batch - 1,
|
||||||
|
batch_size * blocks_per_batch,
|
||||||
|
device="cuda",
|
||||||
|
dtype=torch.int32,
|
||||||
|
)
|
||||||
|
.reshape(blocks_per_batch, 1, batch_size, 1)
|
||||||
|
.transpose(0, 2)
|
||||||
|
)
|
||||||
|
page_table_v_gpu = page_table_k_gpu.clone()
|
||||||
|
|
||||||
|
graph = cudnn.pygraph(
|
||||||
|
io_data_type=convert_to_cudnn_type(dtype),
|
||||||
|
intermediate_data_type=cudnn.data_type.FLOAT,
|
||||||
|
compute_data_type=cudnn.data_type.FLOAT,
|
||||||
|
)
|
||||||
|
|
||||||
|
q = graph.tensor_like(q_gpu)
|
||||||
|
container_k = graph.tensor_like(container_k_gpu)
|
||||||
|
container_v = graph.tensor_like(container_v_gpu)
|
||||||
|
page_table_k = graph.tensor_like(page_table_k_gpu)
|
||||||
|
page_table_v = graph.tensor_like(page_table_v_gpu)
|
||||||
|
|
||||||
|
seq_len_q = graph.tensor_like(seq_len_q_gpu)
|
||||||
|
seq_len_kv = graph.tensor_like(seq_len_kv_gpu)
|
||||||
|
|
||||||
|
o, _ = graph.sdpa(
|
||||||
|
name="sdpa",
|
||||||
|
q=q,
|
||||||
|
k=container_k, # Container K: non contiguous container with K blocks
|
||||||
|
v=container_v, # Container V: non contiguous container with V blocks
|
||||||
|
is_inference=True,
|
||||||
|
attn_scale=attn_scale,
|
||||||
|
use_causal_mask=False,
|
||||||
|
use_padding_mask=True,
|
||||||
|
seq_len_q=seq_len_q,
|
||||||
|
seq_len_kv=seq_len_kv,
|
||||||
|
paged_attention_k_table=page_table_k, # Page Table K: Tensor containing offsets to the container with K blocks
|
||||||
|
paged_attention_v_table=page_table_v, # Page Table V: Tensor containing offsets to the container with V blocks
|
||||||
|
paged_attention_max_seq_len_kv=kv_len, # The maximum sequence length for K caches (this is optional, but recommended)
|
||||||
|
)
|
||||||
|
|
||||||
|
o.set_output(True).set_dim(dims_q).set_stride(strides_q)
|
||||||
|
|
||||||
|
graph.validate()
|
||||||
|
graph.build_operation_graph()
|
||||||
|
graph.create_execution_plans([cudnn.heur_mode.A])
|
||||||
|
graph.check_support()
|
||||||
|
graph.build_plans()
|
||||||
|
|
||||||
|
workspace = torch.empty(
|
||||||
|
graph.get_workspace_size(), device="cuda", dtype=torch.uint8
|
||||||
|
)
|
||||||
|
|
||||||
|
variant_pack = {
|
||||||
|
q: q_gpu,
|
||||||
|
container_k: container_k_gpu,
|
||||||
|
container_v: container_v_gpu,
|
||||||
|
page_table_k: page_table_k_gpu,
|
||||||
|
page_table_v: page_table_v_gpu,
|
||||||
|
seq_len_q: seq_len_q_gpu,
|
||||||
|
seq_len_kv: seq_len_kv_gpu,
|
||||||
|
o: o_gpu,
|
||||||
|
}
|
||||||
|
|
||||||
|
for _ in range(warmup):
|
||||||
|
graph.execute(variant_pack, workspace)
|
||||||
|
|
||||||
|
f = time_fwd(
|
||||||
|
graph.execute,
|
||||||
|
variant_pack,
|
||||||
|
workspace,
|
||||||
|
)
|
||||||
|
|
||||||
|
return f, o_gpu.squeeze(dim=2)
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_diff():
|
||||||
|
|
||||||
|
dtype = torch.float16
|
||||||
|
batch_size = 64
|
||||||
|
kv_len = 4096
|
||||||
|
head_num_q = 64
|
||||||
|
head_num_kv = 8
|
||||||
|
head_dim = 128
|
||||||
|
|
||||||
|
q = torch.randn(batch_size, head_num_q, head_dim, dtype=dtype, device="cuda")
|
||||||
|
kv_data = (
|
||||||
|
torch.randn(
|
||||||
|
batch_size * kv_len, head_num_kv, head_dim, dtype=dtype, device="cuda"
|
||||||
|
),
|
||||||
|
torch.randn(
|
||||||
|
batch_size * kv_len, head_num_kv, head_dim, dtype=dtype, device="cuda"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
_, output_sglang = decode_attention_sglang(
|
||||||
|
q,
|
||||||
|
kv_data,
|
||||||
|
batch_size,
|
||||||
|
kv_len,
|
||||||
|
head_num_q,
|
||||||
|
head_num_kv,
|
||||||
|
head_dim,
|
||||||
|
num_kv_splits=8,
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_flashinfer = decode_attention_flashinfer(dtype, head_num_q, head_num_kv).apply
|
||||||
|
_, output_flashinfer = attn_flashinfer(
|
||||||
|
q, kv_data, batch_size, kv_len, head_num_q, head_num_kv, head_dim, dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
_, output_cudnn = decode_attention_cudnn(
|
||||||
|
q, kv_data, batch_size, kv_len, head_num_q, head_num_kv, head_dim, dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"SGLang output={output_sglang}")
|
||||||
|
print(f"FlashInfer output={output_flashinfer}")
|
||||||
|
print(f"cuDNN output={output_cudnn}")
|
||||||
|
if torch.allclose(output_sglang, output_flashinfer, atol=1e-2, rtol=1e-2):
|
||||||
|
print("✅ SGLang[Triton] and FlashInfer match")
|
||||||
|
else:
|
||||||
|
print("❌ SGLang[Triton] and FlashInfer differ")
|
||||||
|
|
||||||
|
if torch.allclose(output_sglang, output_cudnn, atol=1e-2, rtol=1e-2):
|
||||||
|
print("✅ SGLang[Triton] and cuDNN match")
|
||||||
|
else:
|
||||||
|
print("❌ SGLang[Triton] and cuDNN differ")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
calculate_diff()
|
||||||
|
|
||||||
|
head_dim = 128
|
||||||
|
dtype = torch.float16
|
||||||
|
batch_size_range = [2**i for i in range(0, 8, 2)]
|
||||||
|
kv_len_range = [2**i for i in range(6, 13, 1)]
|
||||||
|
configs = list(itertools.product(batch_size_range, kv_len_range))
|
||||||
|
|
||||||
|
for head_num_q, head_num_kv in [[32, 32], [64, 8], [40, 8]]:
|
||||||
|
attn_flashinfer = decode_attention_flashinfer(
|
||||||
|
dtype, head_num_q, head_num_kv
|
||||||
|
).apply
|
||||||
|
for batch_size, kv_len in configs:
|
||||||
|
q = torch.randn(
|
||||||
|
batch_size, head_num_q, head_dim, dtype=dtype, device="cuda"
|
||||||
|
)
|
||||||
|
kv_data = (
|
||||||
|
torch.randn(
|
||||||
|
batch_size * kv_len,
|
||||||
|
head_num_kv,
|
||||||
|
head_dim,
|
||||||
|
dtype=dtype,
|
||||||
|
device="cuda",
|
||||||
|
),
|
||||||
|
torch.randn(
|
||||||
|
batch_size * kv_len,
|
||||||
|
head_num_kv,
|
||||||
|
head_dim,
|
||||||
|
dtype=dtype,
|
||||||
|
device="cuda",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
us_cudnn, output_cudnn = decode_attention_cudnn(
|
||||||
|
q, kv_data, batch_size, kv_len, head_num_q, head_num_kv, head_dim, dtype
|
||||||
|
)
|
||||||
|
us_sglang, output_sglang = decode_attention_sglang(
|
||||||
|
q,
|
||||||
|
kv_data,
|
||||||
|
batch_size,
|
||||||
|
kv_len,
|
||||||
|
head_num_q,
|
||||||
|
head_num_kv,
|
||||||
|
head_dim,
|
||||||
|
num_kv_splits=8,
|
||||||
|
)
|
||||||
|
us_flashinfer, _ = attn_flashinfer(
|
||||||
|
q, kv_data, batch_size, kv_len, head_num_q, head_num_kv, head_dim, dtype
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
head_num_q,
|
||||||
|
" ",
|
||||||
|
head_num_kv,
|
||||||
|
" ",
|
||||||
|
batch_size,
|
||||||
|
" ",
|
||||||
|
kv_len,
|
||||||
|
" ",
|
||||||
|
us_cudnn,
|
||||||
|
" ",
|
||||||
|
us_sglang,
|
||||||
|
" ",
|
||||||
|
us_flashinfer,
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1,19 @@
|
||||||
|
## DeepSeek kernels benchmark
|
||||||
|
|
||||||
|
|
||||||
|
### Prerequisites
|
||||||
|
- You should install [DeepGemm](https://github.com/deepseek-ai/DeepGEMM) from source before run `benchmark_deepgemm_fp8_gemm.py` and `benchmark_deepgemm_fp8_group_gemm.py`.
|
||||||
|
|
||||||
|
### Benchmark
|
||||||
|
- `benchmark_deepgemm_fp8_gemm.py`
|
||||||
|
```bash
|
||||||
|
python benchmark_deepgemm_fp8_gemm.py --run_correctness --tp_size 1
|
||||||
|
```
|
||||||
|
|
||||||
|
- `benchmark_deepgemm_fp8_group_gemm.py`
|
||||||
|
```bash
|
||||||
|
python benchmark_deepgemm_fp8_group_gemm.py --run_correctness --tp_size 1
|
||||||
|
```
|
||||||
|
|
||||||
|
- You can use the `--run_correctness` parameter to verify all kernels results's correctness.
|
||||||
|
- You can use the `--tp_size` parameter to benchmark all FP8 w8a8 block-wise matrix multiplications involved in DeepSeek V3/R1 under the current tensor parallelism (TP) setting. This benchmark compares DeepSeek's open-source [DeepGemm](https://github.com/deepseek-ai/DeepGEMM) implementation with SGLang's and VLLM Triton implementation.
|
||||||
|
|
@ -0,0 +1,398 @@
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import deep_gemm
|
||||||
|
import tilelang
|
||||||
|
import tilelang.language as T
|
||||||
|
import torch
|
||||||
|
import triton
|
||||||
|
from deep_gemm import ceil_div, get_col_major_tma_aligned_tensor
|
||||||
|
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||||
|
w8a8_block_fp8_matmul as vllm_w8a8_block_fp8_matmul,
|
||||||
|
)
|
||||||
|
|
||||||
|
from sglang.srt.layers.quantization.fp8_kernel import w8a8_block_fp8_matmul
|
||||||
|
|
||||||
|
|
||||||
|
# Adapted from https://github.com/tile-ai/tilelang/blob/a8cfdce92795cb861c9033573534653ee040b5ed/examples/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py#L1
|
||||||
|
def tl_gemm(
|
||||||
|
M,
|
||||||
|
N,
|
||||||
|
K,
|
||||||
|
in_dtype,
|
||||||
|
out_dtype,
|
||||||
|
accum_dtype,
|
||||||
|
):
|
||||||
|
assert in_dtype in [
|
||||||
|
"e4m3_float8",
|
||||||
|
], "Currently only e4m3_float8 is supported"
|
||||||
|
assert out_dtype in [
|
||||||
|
"bfloat16",
|
||||||
|
"float16",
|
||||||
|
], "Currently only bfloat16 and float16 are supported"
|
||||||
|
|
||||||
|
TILE_SIZE = (128, 128, 128)
|
||||||
|
block_M = TILE_SIZE[0]
|
||||||
|
block_N = TILE_SIZE[1]
|
||||||
|
block_K = TILE_SIZE[2]
|
||||||
|
|
||||||
|
A_shape = (M, K)
|
||||||
|
Scales_A_shape = (M, T.ceildiv(K, block_K))
|
||||||
|
B_shape = (N, K)
|
||||||
|
Scales_B_shape = (T.ceildiv(N, block_N), T.ceildiv(K, block_K))
|
||||||
|
A_shared_shape = (block_M, block_K)
|
||||||
|
B_shared_shape = (block_N, block_K)
|
||||||
|
C_shared_shape = (block_M, block_N)
|
||||||
|
|
||||||
|
@T.prim_func
|
||||||
|
def main(
|
||||||
|
A: T.Buffer(A_shape, in_dtype),
|
||||||
|
scales_a: T.Buffer(Scales_A_shape, "float32"),
|
||||||
|
B: T.Buffer(B_shape, in_dtype),
|
||||||
|
scales_b: T.Buffer(Scales_B_shape, "float32"),
|
||||||
|
C: T.Buffer((M, N), out_dtype),
|
||||||
|
):
|
||||||
|
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (
|
||||||
|
bx,
|
||||||
|
by,
|
||||||
|
):
|
||||||
|
|
||||||
|
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
|
||||||
|
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
|
||||||
|
C_shared = T.alloc_shared(C_shared_shape, out_dtype)
|
||||||
|
Scale_C_shared = T.alloc_shared((block_M), "float32")
|
||||||
|
C_local = T.alloc_fragment(C_shared_shape, accum_dtype)
|
||||||
|
C_local_accum = T.alloc_fragment(C_shared_shape, accum_dtype)
|
||||||
|
|
||||||
|
# Improve L2 Cache
|
||||||
|
T.use_swizzle(panel_size=10)
|
||||||
|
|
||||||
|
T.clear(C_local)
|
||||||
|
T.clear(C_local_accum)
|
||||||
|
K_iters = T.ceildiv(K, block_K)
|
||||||
|
for k in T.Pipelined(K_iters, num_stages=4):
|
||||||
|
# Load A into shared memory
|
||||||
|
T.copy(A[by * block_M, k * block_K], A_shared)
|
||||||
|
# Load B into shared memory
|
||||||
|
T.copy(B[bx * block_N, k * block_K], B_shared)
|
||||||
|
# Load scale into shared memory
|
||||||
|
Scale_B = scales_b[bx, k]
|
||||||
|
for i in T.Parallel(block_M):
|
||||||
|
Scale_C_shared[i] = scales_a[by * block_M + i, k] * Scale_B
|
||||||
|
|
||||||
|
T.gemm(A_shared, B_shared, C_local, transpose_B=True)
|
||||||
|
# Promote to enable 2xAcc
|
||||||
|
for i, j in T.Parallel(block_M, block_N):
|
||||||
|
C_local_accum[i, j] += C_local[i, j] * Scale_C_shared[i]
|
||||||
|
T.clear(C_local)
|
||||||
|
# TMA store
|
||||||
|
T.copy(C_local_accum, C_shared)
|
||||||
|
T.copy(C_shared, C[by * block_M, bx * block_N])
|
||||||
|
|
||||||
|
return main
|
||||||
|
|
||||||
|
|
||||||
|
def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
assert x.dim() == 2 and x.size(1) % 128 == 0
|
||||||
|
m, n = x.shape
|
||||||
|
x_view = x.view(m, -1, 128)
|
||||||
|
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
|
||||||
|
return (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view(
|
||||||
|
m, n
|
||||||
|
), (x_amax / 448.0).view(m, -1)
|
||||||
|
|
||||||
|
|
||||||
|
def per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
assert x.dim() == 2
|
||||||
|
m, n = x.shape
|
||||||
|
x_padded = torch.zeros(
|
||||||
|
(ceil_div(m, 128) * 128, ceil_div(n, 128) * 128), dtype=x.dtype, device=x.device
|
||||||
|
)
|
||||||
|
x_padded[:m, :n] = x
|
||||||
|
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128)
|
||||||
|
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
|
||||||
|
x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
|
||||||
|
return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / 448.0).view(
|
||||||
|
x_view.size(0), x_view.size(2)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def fp8_gemm_deepgemm(
|
||||||
|
x_fp8: torch.Tensor,
|
||||||
|
x_scale: torch.Tensor,
|
||||||
|
y_fp8: torch.Tensor,
|
||||||
|
y_scale: torch.Tensor,
|
||||||
|
m: int,
|
||||||
|
n: int,
|
||||||
|
k: int,
|
||||||
|
):
|
||||||
|
"""DeepGEMM implementation of FP8 GEMM"""
|
||||||
|
out = torch.empty((m, n), device="cuda", dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
# Run DeepGEMM kernel
|
||||||
|
deep_gemm.gemm_fp8_fp8_bf16_nt((x_fp8, x_scale), (y_fp8, y_scale), out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def fp8_gemm_sglang(
|
||||||
|
x_fp8: torch.Tensor,
|
||||||
|
x_scale: torch.Tensor,
|
||||||
|
y_fp8: torch.Tensor,
|
||||||
|
y_scale: torch.Tensor,
|
||||||
|
m: int,
|
||||||
|
n: int,
|
||||||
|
k: int,
|
||||||
|
):
|
||||||
|
"""SGLang implementation of FP8 GEMM"""
|
||||||
|
block_size = [128, 128] # Matches the block size in per_block_cast_to_fp8
|
||||||
|
|
||||||
|
# Run SGLang kernel
|
||||||
|
out = w8a8_block_fp8_matmul(
|
||||||
|
x_fp8, y_fp8, x_scale, y_scale, block_size, torch.bfloat16
|
||||||
|
)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def fp8_gemm_vllm(
|
||||||
|
x_fp8: torch.Tensor,
|
||||||
|
x_scale: torch.Tensor,
|
||||||
|
y_fp8: torch.Tensor,
|
||||||
|
y_scale: torch.Tensor,
|
||||||
|
m: int,
|
||||||
|
n: int,
|
||||||
|
k: int,
|
||||||
|
):
|
||||||
|
"""vLLM implementation of FP8 GEMM"""
|
||||||
|
block_size = [128, 128] # Matches the block size in per_block_cast_to_fp8
|
||||||
|
|
||||||
|
# Run vLLM kernel
|
||||||
|
out = vllm_w8a8_block_fp8_matmul(
|
||||||
|
x_fp8, y_fp8, x_scale, y_scale, block_size, torch.bfloat16
|
||||||
|
)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_diff(m: int, n: int, k: int):
|
||||||
|
x = torch.randn((m, k), device="cuda", dtype=torch.bfloat16)
|
||||||
|
y = torch.randn((n, k), device="cuda", dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
x_fp8, x_scale = per_token_cast_to_fp8(x.clone())
|
||||||
|
y_fp8, y_scale = per_block_cast_to_fp8(y.clone())
|
||||||
|
x_scale_col_major = get_col_major_tma_aligned_tensor(x_scale.clone())
|
||||||
|
|
||||||
|
out_deepgemm = fp8_gemm_deepgemm(
|
||||||
|
x_fp8.clone(),
|
||||||
|
x_scale_col_major.clone(),
|
||||||
|
y_fp8.clone(),
|
||||||
|
y_scale.clone(),
|
||||||
|
m,
|
||||||
|
n,
|
||||||
|
k,
|
||||||
|
)
|
||||||
|
out_sglang = fp8_gemm_sglang(
|
||||||
|
x_fp8.clone(), x_scale.clone(), y_fp8.clone(), y_scale.clone(), m, n, k
|
||||||
|
)
|
||||||
|
|
||||||
|
tilelang_func = tl_gemm(m, n, k, "e4m3_float8", "bfloat16", "float32")
|
||||||
|
tilelang_kernel = tilelang.compile(tilelang_func, out_idx=[-1])
|
||||||
|
out_tilelang = tilelang_kernel(
|
||||||
|
x_fp8.clone(), x_scale.clone(), y_fp8.clone(), y_scale.clone()
|
||||||
|
)
|
||||||
|
|
||||||
|
diff_sglang_deepgemm = torch.abs(out_deepgemm - out_sglang).mean().item()
|
||||||
|
diff_tilelang_deepgemm = torch.abs(out_deepgemm - out_tilelang).mean().item()
|
||||||
|
diff_tilelang_sglang = torch.abs(out_tilelang - out_sglang).mean().item()
|
||||||
|
|
||||||
|
print(f"Shape m={m}, n={n}, k={k}:")
|
||||||
|
print(f"DeepGEMM output: {out_deepgemm[0, 0:5]}")
|
||||||
|
print(f"SGLang output: {out_sglang[0, 0:5]}")
|
||||||
|
print(f"TileLang output: {out_tilelang[0, 0:5]}")
|
||||||
|
print(f"Mean absolute difference (SGLang-DeepGEMM): {diff_sglang_deepgemm}")
|
||||||
|
print(f"Mean absolute difference (TileLang-DeepGEMM): {diff_tilelang_deepgemm}")
|
||||||
|
print(f"Mean absolute difference (TileLang-SGLang): {diff_tilelang_sglang}")
|
||||||
|
|
||||||
|
sglang_deepgemm_match = torch.allclose(
|
||||||
|
out_deepgemm, out_sglang, atol=1e-2, rtol=1e-2
|
||||||
|
)
|
||||||
|
tilelang_deepgemm_match = torch.allclose(
|
||||||
|
out_deepgemm, out_tilelang, atol=1e-2, rtol=1e-2
|
||||||
|
)
|
||||||
|
tilelang_sglang_match = torch.allclose(
|
||||||
|
out_tilelang, out_sglang, atol=1e-2, rtol=1e-2
|
||||||
|
)
|
||||||
|
|
||||||
|
if sglang_deepgemm_match and tilelang_deepgemm_match and tilelang_sglang_match:
|
||||||
|
print("✅ All implementations match\n")
|
||||||
|
else:
|
||||||
|
print("❌ Some implementations differ:")
|
||||||
|
print(f" - SGLang vs DeepGEMM: {'✅' if sglang_deepgemm_match else '❌'}")
|
||||||
|
print(f" - TileLang vs DeepGEMM: {'✅' if tilelang_deepgemm_match else '❌'}")
|
||||||
|
print(f" - TileLang vs SGLang: {'✅' if tilelang_sglang_match else '❌'}\n")
|
||||||
|
|
||||||
|
|
||||||
|
def get_weight_shapes(tp_size):
|
||||||
|
# cannot TP
|
||||||
|
total = [
|
||||||
|
(512 + 64, 7168),
|
||||||
|
((128 + 64) * 128, 7168),
|
||||||
|
(128 * (128 + 128), 512),
|
||||||
|
(7168, 16384),
|
||||||
|
(7168, 18432),
|
||||||
|
]
|
||||||
|
# N can TP
|
||||||
|
n_tp = [
|
||||||
|
(18432 * 2, 7168),
|
||||||
|
((128 + 64) * 128, 7168),
|
||||||
|
(128 * (128 + 128), 512),
|
||||||
|
(24576, 1536),
|
||||||
|
(4096, 7168),
|
||||||
|
]
|
||||||
|
# K can TP
|
||||||
|
k_tp = [(7168, 18432), (7168, 16384), (7168, 2048)]
|
||||||
|
|
||||||
|
weight_shapes = []
|
||||||
|
for t in total:
|
||||||
|
weight_shapes.append(t)
|
||||||
|
for n_t in n_tp:
|
||||||
|
new_t = (n_t[0] // tp_size, n_t[1])
|
||||||
|
weight_shapes.append(new_t)
|
||||||
|
for k_t in k_tp:
|
||||||
|
new_t = (k_t[0], k_t[1] // tp_size)
|
||||||
|
weight_shapes.append(new_t)
|
||||||
|
|
||||||
|
return weight_shapes
|
||||||
|
|
||||||
|
|
||||||
|
def create_benchmark_configs(tp_size):
|
||||||
|
configs = []
|
||||||
|
weight_shapes = get_weight_shapes(tp_size)
|
||||||
|
batch_sizes = [8, 16, 32, 64, 128, 256, 1024, 2048, 4096]
|
||||||
|
|
||||||
|
for n, k in weight_shapes:
|
||||||
|
for m in batch_sizes:
|
||||||
|
configs.append((m, n, k, tp_size))
|
||||||
|
|
||||||
|
return configs
|
||||||
|
|
||||||
|
|
||||||
|
def get_benchmark(tp_size):
|
||||||
|
all_configs = create_benchmark_configs(tp_size)
|
||||||
|
|
||||||
|
@triton.testing.perf_report(
|
||||||
|
triton.testing.Benchmark(
|
||||||
|
x_names=["m", "n", "k", "tp_size"],
|
||||||
|
x_vals=[list(config) for config in all_configs],
|
||||||
|
line_arg="provider",
|
||||||
|
line_vals=["deepgemm", "sglang", "tilelang"],
|
||||||
|
line_names=["DeepGEMM", "SGLang", "TileLang"],
|
||||||
|
styles=[("blue", "-"), ("red", "-"), ("green", "-")],
|
||||||
|
ylabel="ms",
|
||||||
|
plot_name=f"fp8-gemm-performance-comparison-tp{tp_size}",
|
||||||
|
args={},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
def benchmark(m, n, k, tp_size, provider):
|
||||||
|
print(f"Shape (m={m}, n={n}, k={k}, tp={tp_size}), Provider: {provider}")
|
||||||
|
x = torch.randn((m, k), device="cuda", dtype=torch.bfloat16)
|
||||||
|
y = torch.randn((n, k), device="cuda", dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
# Preprocess data before benchmarking
|
||||||
|
x_fp8, x_scale = per_token_cast_to_fp8(x)
|
||||||
|
y_fp8, y_scale = per_block_cast_to_fp8(y)
|
||||||
|
x_scale_col_major = get_col_major_tma_aligned_tensor(x_scale.clone())
|
||||||
|
|
||||||
|
quantiles = [0.5, 0.2, 0.8]
|
||||||
|
|
||||||
|
if provider == "deepgemm":
|
||||||
|
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||||
|
lambda: fp8_gemm_deepgemm(
|
||||||
|
x_fp8.clone(),
|
||||||
|
x_scale_col_major.clone(),
|
||||||
|
y_fp8.clone(),
|
||||||
|
y_scale.clone(),
|
||||||
|
m,
|
||||||
|
n,
|
||||||
|
k,
|
||||||
|
),
|
||||||
|
quantiles=quantiles,
|
||||||
|
)
|
||||||
|
elif provider == "sglang":
|
||||||
|
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||||
|
lambda: fp8_gemm_sglang(
|
||||||
|
x_fp8.clone(),
|
||||||
|
x_scale.clone(),
|
||||||
|
y_fp8.clone(),
|
||||||
|
y_scale.clone(),
|
||||||
|
m,
|
||||||
|
n,
|
||||||
|
k,
|
||||||
|
),
|
||||||
|
quantiles=quantiles,
|
||||||
|
)
|
||||||
|
else: # tilelang
|
||||||
|
tilelang_func = tl_gemm(m, n, k, "e4m3_float8", "bfloat16", "float32")
|
||||||
|
tilelang_kernel = tilelang.compile(tilelang_func, out_idx=[-1])
|
||||||
|
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||||
|
lambda: tilelang_kernel(
|
||||||
|
x_fp8.clone(),
|
||||||
|
x_scale.clone(),
|
||||||
|
y_fp8.clone(),
|
||||||
|
y_scale.clone(),
|
||||||
|
),
|
||||||
|
quantiles=quantiles,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Calculate TFLOPS
|
||||||
|
flops = 2 * m * n * k # multiply-adds
|
||||||
|
tflops = flops / (ms * 1e-3) / 1e12
|
||||||
|
|
||||||
|
# Print shape-specific results with TFLOPS
|
||||||
|
print(f"Time: {ms*1000:.2f} ms, TFLOPS: {tflops:.2f}")
|
||||||
|
return ms * 1000, max_ms * 1000, min_ms * 1000 # convert to ms
|
||||||
|
|
||||||
|
return benchmark
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--save_path",
|
||||||
|
type=str,
|
||||||
|
default="./configs/benchmark_ops/fp8_gemm/",
|
||||||
|
help="Path to save fp8 gemm benchmark results",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--run_correctness",
|
||||||
|
action="store_true",
|
||||||
|
default=True,
|
||||||
|
help="Whether to run correctness test",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--tp_size",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Tensor parallelism size to benchmark (default: 1)",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Set random seed for reproducibility
|
||||||
|
torch.manual_seed(0)
|
||||||
|
torch.cuda.manual_seed(0)
|
||||||
|
|
||||||
|
# Enable TF32, adapted from https://github.com/deepseek-ai/DeepGEMM/blob/main/tests/test_core.py#L148
|
||||||
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
torch.backends.cudnn.allow_tf32 = True
|
||||||
|
|
||||||
|
# Run correctness tests on a few examples
|
||||||
|
if args.run_correctness:
|
||||||
|
print("Running correctness tests...")
|
||||||
|
calculate_diff(64, 512, 7168) # Small test
|
||||||
|
calculate_diff(64, 7168, 16384) # Medium test
|
||||||
|
calculate_diff(64, 18432, 7168) # Large test
|
||||||
|
|
||||||
|
# Get the benchmark function with the specified tp_size
|
||||||
|
benchmark = get_benchmark(args.tp_size)
|
||||||
|
|
||||||
|
print(f"Running performance benchmark for TP size = {args.tp_size}...")
|
||||||
|
benchmark.run(print_data=True, save_path=args.save_path)
|
||||||
|
|
@ -0,0 +1,486 @@
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import deep_gemm
|
||||||
|
import torch
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
from deep_gemm import calc_diff, get_col_major_tma_aligned_tensor
|
||||||
|
|
||||||
|
# Import shared functionality from the regular GEMM benchmark
|
||||||
|
from sglang.benchmark.kernels.deepseek.benchmark_deepgemm_fp8_gemm import (
|
||||||
|
per_block_cast_to_fp8,
|
||||||
|
per_token_cast_to_fp8,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def construct_grouped_and_flat_fp8(
|
||||||
|
x: torch.Tensor, y: torch.Tensor, num_groups: int, is_masked: bool
|
||||||
|
) -> Tuple[
|
||||||
|
Tuple[torch.Tensor, torch.Tensor], # grouped x_fp8
|
||||||
|
Tuple[torch.Tensor, torch.Tensor], # grouped y_fp8
|
||||||
|
Tuple[torch.Tensor, torch.Tensor], # flat x_fp8
|
||||||
|
Tuple[torch.Tensor, torch.Tensor], # flat y_fp8
|
||||||
|
torch.Tensor, # output
|
||||||
|
torch.Tensor, # reference output
|
||||||
|
]:
|
||||||
|
# Verify input shapes
|
||||||
|
m, k = x.shape
|
||||||
|
n, k_y = y.shape
|
||||||
|
assert k == k_y, f"Incompatible shapes: x({m}, {k}), y({n}, {k_y})"
|
||||||
|
assert m % num_groups == 0, f"m({m}) must be divisible by num_groups({num_groups})"
|
||||||
|
assert m % 4 == 0, f"TMA alignment error: {m}"
|
||||||
|
|
||||||
|
# Reshape inputs for grouped processing
|
||||||
|
m_per_group = m // num_groups
|
||||||
|
x_grouped = x.view(num_groups, m_per_group, k)
|
||||||
|
y_grouped = y.unsqueeze(0).expand(num_groups, n, k)
|
||||||
|
|
||||||
|
# Initialize output tensors
|
||||||
|
out = torch.empty((num_groups, m_per_group, n), device="cuda", dtype=torch.bfloat16)
|
||||||
|
ref_out = torch.einsum("gmk,gnk->gmn", x_grouped, y_grouped)
|
||||||
|
|
||||||
|
# Quantize grouped tensors
|
||||||
|
x_fp8_grouped = (
|
||||||
|
torch.empty_like(x_grouped, dtype=torch.float8_e4m3fn),
|
||||||
|
torch.empty(
|
||||||
|
(num_groups, m_per_group, k // 128), device="cuda", dtype=torch.float
|
||||||
|
),
|
||||||
|
)
|
||||||
|
y_fp8_grouped = (
|
||||||
|
torch.empty_like(y_grouped, dtype=torch.float8_e4m3fn),
|
||||||
|
torch.empty(
|
||||||
|
(num_groups, (n + 127) // 128, k // 128), device="cuda", dtype=torch.float
|
||||||
|
),
|
||||||
|
)
|
||||||
|
for i in range(num_groups):
|
||||||
|
x_fp8_grouped[0][i], x_fp8_grouped[1][i] = per_token_cast_to_fp8(x_grouped[i])
|
||||||
|
y_fp8_grouped[0][i], y_fp8_grouped[1][i] = per_block_cast_to_fp8(y_grouped[i])
|
||||||
|
|
||||||
|
# Quantize flat tensors
|
||||||
|
x_fp8_flat = per_token_cast_to_fp8(x)
|
||||||
|
y_fp8_flat = per_block_cast_to_fp8(y)
|
||||||
|
|
||||||
|
# For non-masked input, merge the group and M dims in output
|
||||||
|
if not is_masked:
|
||||||
|
x_fp8_grouped = (
|
||||||
|
x_fp8_grouped[0].view(-1, k),
|
||||||
|
per_token_cast_to_fp8(x_grouped.view(-1, k))[1],
|
||||||
|
)
|
||||||
|
out, ref_out = out.view(-1, n), ref_out.view(-1, n)
|
||||||
|
|
||||||
|
# Transpose earlier for testing
|
||||||
|
x_fp8_grouped = (
|
||||||
|
x_fp8_grouped[0],
|
||||||
|
get_col_major_tma_aligned_tensor(x_fp8_grouped[1]),
|
||||||
|
)
|
||||||
|
x_fp8_flat = (x_fp8_flat[0], get_col_major_tma_aligned_tensor(x_fp8_flat[1]))
|
||||||
|
|
||||||
|
return x_fp8_grouped, y_fp8_grouped, x_fp8_flat, y_fp8_flat, out, ref_out
|
||||||
|
|
||||||
|
|
||||||
|
# Since we don't have a group gemm kernel in SGLang/vLLM, we implemented a
|
||||||
|
# custom kernel based on the Triton tutorial.
|
||||||
|
# https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html
|
||||||
|
@triton.jit
|
||||||
|
def fp8_gemm_group_triton_kernel(
|
||||||
|
# Pointers to matrices
|
||||||
|
a_ptr,
|
||||||
|
b_ptr,
|
||||||
|
c_ptr,
|
||||||
|
# Pointers to scaling factors
|
||||||
|
a_scale_ptr,
|
||||||
|
b_scale_ptr,
|
||||||
|
# Matrix dimensions
|
||||||
|
M,
|
||||||
|
N,
|
||||||
|
K,
|
||||||
|
# The stride variables represent how much to increase the ptr by when moving by 1
|
||||||
|
# element in a particular dimension.
|
||||||
|
stride_am,
|
||||||
|
stride_ak,
|
||||||
|
stride_bk,
|
||||||
|
stride_bn,
|
||||||
|
stride_cm,
|
||||||
|
stride_cn,
|
||||||
|
# Strides for scaling factors
|
||||||
|
stride_a_scale_m,
|
||||||
|
stride_a_scale_k,
|
||||||
|
stride_b_scale_n,
|
||||||
|
stride_b_scale_k,
|
||||||
|
# Meta-parameters
|
||||||
|
BLOCK_SIZE_M: tl.constexpr,
|
||||||
|
BLOCK_SIZE_N: tl.constexpr,
|
||||||
|
BLOCK_SIZE_K: tl.constexpr,
|
||||||
|
GROUP_SIZE_M: tl.constexpr,
|
||||||
|
):
|
||||||
|
"""Kernel for computing the matmul C = A x B with FP8 inputs and scaling factors.
|
||||||
|
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
|
||||||
|
|
||||||
|
Note: Block sizes must be multiples of 32 for optimal TMA performance.
|
||||||
|
"""
|
||||||
|
# Map program ids to the block of C it should compute
|
||||||
|
pid_group = tl.program_id(axis=0) # Group ID
|
||||||
|
pid_n = tl.program_id(axis=1) # N dimension ID
|
||||||
|
|
||||||
|
# Compute the M block ID within this group
|
||||||
|
group_size_m = min(M - pid_group * GROUP_SIZE_M, GROUP_SIZE_M)
|
||||||
|
pid_m_within_group = tl.program_id(axis=2) % group_size_m
|
||||||
|
pid_m = pid_group * GROUP_SIZE_M + pid_m_within_group
|
||||||
|
|
||||||
|
# Create pointers for the first blocks of A and B
|
||||||
|
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
|
||||||
|
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
|
||||||
|
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||||
|
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
|
||||||
|
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
|
||||||
|
|
||||||
|
# Initialize accumulator
|
||||||
|
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||||
|
|
||||||
|
# Main loop
|
||||||
|
for k_block in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
||||||
|
k_offset = k_block * BLOCK_SIZE_K
|
||||||
|
|
||||||
|
# Load the next block of A and B, with masks
|
||||||
|
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k_offset, other=0.0)
|
||||||
|
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k_offset, other=0.0)
|
||||||
|
|
||||||
|
# Calculate indices for scaling factors for this K block
|
||||||
|
a_scale_ptrs = a_scale_ptr + (
|
||||||
|
offs_am * stride_a_scale_m + k_block * stride_a_scale_k
|
||||||
|
)
|
||||||
|
b_scale_ptrs = b_scale_ptr + (
|
||||||
|
pid_n * stride_b_scale_n + k_block * stride_b_scale_k
|
||||||
|
)
|
||||||
|
|
||||||
|
# Perform matrix multiplication in FP8
|
||||||
|
res = tl.dot(a, b)
|
||||||
|
|
||||||
|
# Load scaling factors for the current block
|
||||||
|
a_scale = tl.load(a_scale_ptrs)[:, None] # [BLOCK_SIZE_M, 1]
|
||||||
|
b_scale = tl.load(b_scale_ptrs)
|
||||||
|
|
||||||
|
# Apply scaling factors to the accumulated result
|
||||||
|
accumulator += res * a_scale * b_scale
|
||||||
|
|
||||||
|
# Advance pointers
|
||||||
|
a_ptrs += BLOCK_SIZE_K * stride_ak
|
||||||
|
b_ptrs += BLOCK_SIZE_K * stride_bk
|
||||||
|
|
||||||
|
# Convert to bfloat16 for output
|
||||||
|
c = accumulator.to(tl.bfloat16)
|
||||||
|
|
||||||
|
# Write back the result
|
||||||
|
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||||
|
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||||
|
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
|
||||||
|
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
|
||||||
|
tl.store(c_ptrs, c, mask=c_mask)
|
||||||
|
|
||||||
|
|
||||||
|
def fp8_gemm_group_triton(a_tuple, b_tuple, c, num_groups):
|
||||||
|
"""
|
||||||
|
Perform matrix multiplication with FP8 inputs and proper scaling.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
a_tuple: Tuple of (quantized_tensor, scale_factors) for input A
|
||||||
|
b_tuple: Tuple of (quantized_tensor, scale_factors) for input B
|
||||||
|
c: Output tensor in BF16 format
|
||||||
|
num_groups: Number of groups for grouped GEMM
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Result tensor in BF16 format
|
||||||
|
"""
|
||||||
|
# Unpack the tuples
|
||||||
|
a, a_scale = a_tuple
|
||||||
|
b, b_scale = b_tuple
|
||||||
|
|
||||||
|
M, K = a.shape
|
||||||
|
_, N = b.shape
|
||||||
|
|
||||||
|
# Configure block sizes - must be multiples of 32 for TMA alignment
|
||||||
|
BLOCK_SIZE_M = 128
|
||||||
|
BLOCK_SIZE_N = 128
|
||||||
|
BLOCK_SIZE_K = 128
|
||||||
|
|
||||||
|
# Calculate grid dimensions
|
||||||
|
num_pid_m = triton.cdiv(M, BLOCK_SIZE_M)
|
||||||
|
num_pid_n = triton.cdiv(N, BLOCK_SIZE_N)
|
||||||
|
num_groups_grid = triton.cdiv(num_pid_m, num_groups)
|
||||||
|
|
||||||
|
# 3D grid launch - (group, n_blocks, m_blocks_per_group)
|
||||||
|
grid = (num_groups_grid, num_pid_n, min(num_groups, num_pid_m))
|
||||||
|
|
||||||
|
fp8_gemm_group_triton_kernel[grid](
|
||||||
|
a,
|
||||||
|
b,
|
||||||
|
c,
|
||||||
|
a_scale,
|
||||||
|
b_scale,
|
||||||
|
M,
|
||||||
|
N,
|
||||||
|
K,
|
||||||
|
a.stride(0),
|
||||||
|
a.stride(1),
|
||||||
|
b.stride(0),
|
||||||
|
b.stride(1),
|
||||||
|
c.stride(0),
|
||||||
|
c.stride(1),
|
||||||
|
a_scale.stride(0),
|
||||||
|
1, # Stride in the K dimension may be 1
|
||||||
|
b_scale.stride(0),
|
||||||
|
1 if b_scale.dim() > 1 else 0,
|
||||||
|
BLOCK_SIZE_M=BLOCK_SIZE_M,
|
||||||
|
BLOCK_SIZE_N=BLOCK_SIZE_N,
|
||||||
|
BLOCK_SIZE_K=BLOCK_SIZE_K,
|
||||||
|
GROUP_SIZE_M=num_groups,
|
||||||
|
)
|
||||||
|
|
||||||
|
return c
|
||||||
|
|
||||||
|
|
||||||
|
def fp8_gemm_group_deepgemm(x_fp8_grouped, y_fp8_grouped, out, m_indices):
|
||||||
|
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
|
||||||
|
x_fp8_grouped,
|
||||||
|
y_fp8_grouped,
|
||||||
|
out,
|
||||||
|
m_indices,
|
||||||
|
)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_diff(m: int, n: int, k: int, num_groups: int):
|
||||||
|
print(f"Shape (m={m}, n={n}, k={k}")
|
||||||
|
x = torch.randn((m, k), device="cuda", dtype=torch.bfloat16)
|
||||||
|
y = torch.randn((n, k), device="cuda", dtype=torch.bfloat16)
|
||||||
|
x_fp8_grouped, y_fp8_grouped, x_fp8_flat, y_fp8_flat, out, out_torch = (
|
||||||
|
construct_grouped_and_flat_fp8(x, y, num_groups, is_masked=False)
|
||||||
|
)
|
||||||
|
m_per_group = m // num_groups
|
||||||
|
out_deepgemm = out.clone()
|
||||||
|
m_indices = torch.arange(0, num_groups, device="cuda", dtype=torch.int)
|
||||||
|
m_indices = (
|
||||||
|
m_indices.unsqueeze(-1).expand(num_groups, m_per_group).contiguous().view(-1)
|
||||||
|
)
|
||||||
|
|
||||||
|
fp8_gemm_group_deepgemm(
|
||||||
|
x_fp8_grouped,
|
||||||
|
y_fp8_grouped,
|
||||||
|
out_deepgemm,
|
||||||
|
m_indices,
|
||||||
|
)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
# Prepare inputs for Triton
|
||||||
|
a, a_scale = x_fp8_flat
|
||||||
|
b, b_scale = y_fp8_flat
|
||||||
|
b = b.T.contiguous()
|
||||||
|
# Ensure scales are in the right format and contiguous
|
||||||
|
a_scale, b_scale = a_scale.contiguous(), b_scale.contiguous()
|
||||||
|
M, _ = a.shape
|
||||||
|
_, N = b.shape
|
||||||
|
c = torch.empty((M, N), device=a.device, dtype=torch.bfloat16)
|
||||||
|
out_triton = fp8_gemm_group_triton((a, a_scale), (b, b_scale), c, num_groups)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
diff_torch_deepgemm = torch.abs(out_torch - out_deepgemm).mean().item()
|
||||||
|
diff_torch_triton = torch.abs(out_torch - out_triton).mean().item()
|
||||||
|
diff_deepgemm_triton = torch.abs(out_deepgemm - out_triton).mean().item()
|
||||||
|
|
||||||
|
print(f"Shape m={m}, n={n}, k={k}:")
|
||||||
|
print(f"Torch output: {out_torch[0, 0:5]}")
|
||||||
|
print(f"DeepGEMM output: {out_deepgemm[0, 0:5]}")
|
||||||
|
print(f"Triton output: {out_triton[0, 0:5]}")
|
||||||
|
print(f"Mean absolute difference (Torch-DeepGEMM): {diff_torch_deepgemm}")
|
||||||
|
print(f"Mean absolute difference (Torch-Triton): {diff_torch_triton}")
|
||||||
|
print(f"Mean absolute difference (DeepGEMM-Triton): {diff_deepgemm_triton}")
|
||||||
|
|
||||||
|
deepgemm_torch_diff = calc_diff(out_deepgemm, out_torch)
|
||||||
|
triton_torch_diff = calc_diff(out_triton, out_torch)
|
||||||
|
deepgemm_triton_diff = calc_diff(out_deepgemm, out_triton)
|
||||||
|
|
||||||
|
DIFF_THRESHOLD = 0.001
|
||||||
|
all_match = (
|
||||||
|
deepgemm_torch_diff < DIFF_THRESHOLD
|
||||||
|
and triton_torch_diff < DIFF_THRESHOLD
|
||||||
|
and deepgemm_triton_diff < DIFF_THRESHOLD
|
||||||
|
)
|
||||||
|
if all_match:
|
||||||
|
print("✅ All implementations match\n")
|
||||||
|
else:
|
||||||
|
print("❌ Some implementations differ:")
|
||||||
|
print(
|
||||||
|
f" - Torch vs DeepGEMM: {'✅' if deepgemm_torch_diff < DIFF_THRESHOLD else '❌'}"
|
||||||
|
f" - Torch vs Triton: {'✅' if triton_torch_diff < DIFF_THRESHOLD else '❌'}"
|
||||||
|
f" - DeepGEMM vs Triton: {'✅' if deepgemm_triton_diff < DIFF_THRESHOLD else '❌'}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_weight_shapes(tp_size):
|
||||||
|
# cannot TP
|
||||||
|
total = [
|
||||||
|
(512 + 64, 7168),
|
||||||
|
((128 + 64) * 128, 7168),
|
||||||
|
(128 * (128 + 128), 512),
|
||||||
|
(7168, 16384),
|
||||||
|
(7168, 18432),
|
||||||
|
]
|
||||||
|
# N can TP
|
||||||
|
n_tp = [
|
||||||
|
(18432 * 2, 7168),
|
||||||
|
((128 + 64) * 128, 7168),
|
||||||
|
(128 * (128 + 128), 512),
|
||||||
|
(24576, 1536),
|
||||||
|
(4096, 7168),
|
||||||
|
]
|
||||||
|
# K can TP
|
||||||
|
k_tp = [(7168, 18432), (7168, 16384), (7168, 2048)]
|
||||||
|
|
||||||
|
weight_shapes = []
|
||||||
|
for t in total:
|
||||||
|
weight_shapes.append(t)
|
||||||
|
for n_t in n_tp:
|
||||||
|
new_t = (n_t[0] // tp_size, n_t[1])
|
||||||
|
weight_shapes.append(new_t)
|
||||||
|
for k_t in k_tp:
|
||||||
|
new_t = (k_t[0], k_t[1] // tp_size)
|
||||||
|
weight_shapes.append(new_t)
|
||||||
|
|
||||||
|
return weight_shapes
|
||||||
|
|
||||||
|
|
||||||
|
def create_benchmark_configs(tp_size):
|
||||||
|
configs = []
|
||||||
|
weight_shapes = get_weight_shapes(tp_size)
|
||||||
|
batch_sizes = [2048, 4096]
|
||||||
|
group_sizes = [4, 8]
|
||||||
|
for n, k in weight_shapes:
|
||||||
|
for m in batch_sizes:
|
||||||
|
for num_groups in group_sizes:
|
||||||
|
configs.append((m, n, k, num_groups, tp_size))
|
||||||
|
|
||||||
|
return configs
|
||||||
|
|
||||||
|
|
||||||
|
def get_benchmark(tp_size):
|
||||||
|
all_configs = create_benchmark_configs(tp_size)
|
||||||
|
|
||||||
|
@triton.testing.perf_report(
|
||||||
|
triton.testing.Benchmark(
|
||||||
|
x_names=["m", "n", "k", "num_groups", "tp_size"],
|
||||||
|
x_vals=[config for config in all_configs],
|
||||||
|
line_arg="provider",
|
||||||
|
line_vals=["deepgemm", "triton"],
|
||||||
|
line_names=["DeepGEMM", "Triton"],
|
||||||
|
styles=[("blue", "-"), ("red", "-")],
|
||||||
|
ylabel="ms",
|
||||||
|
plot_name=f"fp8-group-gemm-performance-comparison-tp{tp_size}",
|
||||||
|
args={},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
def benchmark(m, n, k, num_groups, tp_size, provider):
|
||||||
|
print(
|
||||||
|
f"Shape (m={m}, n={n}, k={k}, tp={tp_size}, num_groups={num_groups}, Provider: {provider}"
|
||||||
|
)
|
||||||
|
x = torch.randn((m, k), device="cuda", dtype=torch.bfloat16)
|
||||||
|
y = torch.randn((n, k), device="cuda", dtype=torch.bfloat16)
|
||||||
|
x_fp8_grouped, y_fp8_grouped, x_fp8_flat, y_fp8_flat, out, out_torch = (
|
||||||
|
construct_grouped_and_flat_fp8(x, y, num_groups, is_masked=False)
|
||||||
|
)
|
||||||
|
m_per_group = m // num_groups
|
||||||
|
m_indices = torch.arange(0, num_groups, device="cuda", dtype=torch.int)
|
||||||
|
m_indices = (
|
||||||
|
m_indices.unsqueeze(-1)
|
||||||
|
.expand(num_groups, m_per_group)
|
||||||
|
.contiguous()
|
||||||
|
.view(-1)
|
||||||
|
)
|
||||||
|
|
||||||
|
quantiles = [0.5, 0.2, 0.8]
|
||||||
|
|
||||||
|
if provider == "deepgemm":
|
||||||
|
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||||
|
lambda: fp8_gemm_group_deepgemm(
|
||||||
|
x_fp8_grouped,
|
||||||
|
y_fp8_grouped,
|
||||||
|
out,
|
||||||
|
m_indices,
|
||||||
|
),
|
||||||
|
quantiles=quantiles,
|
||||||
|
)
|
||||||
|
elif provider == "triton":
|
||||||
|
# Prepare inputs for Triton
|
||||||
|
# We did it outside of the lambda function to make it fair comparison like deepgemm
|
||||||
|
a, a_scale = x_fp8_flat
|
||||||
|
b, b_scale = y_fp8_flat
|
||||||
|
b = b.T.contiguous()
|
||||||
|
# Ensure scales are in the right format and contiguous
|
||||||
|
a_scale, b_scale = a_scale.contiguous(), b_scale.contiguous()
|
||||||
|
M, _ = a.shape
|
||||||
|
_, N = b.shape
|
||||||
|
c = torch.empty((M, N), device=a.device, dtype=torch.bfloat16)
|
||||||
|
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||||
|
lambda: fp8_gemm_group_triton(
|
||||||
|
(a, a_scale),
|
||||||
|
(b, b_scale),
|
||||||
|
c,
|
||||||
|
num_groups,
|
||||||
|
),
|
||||||
|
quantiles=quantiles,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Calculate TFLOPS
|
||||||
|
flops = 2 * m * n * k # multiply-adds
|
||||||
|
tflops = flops / (ms * 1e-3) / 1e12
|
||||||
|
|
||||||
|
print(f"Time: {ms*1000:.2f} ms, TFLOPS: {tflops:.2f}")
|
||||||
|
return ms * 1000, max_ms * 1000, min_ms * 1000 # convert to ms
|
||||||
|
|
||||||
|
return benchmark
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--save_path",
|
||||||
|
type=str,
|
||||||
|
default="./configs/benchmark_ops/fp8_group_gemm/",
|
||||||
|
help="Path to save deepgemm fp8 group gemm benchmark results",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--run_correctness",
|
||||||
|
action="store_true",
|
||||||
|
help="Whether to run correctness test",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--tp_size",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Tensor parallelism size to benchmark (default: 1)",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Set random seed for reproducibility
|
||||||
|
torch.manual_seed(0)
|
||||||
|
torch.cuda.manual_seed(0)
|
||||||
|
|
||||||
|
# Enable TF32, adapted from https://github.com/deepseek-ai/DeepGEMM/blob/main/tests/test_core.py#L148
|
||||||
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
torch.backends.cudnn.allow_tf32 = True
|
||||||
|
|
||||||
|
# Run correctness tests on a few examples
|
||||||
|
if args.run_correctness:
|
||||||
|
print("Running correctness tests...")
|
||||||
|
calculate_diff(8192, 7168, 4096, 4)
|
||||||
|
calculate_diff(8192, 2048, 7168, 4)
|
||||||
|
calculate_diff(4096, 7168, 4096, 8)
|
||||||
|
calculate_diff(4096, 2048, 7168, 8)
|
||||||
|
calculate_diff(4096, 576, 7168, 8)
|
||||||
|
|
||||||
|
# Get the benchmark function with the specified tp_size
|
||||||
|
benchmark = get_benchmark(args.tp_size)
|
||||||
|
|
||||||
|
print(f"Running performance benchmark for TP size = {args.tp_size}...")
|
||||||
|
benchmark.run(print_data=True, save_path=args.save_path)
|
||||||
|
|
@ -0,0 +1,49 @@
|
||||||
|
## Benchmark Kernels
|
||||||
|
|
||||||
|
This directory contains benchmarking tools for MoE (Mixture of Experts) kernels.
|
||||||
|
|
||||||
|
### Tuning Tool
|
||||||
|
|
||||||
|
- `tuning_fused_moe_triton.py`: A tool for tuning the `fused_moe_triton` kernel. Adapted from [vllm's benchmark_moe.py](https://github.com/vllm-project/vllm/blob/main/benchmarks/kernels/benchmark_moe.py), with added support for various model architectures.
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
```bash
|
||||||
|
# Tune Qwen2-57B with FP8 and TP=4
|
||||||
|
python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py \
|
||||||
|
--model Qwen/Qwen2-57B-A14B-Instruct \
|
||||||
|
--tp-size 4 \
|
||||||
|
--dtype fp8_w8a8 \
|
||||||
|
--tune
|
||||||
|
|
||||||
|
# Tune Mixtral-8x7B with default settings
|
||||||
|
python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py \
|
||||||
|
--model mistralai/Mixtral-8x7B-Instruct-v0.1 \
|
||||||
|
--tune
|
||||||
|
```
|
||||||
|
|
||||||
|
After tuning, a configuration file (e.g., `E=64,N=640,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json`) will be generated in the current directory. You can move this file to `sglang/srt/layers/fused_moe_triton/configs/` to use it in `sglang`.
|
||||||
|
|
||||||
|
### Performance Comparison Tool
|
||||||
|
|
||||||
|
- `benchmark_vllm_vs_sglang_fused_moe_triton.py`: A tool for comparing the performance of fused MoE kernels between vllm and sglang implementations. Supports various model architectures and data types.
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
```bash
|
||||||
|
# Compare with default settings (Mixtral model)
|
||||||
|
python benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py
|
||||||
|
|
||||||
|
# Compare with FP8 mode for Qwen2-57B
|
||||||
|
python benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py \
|
||||||
|
--model Qwen/Qwen2-57B-A14B-Instruct \
|
||||||
|
--use-fp8
|
||||||
|
|
||||||
|
# Compare with custom TP size
|
||||||
|
python benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py \
|
||||||
|
--tp-size 4
|
||||||
|
```
|
||||||
|
|
||||||
|
The benchmark results will be saved as plots and data files in the specified output directory (default: `./configs/benchmark_ops/vllm_sglang_fused_moe/`).
|
||||||
|
|
||||||
|
- `benchmark_torch_compile_fused_moe.py`: A tool for benchmarking the performance of the fused MoE kernel with `torch.compile` and original fused MoE kernel.
|
||||||
|
|
||||||
|
Usage is the same as `benchmark_vllm_vs_sglang_fused_moe_triton.py`.
|
||||||
|
|
@ -0,0 +1,294 @@
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import triton
|
||||||
|
from torch.nn import functional as F
|
||||||
|
from transformers import AutoConfig
|
||||||
|
|
||||||
|
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
|
||||||
|
fused_moe as fused_moe_triton,
|
||||||
|
)
|
||||||
|
from sglang.srt.model_executor.cuda_graph_runner import set_torch_compile_config
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_config(model_name: str, tp_size: int):
|
||||||
|
"""Get model configuration parameters"""
|
||||||
|
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
|
||||||
|
|
||||||
|
if config.architectures[0] == "DbrxForCausalLM":
|
||||||
|
E = config.ffn_config.moe_num_experts
|
||||||
|
topk = config.ffn_config.moe_top_k
|
||||||
|
intermediate_size = config.ffn_config.ffn_hidden_size
|
||||||
|
shard_intermediate_size = 2 * intermediate_size // tp_size
|
||||||
|
elif config.architectures[0] == "JambaForCausalLM":
|
||||||
|
E = config.num_experts
|
||||||
|
topk = config.num_experts_per_tok
|
||||||
|
intermediate_size = config.intermediate_size
|
||||||
|
shard_intermediate_size = 2 * intermediate_size // tp_size
|
||||||
|
elif config.architectures[0] == "Qwen2MoeForCausalLM":
|
||||||
|
E = config.num_experts
|
||||||
|
topk = config.num_experts_per_tok
|
||||||
|
intermediate_size = config.moe_intermediate_size
|
||||||
|
shard_intermediate_size = 2 * intermediate_size // tp_size
|
||||||
|
elif config.architectures[0] in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]:
|
||||||
|
E = config.n_routed_experts
|
||||||
|
topk = config.num_experts_per_tok
|
||||||
|
intermediate_size = config.intermediate_size
|
||||||
|
shard_intermediate_size = 2 * intermediate_size // tp_size
|
||||||
|
elif config.architectures[0] in [
|
||||||
|
"Grok1ForCausalLM",
|
||||||
|
"Grok1ImgGen",
|
||||||
|
"Grok1AForCausalLM",
|
||||||
|
]:
|
||||||
|
E = config.num_local_experts
|
||||||
|
topk = config.num_experts_per_tok
|
||||||
|
intermediate_size = config.moe_intermediate_size
|
||||||
|
shard_intermediate_size = 2 * intermediate_size // tp_size
|
||||||
|
else:
|
||||||
|
# Default: Mixtral
|
||||||
|
E = config.num_local_experts
|
||||||
|
topk = config.num_experts_per_tok
|
||||||
|
intermediate_size = config.intermediate_size
|
||||||
|
shard_intermediate_size = 2 * intermediate_size // tp_size
|
||||||
|
|
||||||
|
shape_configs = {
|
||||||
|
"num_experts": E,
|
||||||
|
"topk": topk,
|
||||||
|
"hidden_size": config.hidden_size,
|
||||||
|
"shard_intermediate_size": shard_intermediate_size,
|
||||||
|
"dtype": config.torch_dtype,
|
||||||
|
}
|
||||||
|
print(f"{shape_configs=}")
|
||||||
|
return shape_configs
|
||||||
|
|
||||||
|
|
||||||
|
def fused_topk_native(
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
gating_output: torch.Tensor,
|
||||||
|
topk: int,
|
||||||
|
renormalize: bool,
|
||||||
|
):
|
||||||
|
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
||||||
|
M, _ = hidden_states.shape
|
||||||
|
topk_weights = torch.empty(
|
||||||
|
M, topk, dtype=torch.float32, device=hidden_states.device
|
||||||
|
)
|
||||||
|
topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
|
||||||
|
topk_weights = F.softmax(gating_output.float(), dim=-1)
|
||||||
|
topk_weights, topk_ids = torch.topk(topk_weights, topk, dim=-1)
|
||||||
|
if renormalize:
|
||||||
|
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||||
|
return topk_weights, topk_ids
|
||||||
|
|
||||||
|
|
||||||
|
@torch.compile(dynamic=False)
|
||||||
|
def fused_moe_torch(
|
||||||
|
x,
|
||||||
|
w1,
|
||||||
|
w2,
|
||||||
|
input_gating,
|
||||||
|
topk,
|
||||||
|
use_fp8_w8a8=False,
|
||||||
|
w1_scale=None,
|
||||||
|
w2_scale=None,
|
||||||
|
a1_scale=None,
|
||||||
|
a2_scale=None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
assert not use_fp8_w8a8, "Not supported"
|
||||||
|
|
||||||
|
topk_weights, topk_ids = fused_topk_native(
|
||||||
|
hidden_states=x,
|
||||||
|
gating_output=input_gating,
|
||||||
|
topk=topk,
|
||||||
|
renormalize=True,
|
||||||
|
)
|
||||||
|
w13_weights = w1[topk_ids]
|
||||||
|
w1_weights, w3_weights = torch.chunk(w13_weights, 2, dim=2)
|
||||||
|
w2_weights = w2[topk_ids]
|
||||||
|
x1 = torch.einsum("ti,taoi -> tao", x, w1_weights)
|
||||||
|
x1 = F.silu(x1)
|
||||||
|
x3 = torch.einsum("ti, taoi -> tao", x, w3_weights)
|
||||||
|
expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights)
|
||||||
|
return torch.einsum("tai,ta -> ti", expert_outs, topk_weights.to(expert_outs.dtype))
|
||||||
|
|
||||||
|
|
||||||
|
def fused_moe_torch_compile(
|
||||||
|
x,
|
||||||
|
w1,
|
||||||
|
w2,
|
||||||
|
input_gating,
|
||||||
|
topk,
|
||||||
|
use_fp8_w8a8=False,
|
||||||
|
w1_scale=None,
|
||||||
|
w2_scale=None,
|
||||||
|
a1_scale=None,
|
||||||
|
a2_scale=None,
|
||||||
|
):
|
||||||
|
return fused_moe_torch(
|
||||||
|
x,
|
||||||
|
w1,
|
||||||
|
w2,
|
||||||
|
input_gating,
|
||||||
|
topk,
|
||||||
|
use_fp8_w8a8=use_fp8_w8a8,
|
||||||
|
w1_scale=w1_scale,
|
||||||
|
w2_scale=w2_scale,
|
||||||
|
a1_scale=a1_scale,
|
||||||
|
a2_scale=a2_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def fused_moe_sglang_api(
|
||||||
|
x,
|
||||||
|
w1,
|
||||||
|
w2,
|
||||||
|
input_gating,
|
||||||
|
topk,
|
||||||
|
use_fp8_w8a8=False,
|
||||||
|
w1_scale=None,
|
||||||
|
w2_scale=None,
|
||||||
|
a1_scale=None,
|
||||||
|
a2_scale=None,
|
||||||
|
):
|
||||||
|
return fused_moe_triton(
|
||||||
|
x,
|
||||||
|
w1,
|
||||||
|
w2,
|
||||||
|
input_gating,
|
||||||
|
topk,
|
||||||
|
renormalize=True,
|
||||||
|
inplace=True,
|
||||||
|
use_fp8_w8a8=use_fp8_w8a8,
|
||||||
|
w1_scale=w1_scale,
|
||||||
|
w2_scale=w2_scale,
|
||||||
|
a1_scale=a1_scale,
|
||||||
|
a2_scale=a2_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.testing.perf_report(
|
||||||
|
triton.testing.Benchmark(
|
||||||
|
x_names=["batch_size"],
|
||||||
|
x_vals=list(range(1, 5)),
|
||||||
|
line_arg="provider",
|
||||||
|
line_vals=[
|
||||||
|
"fused_moe_triton",
|
||||||
|
"fused_moe_torch_compile",
|
||||||
|
],
|
||||||
|
line_names=[
|
||||||
|
"fused_moe_triton",
|
||||||
|
"fused_moe_torch_compile",
|
||||||
|
],
|
||||||
|
styles=[
|
||||||
|
("blue", "-"),
|
||||||
|
("green", "-"),
|
||||||
|
],
|
||||||
|
ylabel="Time (ms)",
|
||||||
|
plot_name="fused-moe-performance",
|
||||||
|
args={},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
def benchmark(batch_size, provider, model_config, use_fp8=False):
|
||||||
|
print(f"benchmark {provider} with batch_size={batch_size}")
|
||||||
|
torch.set_default_device("cuda")
|
||||||
|
torch.cuda.manual_seed_all(0)
|
||||||
|
set_torch_compile_config()
|
||||||
|
|
||||||
|
num_tokens = batch_size
|
||||||
|
num_experts = model_config["num_experts"]
|
||||||
|
hidden_size = model_config["hidden_size"]
|
||||||
|
shard_intermediate_size = model_config["shard_intermediate_size"]
|
||||||
|
topk = model_config["topk"]
|
||||||
|
dtype = model_config["dtype"]
|
||||||
|
|
||||||
|
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
|
||||||
|
|
||||||
|
if use_fp8:
|
||||||
|
init_dtype = dtype
|
||||||
|
w1 = torch.randn(
|
||||||
|
num_experts, shard_intermediate_size, hidden_size, dtype=init_dtype
|
||||||
|
)
|
||||||
|
w2 = torch.randn(
|
||||||
|
num_experts, hidden_size, shard_intermediate_size // 2, dtype=init_dtype
|
||||||
|
)
|
||||||
|
w1 = w1.to(torch.float8_e4m3fn)
|
||||||
|
w2 = w2.to(torch.float8_e4m3fn)
|
||||||
|
w1_scale = torch.randn(num_experts, dtype=torch.float32)
|
||||||
|
w2_scale = torch.randn(num_experts, dtype=torch.float32)
|
||||||
|
a1_scale = torch.randn(1, dtype=torch.float32)
|
||||||
|
a2_scale = torch.randn(1, dtype=torch.float32)
|
||||||
|
else:
|
||||||
|
w1 = torch.randn(num_experts, shard_intermediate_size, hidden_size, dtype=dtype)
|
||||||
|
w2 = torch.randn(
|
||||||
|
num_experts, hidden_size, shard_intermediate_size // 2, dtype=dtype
|
||||||
|
)
|
||||||
|
w1_scale = w2_scale = a1_scale = a2_scale = None
|
||||||
|
|
||||||
|
input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32)
|
||||||
|
|
||||||
|
# Warmup
|
||||||
|
api_func = (
|
||||||
|
fused_moe_torch_compile
|
||||||
|
if provider == "fused_moe_torch_compile"
|
||||||
|
else fused_moe_sglang_api
|
||||||
|
)
|
||||||
|
for _ in range(10):
|
||||||
|
y = api_func(
|
||||||
|
x,
|
||||||
|
w1,
|
||||||
|
w2,
|
||||||
|
input_gating,
|
||||||
|
topk,
|
||||||
|
use_fp8_w8a8=use_fp8,
|
||||||
|
w1_scale=w1_scale,
|
||||||
|
w2_scale=w2_scale,
|
||||||
|
a1_scale=a1_scale,
|
||||||
|
a2_scale=a2_scale,
|
||||||
|
)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
quantiles = [0.5, 0.2, 0.8]
|
||||||
|
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||||
|
lambda: api_func(
|
||||||
|
x,
|
||||||
|
w1,
|
||||||
|
w2,
|
||||||
|
input_gating,
|
||||||
|
topk,
|
||||||
|
use_fp8_w8a8=use_fp8,
|
||||||
|
w1_scale=w1_scale,
|
||||||
|
w2_scale=w2_scale,
|
||||||
|
a1_scale=a1_scale,
|
||||||
|
a2_scale=a2_scale,
|
||||||
|
)[0],
|
||||||
|
quantiles=quantiles,
|
||||||
|
)
|
||||||
|
return ms, min_ms, max_ms
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1"
|
||||||
|
)
|
||||||
|
parser.add_argument("--tp-size", type=int, default=2)
|
||||||
|
parser.add_argument("--use-fp8", action="store_true")
|
||||||
|
parser.add_argument(
|
||||||
|
"--save-path",
|
||||||
|
type=str,
|
||||||
|
default="./configs/benchmark_ops/fused_moe_torch_compile/",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
model_config = get_model_config(args.model, args.tp_size)
|
||||||
|
benchmark.run(
|
||||||
|
show_plots=True,
|
||||||
|
print_data=True,
|
||||||
|
save_path=args.save_path,
|
||||||
|
model_config=model_config,
|
||||||
|
use_fp8=args.use_fp8,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
@ -0,0 +1,299 @@
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import triton
|
||||||
|
import vllm
|
||||||
|
from transformers import AutoConfig
|
||||||
|
from vllm.model_executor.layers.fused_moe.fused_moe import fused_moe as fused_moe_vllm
|
||||||
|
|
||||||
|
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
|
||||||
|
fused_moe as fused_moe_sglang,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_config(model_name: str, tp_size: int):
|
||||||
|
"""Get model configuration parameters"""
|
||||||
|
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
|
||||||
|
|
||||||
|
if config.architectures[0] == "DbrxForCausalLM":
|
||||||
|
E = config.ffn_config.moe_num_experts
|
||||||
|
topk = config.ffn_config.moe_top_k
|
||||||
|
intermediate_size = config.ffn_config.ffn_hidden_size
|
||||||
|
shard_intermediate_size = 2 * intermediate_size // tp_size
|
||||||
|
elif config.architectures[0] == "JambaForCausalLM":
|
||||||
|
E = config.num_experts
|
||||||
|
topk = config.num_experts_per_tok
|
||||||
|
intermediate_size = config.intermediate_size
|
||||||
|
shard_intermediate_size = 2 * intermediate_size // tp_size
|
||||||
|
elif config.architectures[0] == "Qwen2MoeForCausalLM":
|
||||||
|
E = config.num_experts
|
||||||
|
topk = config.num_experts_per_tok
|
||||||
|
intermediate_size = config.moe_intermediate_size
|
||||||
|
shard_intermediate_size = 2 * intermediate_size // tp_size
|
||||||
|
elif config.architectures[0] in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]:
|
||||||
|
E = config.n_routed_experts
|
||||||
|
topk = config.num_experts_per_tok
|
||||||
|
intermediate_size = config.moe_intermediate_size
|
||||||
|
shard_intermediate_size = 2 * intermediate_size // tp_size
|
||||||
|
elif config.architectures[0] in [
|
||||||
|
"Grok1ForCausalLM",
|
||||||
|
"Grok1ImgGen",
|
||||||
|
"Grok1AForCausalLM",
|
||||||
|
]:
|
||||||
|
E = config.num_local_experts
|
||||||
|
topk = config.num_experts_per_tok
|
||||||
|
intermediate_size = config.moe_intermediate_size
|
||||||
|
shard_intermediate_size = 2 * intermediate_size // tp_size
|
||||||
|
else:
|
||||||
|
# Default: Mixtral
|
||||||
|
E = config.num_local_experts
|
||||||
|
topk = config.num_experts_per_tok
|
||||||
|
intermediate_size = config.intermediate_size
|
||||||
|
shard_intermediate_size = 2 * intermediate_size // tp_size
|
||||||
|
|
||||||
|
vllm_version_num = (
|
||||||
|
vllm.__version_tuple__[0] * 100
|
||||||
|
+ vllm.__version_tuple__[1] * 10
|
||||||
|
+ vllm.__version_tuple__[2]
|
||||||
|
)
|
||||||
|
block_shape = None
|
||||||
|
if (
|
||||||
|
hasattr(config, "quantization_config")
|
||||||
|
and "weight_block_size" in config.quantization_config
|
||||||
|
):
|
||||||
|
block_shape = config.quantization_config["weight_block_size"]
|
||||||
|
assert len(block_shape) == 2
|
||||||
|
assert (
|
||||||
|
vllm_version_num >= 66
|
||||||
|
), "Block-wise quantized fp8 fused_moe is only supported for VLLM>=0.6.6.post1"
|
||||||
|
|
||||||
|
shape_configs = {
|
||||||
|
"num_experts": E,
|
||||||
|
"topk": topk,
|
||||||
|
"hidden_size": config.hidden_size,
|
||||||
|
"shard_intermediate_size": shard_intermediate_size,
|
||||||
|
"dtype": config.torch_dtype,
|
||||||
|
"block_shape": block_shape,
|
||||||
|
}
|
||||||
|
print(f"{shape_configs=}")
|
||||||
|
return shape_configs
|
||||||
|
|
||||||
|
|
||||||
|
def fused_moe_vllm_api(
|
||||||
|
x,
|
||||||
|
w1,
|
||||||
|
w2,
|
||||||
|
input_gating,
|
||||||
|
topk,
|
||||||
|
use_fp8_w8a8=False,
|
||||||
|
w1_scale=None,
|
||||||
|
w2_scale=None,
|
||||||
|
a1_scale=None,
|
||||||
|
a2_scale=None,
|
||||||
|
block_shape=None,
|
||||||
|
):
|
||||||
|
if block_shape is not None:
|
||||||
|
return fused_moe_vllm(
|
||||||
|
x,
|
||||||
|
w1,
|
||||||
|
w2,
|
||||||
|
input_gating,
|
||||||
|
topk,
|
||||||
|
renormalize=True,
|
||||||
|
inplace=True,
|
||||||
|
use_fp8_w8a8=use_fp8_w8a8,
|
||||||
|
w1_scale=w1_scale,
|
||||||
|
w2_scale=w2_scale,
|
||||||
|
a1_scale=a1_scale,
|
||||||
|
a2_scale=a2_scale,
|
||||||
|
block_shape=block_shape,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return fused_moe_vllm(
|
||||||
|
x,
|
||||||
|
w1,
|
||||||
|
w2,
|
||||||
|
input_gating,
|
||||||
|
topk,
|
||||||
|
renormalize=True,
|
||||||
|
inplace=True,
|
||||||
|
use_fp8_w8a8=use_fp8_w8a8,
|
||||||
|
w1_scale=w1_scale,
|
||||||
|
w2_scale=w2_scale,
|
||||||
|
a1_scale=a1_scale,
|
||||||
|
a2_scale=a2_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def fused_moe_sglang_api(
|
||||||
|
x,
|
||||||
|
w1,
|
||||||
|
w2,
|
||||||
|
input_gating,
|
||||||
|
topk,
|
||||||
|
use_fp8_w8a8=False,
|
||||||
|
w1_scale=None,
|
||||||
|
w2_scale=None,
|
||||||
|
a1_scale=None,
|
||||||
|
a2_scale=None,
|
||||||
|
block_shape=None,
|
||||||
|
):
|
||||||
|
return fused_moe_sglang(
|
||||||
|
x,
|
||||||
|
w1,
|
||||||
|
w2,
|
||||||
|
input_gating,
|
||||||
|
topk,
|
||||||
|
renormalize=True,
|
||||||
|
inplace=True,
|
||||||
|
use_fp8_w8a8=use_fp8_w8a8,
|
||||||
|
w1_scale=w1_scale,
|
||||||
|
w2_scale=w2_scale,
|
||||||
|
a1_scale=a1_scale,
|
||||||
|
a2_scale=a2_scale,
|
||||||
|
block_shape=block_shape,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.testing.perf_report(
|
||||||
|
triton.testing.Benchmark(
|
||||||
|
x_names=["batch_size"],
|
||||||
|
x_vals=list(range(1, 513)),
|
||||||
|
line_arg="provider",
|
||||||
|
line_vals=[
|
||||||
|
"vllm_fused_moe_triton",
|
||||||
|
"sglang_fused_moe_triton",
|
||||||
|
],
|
||||||
|
line_names=[
|
||||||
|
"vllm_fused_moe_triton",
|
||||||
|
"sglang_fused_moe_triton",
|
||||||
|
],
|
||||||
|
styles=[
|
||||||
|
("blue", "-"),
|
||||||
|
("green", "-"),
|
||||||
|
],
|
||||||
|
ylabel="Time (ms)",
|
||||||
|
plot_name="fused-moe-performance",
|
||||||
|
args={},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
def benchmark(batch_size, provider, model_config, use_fp8=False):
|
||||||
|
print(f"benchmark {provider} with batch_size={batch_size}")
|
||||||
|
torch.set_default_device("cuda")
|
||||||
|
torch.cuda.manual_seed_all(0)
|
||||||
|
|
||||||
|
num_tokens = batch_size
|
||||||
|
num_experts = model_config["num_experts"]
|
||||||
|
hidden_size = model_config["hidden_size"]
|
||||||
|
shard_intermediate_size = model_config["shard_intermediate_size"]
|
||||||
|
topk = model_config["topk"]
|
||||||
|
dtype = model_config["dtype"]
|
||||||
|
block_shape = getattr(model_config, "block_shape", None)
|
||||||
|
|
||||||
|
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
|
||||||
|
w1_scale = w2_scale = a1_scale = a2_scale = None
|
||||||
|
|
||||||
|
if use_fp8:
|
||||||
|
init_dtype = dtype
|
||||||
|
w1 = torch.randn(
|
||||||
|
num_experts, shard_intermediate_size, hidden_size, dtype=init_dtype
|
||||||
|
)
|
||||||
|
w2 = torch.randn(
|
||||||
|
num_experts, hidden_size, shard_intermediate_size // 2, dtype=init_dtype
|
||||||
|
)
|
||||||
|
w1 = w1.to(torch.float8_e4m3fn)
|
||||||
|
w2 = w2.to(torch.float8_e4m3fn)
|
||||||
|
|
||||||
|
if block_shape is None:
|
||||||
|
w1_scale = torch.randn(num_experts, dtype=torch.float32)
|
||||||
|
w2_scale = torch.randn(num_experts, dtype=torch.float32)
|
||||||
|
a1_scale = torch.randn(1, dtype=torch.float32)
|
||||||
|
a2_scale = torch.randn(1, dtype=torch.float32)
|
||||||
|
else:
|
||||||
|
block_n, block_k = block_shape[0], block_shape[1]
|
||||||
|
n_tiles_w1 = (shard_intermediate_size + block_n - 1) // block_n
|
||||||
|
n_tiles_w2 = (hidden_size + block_n - 1) // block_n
|
||||||
|
k_tiles_w1 = (hidden_size + block_k - 1) // block_k
|
||||||
|
k_tiles_w2 = (shard_intermediate_size // 2 + block_k - 1) // block_k
|
||||||
|
w1_scale = torch.rand(
|
||||||
|
(num_experts, n_tiles_w1, k_tiles_w1), dtype=torch.float32
|
||||||
|
)
|
||||||
|
w2_scale = torch.rand(
|
||||||
|
(num_experts, n_tiles_w2, k_tiles_w2), dtype=torch.float32
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
w1 = torch.randn(num_experts, shard_intermediate_size, hidden_size, dtype=dtype)
|
||||||
|
w2 = torch.randn(
|
||||||
|
num_experts, hidden_size, shard_intermediate_size // 2, dtype=dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32)
|
||||||
|
|
||||||
|
# Warmup
|
||||||
|
api_func = (
|
||||||
|
fused_moe_vllm_api
|
||||||
|
if provider == "vllm_fused_moe_triton"
|
||||||
|
else fused_moe_sglang_api
|
||||||
|
)
|
||||||
|
for _ in range(10):
|
||||||
|
y = api_func(
|
||||||
|
x,
|
||||||
|
w1,
|
||||||
|
w2,
|
||||||
|
input_gating,
|
||||||
|
topk,
|
||||||
|
use_fp8_w8a8=use_fp8,
|
||||||
|
w1_scale=w1_scale,
|
||||||
|
w2_scale=w2_scale,
|
||||||
|
a1_scale=a1_scale,
|
||||||
|
a2_scale=a2_scale,
|
||||||
|
block_shape=block_shape,
|
||||||
|
)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
quantiles = [0.5, 0.2, 0.8]
|
||||||
|
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||||
|
lambda: api_func(
|
||||||
|
x,
|
||||||
|
w1,
|
||||||
|
w2,
|
||||||
|
input_gating,
|
||||||
|
topk,
|
||||||
|
use_fp8_w8a8=use_fp8,
|
||||||
|
w1_scale=w1_scale,
|
||||||
|
w2_scale=w2_scale,
|
||||||
|
a1_scale=a1_scale,
|
||||||
|
a2_scale=a2_scale,
|
||||||
|
block_shape=block_shape,
|
||||||
|
)[0],
|
||||||
|
quantiles=quantiles,
|
||||||
|
)
|
||||||
|
return ms, min_ms, max_ms
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1"
|
||||||
|
)
|
||||||
|
parser.add_argument("--tp-size", type=int, default=2)
|
||||||
|
parser.add_argument("--use-fp8", action="store_true")
|
||||||
|
parser.add_argument(
|
||||||
|
"--save-path",
|
||||||
|
type=str,
|
||||||
|
default="./configs/benchmark_ops/vllm_sglang_fused_moe/",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
model_config = get_model_config(args.model, args.tp_size)
|
||||||
|
benchmark.run(
|
||||||
|
show_plots=True,
|
||||||
|
print_data=True,
|
||||||
|
save_path=args.save_path,
|
||||||
|
model_config=model_config,
|
||||||
|
use_fp8=args.use_fp8,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
@ -0,0 +1,564 @@
|
||||||
|
# Adapted from https://github.com/vllm-project/vllm/blob/main/benchmarks/kernels/benchmark_moe.py
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any, Dict, List, Tuple, TypedDict
|
||||||
|
|
||||||
|
import ray
|
||||||
|
import torch
|
||||||
|
import triton
|
||||||
|
from ray.experimental.tqdm_ray import tqdm
|
||||||
|
from transformers import AutoConfig
|
||||||
|
|
||||||
|
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
|
||||||
|
fused_moe,
|
||||||
|
get_config_dtype_str,
|
||||||
|
get_config_file_name,
|
||||||
|
get_default_config,
|
||||||
|
get_moe_configs,
|
||||||
|
)
|
||||||
|
from sglang.srt.utils import is_hip
|
||||||
|
|
||||||
|
_is_hip_ = is_hip()
|
||||||
|
|
||||||
|
|
||||||
|
class BenchmarkConfig(TypedDict):
|
||||||
|
BLOCK_SIZE_M: int
|
||||||
|
BLOCK_SIZE_N: int
|
||||||
|
BLOCK_SIZE_K: int
|
||||||
|
GROUP_SIZE_M: int
|
||||||
|
num_warps: int
|
||||||
|
num_stages: int
|
||||||
|
|
||||||
|
|
||||||
|
def benchmark_config(
|
||||||
|
config: BenchmarkConfig,
|
||||||
|
num_tokens: int,
|
||||||
|
num_experts: int,
|
||||||
|
shard_intermediate_size: int,
|
||||||
|
hidden_size: int,
|
||||||
|
topk: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
use_fp8_w8a8: bool,
|
||||||
|
use_int8_w8a8: bool,
|
||||||
|
use_int8_w8a16: bool,
|
||||||
|
block_shape: List[int] = None,
|
||||||
|
num_iters: int = 100,
|
||||||
|
) -> float:
|
||||||
|
init_dtype = torch.float16 if use_fp8_w8a8 else dtype
|
||||||
|
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
|
||||||
|
if use_int8_w8a16 or use_int8_w8a8:
|
||||||
|
w1 = torch.randint(
|
||||||
|
-127,
|
||||||
|
127,
|
||||||
|
(
|
||||||
|
num_experts,
|
||||||
|
shard_intermediate_size,
|
||||||
|
hidden_size,
|
||||||
|
),
|
||||||
|
dtype=torch.int8,
|
||||||
|
)
|
||||||
|
w2 = torch.randint(
|
||||||
|
-127,
|
||||||
|
127,
|
||||||
|
(
|
||||||
|
num_experts,
|
||||||
|
hidden_size,
|
||||||
|
shard_intermediate_size // 2,
|
||||||
|
),
|
||||||
|
dtype=torch.int8,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
w1 = torch.randn(
|
||||||
|
num_experts, shard_intermediate_size, hidden_size, dtype=init_dtype
|
||||||
|
)
|
||||||
|
w2 = torch.randn(
|
||||||
|
num_experts, hidden_size, shard_intermediate_size // 2, dtype=init_dtype
|
||||||
|
)
|
||||||
|
gating_output = torch.randn(num_iters, num_tokens, num_experts, dtype=torch.float32)
|
||||||
|
|
||||||
|
w1_scale = None
|
||||||
|
w2_scale = None
|
||||||
|
a1_scale = None
|
||||||
|
a2_scale = None
|
||||||
|
if use_int8_w8a16:
|
||||||
|
w1_scale = torch.randn(
|
||||||
|
(num_experts, 2 * shard_intermediate_size), dtype=torch.float32
|
||||||
|
)
|
||||||
|
w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32)
|
||||||
|
if use_fp8_w8a8 or use_int8_w8a8:
|
||||||
|
if use_int8_w8a8 and block_shape is None:
|
||||||
|
w1_scale = torch.randn(
|
||||||
|
num_experts, shard_intermediate_size, dtype=torch.float32
|
||||||
|
)
|
||||||
|
w2_scale = torch.randn(num_experts, hidden_size, dtype=torch.float32)
|
||||||
|
elif block_shape is None:
|
||||||
|
w1_scale = torch.randn(num_experts, dtype=torch.float32)
|
||||||
|
w2_scale = torch.randn(num_experts, dtype=torch.float32)
|
||||||
|
a1_scale = torch.randn(1, dtype=torch.float32)
|
||||||
|
a2_scale = torch.randn(1, dtype=torch.float32)
|
||||||
|
else:
|
||||||
|
block_n, block_k = block_shape[0], block_shape[1]
|
||||||
|
n_tiles_w1 = (shard_intermediate_size + block_n - 1) // block_n
|
||||||
|
n_tiles_w2 = (hidden_size + block_n - 1) // block_n
|
||||||
|
k_tiles_w1 = (hidden_size + block_k - 1) // block_k
|
||||||
|
k_tiles_w2 = (shard_intermediate_size // 2 + block_k - 1) // block_k
|
||||||
|
w1_scale = torch.rand(
|
||||||
|
(num_experts, n_tiles_w1, k_tiles_w1), dtype=torch.float32
|
||||||
|
)
|
||||||
|
w2_scale = torch.rand(
|
||||||
|
(num_experts, n_tiles_w2, k_tiles_w2), dtype=torch.float32
|
||||||
|
)
|
||||||
|
|
||||||
|
if use_fp8_w8a8:
|
||||||
|
w1 = w1.to(torch.float8_e4m3fnuz if _is_hip_ else torch.float8_e4m3fn)
|
||||||
|
w2 = w2.to(torch.float8_e4m3fnuz if _is_hip_ else torch.float8_e4m3fn)
|
||||||
|
|
||||||
|
input_gating = torch.empty(num_tokens, num_experts, dtype=torch.float32)
|
||||||
|
|
||||||
|
def prepare(i: int):
|
||||||
|
input_gating.copy_(gating_output[i])
|
||||||
|
|
||||||
|
def run():
|
||||||
|
from sglang.srt.layers.moe.fused_moe_triton import override_config
|
||||||
|
|
||||||
|
with override_config(config):
|
||||||
|
fused_moe(
|
||||||
|
x,
|
||||||
|
w1,
|
||||||
|
w2,
|
||||||
|
input_gating,
|
||||||
|
topk,
|
||||||
|
renormalize=True,
|
||||||
|
inplace=True,
|
||||||
|
use_fp8_w8a8=use_fp8_w8a8,
|
||||||
|
use_int8_w8a8=use_int8_w8a8,
|
||||||
|
use_int8_w8a16=use_int8_w8a16,
|
||||||
|
w1_scale=w1_scale,
|
||||||
|
w2_scale=w2_scale,
|
||||||
|
a1_scale=a1_scale,
|
||||||
|
a2_scale=a2_scale,
|
||||||
|
block_shape=block_shape,
|
||||||
|
)
|
||||||
|
|
||||||
|
# JIT compilation & warmup
|
||||||
|
run()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
# Capture 10 invocations with CUDA graph
|
||||||
|
graph = torch.cuda.CUDAGraph()
|
||||||
|
with torch.cuda.graph(graph):
|
||||||
|
for _ in range(10):
|
||||||
|
run()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
# Warmup
|
||||||
|
for _ in range(5):
|
||||||
|
graph.replay()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
start_event = torch.cuda.Event(enable_timing=True)
|
||||||
|
end_event = torch.cuda.Event(enable_timing=True)
|
||||||
|
|
||||||
|
latencies: List[float] = []
|
||||||
|
for i in range(num_iters):
|
||||||
|
prepare(i)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
start_event.record()
|
||||||
|
graph.replay()
|
||||||
|
end_event.record()
|
||||||
|
end_event.synchronize()
|
||||||
|
latencies.append(start_event.elapsed_time(end_event))
|
||||||
|
avg = sum(latencies) / (num_iters * 10) * 1000 # us
|
||||||
|
graph.reset()
|
||||||
|
return avg
|
||||||
|
|
||||||
|
|
||||||
|
def get_rocm_configs_compute_bound() -> List[Dict[str, int]]:
|
||||||
|
configs: List[BenchmarkConfig] = []
|
||||||
|
waves_per_eu_range = 0
|
||||||
|
for num_stages in [2]:
|
||||||
|
for block_m in [32, 64, 128, 256]:
|
||||||
|
for block_k in [32, 64, 128, 256]:
|
||||||
|
for block_n in [16, 32, 64, 128, 256]:
|
||||||
|
for num_warps in [1, 2, 4, 8]:
|
||||||
|
for group_size in [1, 4, 8, 16, 32]:
|
||||||
|
configs.append(
|
||||||
|
{
|
||||||
|
"BLOCK_SIZE_M": block_m,
|
||||||
|
"BLOCK_SIZE_N": block_n,
|
||||||
|
"BLOCK_SIZE_K": block_k,
|
||||||
|
"GROUP_SIZE_M": group_size,
|
||||||
|
"num_warps": num_warps,
|
||||||
|
"num_stages": num_stages,
|
||||||
|
"waves_per_eu": waves_per_eu_range,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return configs
|
||||||
|
|
||||||
|
|
||||||
|
def get_configs_compute_bound() -> List[Dict[str, int]]:
|
||||||
|
# Reduced search space for faster tuning.
|
||||||
|
# TODO(woosuk): Increase the search space and use a performance model to
|
||||||
|
# prune the search space.
|
||||||
|
configs: List[BenchmarkConfig] = []
|
||||||
|
if _is_hip_:
|
||||||
|
configs = get_rocm_configs_compute_bound()
|
||||||
|
else:
|
||||||
|
for num_stages in [2, 3, 4, 5]:
|
||||||
|
for block_m in [16, 32, 64, 128, 256]:
|
||||||
|
for block_k in [64, 128, 256]:
|
||||||
|
for block_n in [32, 64, 128, 256]:
|
||||||
|
for num_warps in [4, 8]:
|
||||||
|
for group_size in [1, 16, 32, 64]:
|
||||||
|
configs.append(
|
||||||
|
{
|
||||||
|
"BLOCK_SIZE_M": block_m,
|
||||||
|
"BLOCK_SIZE_N": block_n,
|
||||||
|
"BLOCK_SIZE_K": block_k,
|
||||||
|
"GROUP_SIZE_M": group_size,
|
||||||
|
"num_warps": num_warps,
|
||||||
|
"num_stages": num_stages,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return configs
|
||||||
|
|
||||||
|
|
||||||
|
@ray.remote(num_gpus=1)
|
||||||
|
class BenchmarkWorker:
|
||||||
|
|
||||||
|
def __init__(self, seed: int) -> None:
|
||||||
|
torch.set_default_device("cuda")
|
||||||
|
torch.cuda.manual_seed_all(0)
|
||||||
|
self.seed = seed
|
||||||
|
|
||||||
|
def benchmark(
|
||||||
|
self,
|
||||||
|
num_tokens: int,
|
||||||
|
num_experts: int,
|
||||||
|
shard_intermediate_size: int,
|
||||||
|
hidden_size: int,
|
||||||
|
topk: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
use_fp8_w8a8: bool,
|
||||||
|
use_int8_w8a8: bool,
|
||||||
|
use_int8_w8a16: bool,
|
||||||
|
block_shape: List[int],
|
||||||
|
) -> Tuple[Dict[str, int], float]:
|
||||||
|
torch.cuda.manual_seed_all(0)
|
||||||
|
dtype_str = get_config_dtype_str(
|
||||||
|
dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8
|
||||||
|
)
|
||||||
|
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
|
||||||
|
# is the intermediate size after silu_and_mul.
|
||||||
|
block_n = block_shape[0] if block_shape else 0
|
||||||
|
block_k = block_shape[1] if block_shape else 0
|
||||||
|
op_config = get_moe_configs(
|
||||||
|
num_experts, shard_intermediate_size // 2, dtype_str, block_n, block_k
|
||||||
|
)
|
||||||
|
if op_config is None:
|
||||||
|
config = get_default_config(
|
||||||
|
num_tokens,
|
||||||
|
num_experts,
|
||||||
|
shard_intermediate_size,
|
||||||
|
hidden_size,
|
||||||
|
topk,
|
||||||
|
dtype_str,
|
||||||
|
False,
|
||||||
|
block_shape,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
config = op_config[min(op_config.keys(), key=lambda x: abs(x - num_tokens))]
|
||||||
|
kernel_time = benchmark_config(
|
||||||
|
config,
|
||||||
|
num_tokens,
|
||||||
|
num_experts,
|
||||||
|
shard_intermediate_size,
|
||||||
|
hidden_size,
|
||||||
|
topk,
|
||||||
|
dtype,
|
||||||
|
use_fp8_w8a8,
|
||||||
|
use_int8_w8a8,
|
||||||
|
use_int8_w8a16,
|
||||||
|
block_shape,
|
||||||
|
)
|
||||||
|
return config, kernel_time
|
||||||
|
|
||||||
|
def tune(
|
||||||
|
self,
|
||||||
|
num_tokens: int,
|
||||||
|
num_experts: int,
|
||||||
|
shard_intermediate_size: int,
|
||||||
|
hidden_size: int,
|
||||||
|
topk: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
use_fp8_w8a8: bool,
|
||||||
|
use_int8_w8a8: bool,
|
||||||
|
use_int8_w8a16: bool,
|
||||||
|
block_shape: List[int],
|
||||||
|
search_space: List[Dict[str, int]],
|
||||||
|
) -> Dict[str, int]:
|
||||||
|
best_config = None
|
||||||
|
best_time = float("inf")
|
||||||
|
for config in tqdm(search_space):
|
||||||
|
try:
|
||||||
|
kernel_time = benchmark_config(
|
||||||
|
config,
|
||||||
|
num_tokens,
|
||||||
|
num_experts,
|
||||||
|
shard_intermediate_size,
|
||||||
|
hidden_size,
|
||||||
|
topk,
|
||||||
|
dtype,
|
||||||
|
use_fp8_w8a8,
|
||||||
|
use_int8_w8a8,
|
||||||
|
use_int8_w8a16,
|
||||||
|
block_shape,
|
||||||
|
num_iters=10,
|
||||||
|
)
|
||||||
|
except triton.runtime.autotuner.OutOfResources:
|
||||||
|
# Some configurations may be invalid and fail to compile.
|
||||||
|
continue
|
||||||
|
|
||||||
|
if kernel_time < best_time:
|
||||||
|
best_time = kernel_time
|
||||||
|
best_config = config
|
||||||
|
now = datetime.now()
|
||||||
|
print(f"{now.ctime()}] Completed tuning for batch_size={num_tokens}")
|
||||||
|
assert best_config is not None
|
||||||
|
return best_config
|
||||||
|
|
||||||
|
|
||||||
|
def sort_config(config: BenchmarkConfig) -> BenchmarkConfig:
|
||||||
|
return {
|
||||||
|
"BLOCK_SIZE_M": config["BLOCK_SIZE_M"],
|
||||||
|
"BLOCK_SIZE_N": config["BLOCK_SIZE_N"],
|
||||||
|
"BLOCK_SIZE_K": config["BLOCK_SIZE_K"],
|
||||||
|
"GROUP_SIZE_M": config["GROUP_SIZE_M"],
|
||||||
|
"num_warps": config["num_warps"],
|
||||||
|
"num_stages": config["num_stages"],
|
||||||
|
**(
|
||||||
|
{"waves_per_eu": config["waves_per_eu"]} if "waves_per_eu" in config else {}
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def save_configs(
|
||||||
|
configs: Dict[int, BenchmarkConfig],
|
||||||
|
num_experts: int,
|
||||||
|
shard_intermediate_size: int,
|
||||||
|
hidden_size: int,
|
||||||
|
topk: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
use_fp8_w8a8: bool,
|
||||||
|
use_int8_w8a8: bool,
|
||||||
|
use_int8_w8a16: bool,
|
||||||
|
block_shape: List[int],
|
||||||
|
) -> None:
|
||||||
|
dtype_str = get_config_dtype_str(
|
||||||
|
dtype,
|
||||||
|
use_int8_w8a16=use_int8_w8a16,
|
||||||
|
use_fp8_w8a8=use_fp8_w8a8,
|
||||||
|
use_int8_w8a8=use_int8_w8a8,
|
||||||
|
)
|
||||||
|
|
||||||
|
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
|
||||||
|
# is the intermediate size after silu_and_mul.
|
||||||
|
filename = get_config_file_name(
|
||||||
|
num_experts,
|
||||||
|
shard_intermediate_size // 2,
|
||||||
|
dtype_str,
|
||||||
|
block_shape,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Writing best config to {filename}...")
|
||||||
|
with open(filename, "w") as f:
|
||||||
|
json.dump(configs, f, indent=4)
|
||||||
|
f.write("\n")
|
||||||
|
|
||||||
|
|
||||||
|
def main(args: argparse.Namespace):
|
||||||
|
print(args)
|
||||||
|
|
||||||
|
config = AutoConfig.from_pretrained(args.model, trust_remote_code=True)
|
||||||
|
if config.architectures[0] == "DbrxForCausalLM":
|
||||||
|
E = config.ffn_config.moe_num_experts
|
||||||
|
topk = config.ffn_config.moe_top_k
|
||||||
|
intermediate_size = config.ffn_config.ffn_hidden_size
|
||||||
|
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||||
|
elif config.architectures[0] == "JambaForCausalLM":
|
||||||
|
E = config.num_experts
|
||||||
|
topk = config.num_experts_per_tok
|
||||||
|
intermediate_size = config.intermediate_size
|
||||||
|
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||||
|
elif config.architectures[0] == "Qwen2MoeForCausalLM":
|
||||||
|
E = config.num_experts
|
||||||
|
topk = config.num_experts_per_tok
|
||||||
|
intermediate_size = config.moe_intermediate_size
|
||||||
|
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||||
|
elif config.architectures[0] in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]:
|
||||||
|
E = config.n_routed_experts
|
||||||
|
topk = config.num_experts_per_tok
|
||||||
|
intermediate_size = config.moe_intermediate_size
|
||||||
|
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||||
|
elif config.architectures[0] in [
|
||||||
|
"Grok1ForCausalLM",
|
||||||
|
"Grok1ImgGen",
|
||||||
|
"Grok1AForCausalLM",
|
||||||
|
]:
|
||||||
|
E = config.num_local_experts
|
||||||
|
topk = config.num_experts_per_tok
|
||||||
|
intermediate_size = config.moe_intermediate_size
|
||||||
|
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||||
|
else:
|
||||||
|
# Default: Mixtral
|
||||||
|
E = config.num_local_experts
|
||||||
|
topk = config.num_experts_per_tok
|
||||||
|
intermediate_size = config.intermediate_size
|
||||||
|
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||||
|
|
||||||
|
hidden_size = config.hidden_size
|
||||||
|
dtype = config.torch_dtype
|
||||||
|
use_fp8_w8a8 = args.dtype == "fp8_w8a8"
|
||||||
|
use_int8_w8a8 = args.dtype == "int8_w8a8"
|
||||||
|
use_int8_w8a16 = args.dtype == "int8_w8a16"
|
||||||
|
block_shape = None
|
||||||
|
if (
|
||||||
|
hasattr(config, "quantization_config")
|
||||||
|
and "weight_block_size" in config.quantization_config
|
||||||
|
):
|
||||||
|
block_shape = config.quantization_config["weight_block_size"]
|
||||||
|
assert len(block_shape) == 2
|
||||||
|
|
||||||
|
if args.batch_size is None:
|
||||||
|
batch_sizes = [
|
||||||
|
1,
|
||||||
|
2,
|
||||||
|
4,
|
||||||
|
8,
|
||||||
|
16,
|
||||||
|
24,
|
||||||
|
32,
|
||||||
|
48,
|
||||||
|
64,
|
||||||
|
96,
|
||||||
|
128,
|
||||||
|
256,
|
||||||
|
512,
|
||||||
|
1024,
|
||||||
|
1536,
|
||||||
|
2048,
|
||||||
|
3072,
|
||||||
|
4096,
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
batch_sizes = [args.batch_size]
|
||||||
|
|
||||||
|
ray.init()
|
||||||
|
num_gpus = int(ray.available_resources()["GPU"])
|
||||||
|
workers = [BenchmarkWorker.remote(args.seed) for _ in range(num_gpus)]
|
||||||
|
|
||||||
|
def _distribute(method: str, inputs: List[Any]) -> List[Any]:
|
||||||
|
outputs = []
|
||||||
|
worker_idx = 0
|
||||||
|
for input_args in inputs:
|
||||||
|
worker = workers[worker_idx]
|
||||||
|
worker_method = getattr(worker, method)
|
||||||
|
output = worker_method.remote(*input_args)
|
||||||
|
outputs.append(output)
|
||||||
|
worker_idx = (worker_idx + 1) % num_gpus
|
||||||
|
return ray.get(outputs)
|
||||||
|
|
||||||
|
if args.tune:
|
||||||
|
search_space = get_configs_compute_bound()
|
||||||
|
if block_shape is not None:
|
||||||
|
block_n, block_k = block_shape[0], block_shape[1]
|
||||||
|
search_space = [
|
||||||
|
config
|
||||||
|
for config in search_space
|
||||||
|
if block_k % config["BLOCK_SIZE_K"] == 0
|
||||||
|
]
|
||||||
|
print(f"Start tuning over {len(search_space)} configurations...")
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
|
configs = _distribute(
|
||||||
|
"tune",
|
||||||
|
[
|
||||||
|
(
|
||||||
|
batch_size,
|
||||||
|
E,
|
||||||
|
shard_intermediate_size,
|
||||||
|
hidden_size,
|
||||||
|
topk,
|
||||||
|
dtype,
|
||||||
|
use_fp8_w8a8,
|
||||||
|
use_int8_w8a8,
|
||||||
|
use_int8_w8a16,
|
||||||
|
block_shape,
|
||||||
|
search_space,
|
||||||
|
)
|
||||||
|
for batch_size in batch_sizes
|
||||||
|
],
|
||||||
|
)
|
||||||
|
best_configs = {
|
||||||
|
M: sort_config(config) for M, config in zip(batch_sizes, configs)
|
||||||
|
}
|
||||||
|
save_configs(
|
||||||
|
best_configs,
|
||||||
|
E,
|
||||||
|
shard_intermediate_size,
|
||||||
|
hidden_size,
|
||||||
|
topk,
|
||||||
|
dtype,
|
||||||
|
use_fp8_w8a8,
|
||||||
|
use_int8_w8a8,
|
||||||
|
use_int8_w8a16,
|
||||||
|
block_shape,
|
||||||
|
)
|
||||||
|
end = time.time()
|
||||||
|
print(f"Tuning took {end - start:.2f} seconds")
|
||||||
|
else:
|
||||||
|
outputs = _distribute(
|
||||||
|
"benchmark",
|
||||||
|
[
|
||||||
|
(
|
||||||
|
batch_size,
|
||||||
|
E,
|
||||||
|
shard_intermediate_size,
|
||||||
|
hidden_size,
|
||||||
|
topk,
|
||||||
|
dtype,
|
||||||
|
use_fp8_w8a8,
|
||||||
|
use_int8_w8a8,
|
||||||
|
use_int8_w8a16,
|
||||||
|
block_shape,
|
||||||
|
)
|
||||||
|
for batch_size in batch_sizes
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
for batch_size, (config, kernel_time) in zip(batch_sizes, outputs):
|
||||||
|
print(f"Batch size: {batch_size}, config: {config}")
|
||||||
|
print(f"Kernel time: {kernel_time:.2f} us")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1"
|
||||||
|
)
|
||||||
|
parser.add_argument("--tp-size", "-tp", type=int, default=2)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dtype",
|
||||||
|
type=str,
|
||||||
|
choices=["auto", "fp8_w8a8", "int8_w8a16", "int8_w8a8"],
|
||||||
|
default="auto",
|
||||||
|
)
|
||||||
|
parser.add_argument("--seed", type=int, default=0)
|
||||||
|
parser.add_argument("--batch-size", type=int, required=False)
|
||||||
|
parser.add_argument("--tune", action="store_true")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
main(args)
|
||||||
|
|
@ -0,0 +1,576 @@
|
||||||
|
import itertools
|
||||||
|
import math
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
from einops import rearrange
|
||||||
|
from sgl_kernel import lightning_attention_decode as sgl_lightning_attention_decode
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def _decode_kernel(
|
||||||
|
Q,
|
||||||
|
K,
|
||||||
|
V,
|
||||||
|
KV,
|
||||||
|
Out,
|
||||||
|
S,
|
||||||
|
b: tl.constexpr,
|
||||||
|
h: tl.constexpr,
|
||||||
|
n: tl.constexpr,
|
||||||
|
d: tl.constexpr,
|
||||||
|
d_original: tl.constexpr,
|
||||||
|
e: tl.constexpr,
|
||||||
|
e_original: tl.constexpr,
|
||||||
|
):
|
||||||
|
off_bh = tl.program_id(0)
|
||||||
|
off_h = off_bh % h
|
||||||
|
|
||||||
|
qk_offset = off_bh * n * d
|
||||||
|
v_offset = off_bh * n * e
|
||||||
|
o_offset = off_bh * n * e
|
||||||
|
kv_offset = off_bh * d * e
|
||||||
|
|
||||||
|
s = tl.load(S + off_h)
|
||||||
|
ratio = tl.exp(-s)
|
||||||
|
|
||||||
|
d_idx = tl.arange(0, d)
|
||||||
|
e_idx = tl.arange(0, e)
|
||||||
|
|
||||||
|
# Create masks for original dimensions
|
||||||
|
d_mask = d_idx < d_original
|
||||||
|
e_mask = e_idx < e_original
|
||||||
|
|
||||||
|
# Load with masking
|
||||||
|
q = tl.load(Q + qk_offset + d_idx, mask=d_mask, other=0.0)
|
||||||
|
k = tl.load(K + qk_offset + d_idx, mask=d_mask, other=0.0)
|
||||||
|
v = tl.load(V + v_offset + e_idx, mask=e_mask, other=0.0)
|
||||||
|
|
||||||
|
# Load KV with 2D masking
|
||||||
|
kv = tl.load(
|
||||||
|
KV + kv_offset + d_idx[:, None] * e + e_idx[None, :],
|
||||||
|
mask=(d_mask[:, None] & e_mask[None, :]),
|
||||||
|
other=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compute outer product using element-wise operations
|
||||||
|
k_v_prod = k[:, None] * v[None, :]
|
||||||
|
kv = ratio * kv + k_v_prod
|
||||||
|
|
||||||
|
# Store KV with 2D masking
|
||||||
|
tl.store(
|
||||||
|
KV + kv_offset + d_idx[:, None] * e + e_idx[None, :],
|
||||||
|
kv.to(KV.dtype.element_ty),
|
||||||
|
mask=(d_mask[:, None] & e_mask[None, :]),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compute matrix-vector multiplication using element-wise operations and reduction
|
||||||
|
o = tl.sum(q[:, None] * kv, axis=0)
|
||||||
|
|
||||||
|
# Store output with masking
|
||||||
|
tl.store(Out + o_offset + e_idx, o.to(Out.dtype.element_ty), mask=e_mask)
|
||||||
|
|
||||||
|
|
||||||
|
def lightning_attn_decode(q, k, v, kv, s):
|
||||||
|
"""Triton implementation of Lightning Attention decode operation"""
|
||||||
|
b, h, n, d = q.shape
|
||||||
|
e = v.shape[-1]
|
||||||
|
assert n == 1, "Sequence length must be 1 in decode mode"
|
||||||
|
|
||||||
|
# Get padded dimensions (power of 2)
|
||||||
|
d_padded = next_power_of_2(d)
|
||||||
|
e_padded = next_power_of_2(e)
|
||||||
|
|
||||||
|
# Create output tensor (padded)
|
||||||
|
o_padded = torch.empty(b, h, n, e_padded, dtype=v.dtype, device=v.device)
|
||||||
|
|
||||||
|
# Create padded tensors without actually padding the data
|
||||||
|
q_padded = torch.empty(b, h, n, d_padded, dtype=q.dtype, device=q.device)
|
||||||
|
k_padded = torch.empty(b, h, n, d_padded, dtype=k.dtype, device=k.device)
|
||||||
|
v_padded = torch.empty(b, h, n, e_padded, dtype=v.dtype, device=v.device)
|
||||||
|
kv_padded = torch.empty(
|
||||||
|
b, h, d_padded, e_padded, dtype=torch.float32, device=kv.device
|
||||||
|
)
|
||||||
|
|
||||||
|
# Copy data to padded tensors
|
||||||
|
q_padded[..., :d] = q
|
||||||
|
k_padded[..., :d] = k
|
||||||
|
v_padded[..., :e] = v
|
||||||
|
kv_padded[..., :d, :e] = kv
|
||||||
|
|
||||||
|
# Launch kernel
|
||||||
|
grid = (b * h, 1)
|
||||||
|
_decode_kernel[grid](
|
||||||
|
q_padded,
|
||||||
|
k_padded,
|
||||||
|
v_padded,
|
||||||
|
kv_padded,
|
||||||
|
o_padded,
|
||||||
|
s,
|
||||||
|
b=b,
|
||||||
|
h=h,
|
||||||
|
n=n,
|
||||||
|
d=d_padded,
|
||||||
|
d_original=d,
|
||||||
|
e=e_padded,
|
||||||
|
e_original=e,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get unpadded outputs
|
||||||
|
o = o_padded[..., :e]
|
||||||
|
kv_out = kv_padded[..., :d, :e]
|
||||||
|
|
||||||
|
return o, kv_out
|
||||||
|
|
||||||
|
|
||||||
|
def next_power_of_2(n):
|
||||||
|
return 2 ** (int(math.ceil(math.log(n, 2))))
|
||||||
|
|
||||||
|
|
||||||
|
class MiniMaxText01LightningAttention(nn.Module):
|
||||||
|
def __init__(self, config=None, layer_idx: Optional[int] = None, **kwargs):
|
||||||
|
super().__init__()
|
||||||
|
if config is None:
|
||||||
|
config = type("Config", (), kwargs)
|
||||||
|
|
||||||
|
bias = False
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.num_heads = config.num_attention_heads
|
||||||
|
self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads)
|
||||||
|
|
||||||
|
self.out_proj = nn.Linear(
|
||||||
|
self.head_dim * self.num_heads, self.hidden_size, bias=bias
|
||||||
|
)
|
||||||
|
self.act = get_activation_fn(config.hidden_act)
|
||||||
|
self.norm = MiniMaxText01RMSNorm(self.head_dim * self.num_heads)
|
||||||
|
|
||||||
|
self.qkv_proj = nn.Linear(
|
||||||
|
self.hidden_size, 3 * self.head_dim * self.num_heads, bias=bias
|
||||||
|
)
|
||||||
|
self.output_gate = nn.Linear(
|
||||||
|
self.hidden_size, self.head_dim * self.num_heads, bias=bias
|
||||||
|
)
|
||||||
|
|
||||||
|
# for inference only
|
||||||
|
self.offset = 0
|
||||||
|
self.layer_idx = layer_idx
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
attn_mask: Optional[torch.Tensor] = None, # (b, h, n, m)
|
||||||
|
output_attentions: bool = False,
|
||||||
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
|
use_cache: bool = False,
|
||||||
|
slope_rate: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
if (not self.training) and (not do_eval):
|
||||||
|
return self.inference(
|
||||||
|
hidden_states,
|
||||||
|
attn_mask,
|
||||||
|
output_attentions,
|
||||||
|
past_key_value,
|
||||||
|
use_cache,
|
||||||
|
slope_rate,
|
||||||
|
)
|
||||||
|
|
||||||
|
def inference(
|
||||||
|
self,
|
||||||
|
x,
|
||||||
|
attn_mask: Optional[torch.Tensor] = None, # (b, n)
|
||||||
|
output_attentions: bool = False,
|
||||||
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
|
use_cache: bool = False,
|
||||||
|
slope_rate: Optional[torch.Tensor] = None, # (h, 1, 1)
|
||||||
|
):
|
||||||
|
# x: b n d
|
||||||
|
b, n, d = x.shape
|
||||||
|
# linear map
|
||||||
|
qkv = self.act(self.qkv_proj(x))
|
||||||
|
new_shape = qkv.size()[:-1] + (self.num_heads, -1)
|
||||||
|
qkv = qkv.view(*new_shape)
|
||||||
|
q, k, v = torch.split(qkv, [self.head_dim] * 3, dim=3)
|
||||||
|
q = q.transpose(1, 2) # [b, n, h, d] -> [b, h, n, d]
|
||||||
|
k = k.transpose(1, 2) # [b, n, h, d] -> [b, h, n, d]
|
||||||
|
v = v.transpose(1, 2) # [b, n, h, d] -> [b, h, n, e]
|
||||||
|
|
||||||
|
self.offset += 1
|
||||||
|
ratio = torch.exp(-slope_rate) # [h, 1, 1]
|
||||||
|
|
||||||
|
# decode mode
|
||||||
|
kv = past_key_value # [b, h, d, e]
|
||||||
|
output = []
|
||||||
|
for i in range(n):
|
||||||
|
# kv: [b, h, d, e]
|
||||||
|
# ratio: [h, 1, 1]
|
||||||
|
# k: [b, h, n, d]
|
||||||
|
# v: [b, h, n, e]
|
||||||
|
# k[:, :, i : i + 1]: [b, h, 1, d]
|
||||||
|
# v[:, :, i : i + 1]: [b, h, 1, e]
|
||||||
|
# ratio * kv: [b, h, d, e]
|
||||||
|
# torch.einsum(
|
||||||
|
# "... n d, ... n e -> ... d e",
|
||||||
|
# k[:, :, i : i + 1],
|
||||||
|
# v[:, :, i : i + 1],
|
||||||
|
# )
|
||||||
|
# [b, h, d, e] + [b, h, d, e] -> [b, h, d, e]
|
||||||
|
kv = ratio * kv + torch.einsum(
|
||||||
|
"... n d, ... n e -> ... d e",
|
||||||
|
k[:, :, i : i + 1],
|
||||||
|
v[:, :, i : i + 1],
|
||||||
|
)
|
||||||
|
# q[:, :, i : i + 1]: [b, h, 1, d]
|
||||||
|
# kv.to(q.dtype): [b, h, d, e]
|
||||||
|
# torch.einsum(
|
||||||
|
# "... n e, ... e d -> ... n d", q[:, :, i : i + 1], kv.to(q.dtype)
|
||||||
|
# )
|
||||||
|
# [b, h, 1, d] * [b, h, d, e] -> [b, h, 1, e]
|
||||||
|
qkv = torch.einsum(
|
||||||
|
"... n e, ... e d -> ... n d", q[:, :, i : i + 1], kv.to(q.dtype)
|
||||||
|
)
|
||||||
|
output.append(qkv)
|
||||||
|
output = torch.cat(output, dim=-2)
|
||||||
|
|
||||||
|
# reshape
|
||||||
|
output = rearrange(output, "b h n d -> b n (h d)")
|
||||||
|
# normalize
|
||||||
|
output = self.norm(output)
|
||||||
|
# gate
|
||||||
|
output = F.sigmoid(self.output_gate(x)) * output
|
||||||
|
# outproj
|
||||||
|
output = self.out_proj(output)
|
||||||
|
|
||||||
|
attn_weights = None
|
||||||
|
|
||||||
|
return output, attn_weights, kv
|
||||||
|
|
||||||
|
|
||||||
|
def get_activation_fn(activation):
|
||||||
|
if activation == "gelu":
|
||||||
|
return F.gelu
|
||||||
|
elif activation == "relu":
|
||||||
|
return F.relu
|
||||||
|
elif activation == "elu":
|
||||||
|
return F.elu
|
||||||
|
elif activation == "sigmoid":
|
||||||
|
return F.sigmoid
|
||||||
|
elif activation == "exp":
|
||||||
|
|
||||||
|
def f(x):
|
||||||
|
with torch.no_grad():
|
||||||
|
x_max = torch.max(x, dim=-1, keepdims=True).values
|
||||||
|
y = torch.exp(x - x_max)
|
||||||
|
return y
|
||||||
|
|
||||||
|
return f
|
||||||
|
elif activation == "leak":
|
||||||
|
return F.leaky_relu
|
||||||
|
elif activation == "1+elu":
|
||||||
|
|
||||||
|
def f(x):
|
||||||
|
return 1 + F.elu(x)
|
||||||
|
|
||||||
|
return f
|
||||||
|
elif activation == "2+elu":
|
||||||
|
|
||||||
|
def f(x):
|
||||||
|
return 2 + F.elu(x)
|
||||||
|
|
||||||
|
return f
|
||||||
|
elif activation == "silu" or activation == "swish":
|
||||||
|
return F.silu
|
||||||
|
elif activation == "sine":
|
||||||
|
return torch.sin
|
||||||
|
else:
|
||||||
|
return lambda x: x
|
||||||
|
|
||||||
|
|
||||||
|
class MiniMaxText01RMSNorm(nn.Module):
|
||||||
|
def __init__(self, hidden_size, eps=1e-6):
|
||||||
|
"""
|
||||||
|
MiniMaxText01RMSNorm is equivalent to T5LayerNorm
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||||
|
self.variance_epsilon = eps
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
input_dtype = hidden_states.dtype
|
||||||
|
hidden_states = hidden_states.to(torch.float32)
|
||||||
|
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||||
|
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||||
|
return self.weight * hidden_states.to(input_dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def test_lightning_attention_implementations(model_params):
|
||||||
|
torch.manual_seed(42)
|
||||||
|
|
||||||
|
batch_size = 64
|
||||||
|
seq_len = 1
|
||||||
|
dtype = torch.bfloat16
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
hidden_states = torch.randn(
|
||||||
|
batch_size, seq_len, model_params["hidden_size"], dtype=dtype, device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
attention_mask = torch.ones(batch_size, seq_len, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
slope_rate = _build_slope_tensor(model_params["num_attention_heads"]).to(device)
|
||||||
|
|
||||||
|
model_attn = MiniMaxText01LightningAttention(**model_params).to(dtype).to(device)
|
||||||
|
model_attn.eval()
|
||||||
|
|
||||||
|
d = model_params["head_dim"]
|
||||||
|
past_kv = torch.randn(
|
||||||
|
batch_size,
|
||||||
|
model_params["num_attention_heads"],
|
||||||
|
d,
|
||||||
|
d,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
with torch.no_grad():
|
||||||
|
model_output, _, new_kv = model_attn.inference(
|
||||||
|
hidden_states,
|
||||||
|
attn_mask=attention_mask,
|
||||||
|
slope_rate=slope_rate,
|
||||||
|
past_key_value=past_kv,
|
||||||
|
)
|
||||||
|
|
||||||
|
qkv = model_attn.act(model_attn.qkv_proj(hidden_states))
|
||||||
|
new_shape = qkv.size()[:-1] + (model_attn.num_heads, -1)
|
||||||
|
qkv = qkv.view(*new_shape)
|
||||||
|
q, k, v = torch.split(qkv, [model_attn.head_dim] * 3, dim=-1)
|
||||||
|
q = q.transpose(1, 2)
|
||||||
|
k = k.transpose(1, 2)
|
||||||
|
v = v.transpose(1, 2)
|
||||||
|
q = q.contiguous()
|
||||||
|
k = k.contiguous()
|
||||||
|
v = v.contiguous()
|
||||||
|
past_kv = past_kv.contiguous()
|
||||||
|
slope_rate = slope_rate.contiguous()
|
||||||
|
|
||||||
|
# Test Triton implementation
|
||||||
|
triton_output, triton_new_kv = lightning_attn_decode(q, k, v, past_kv, slope_rate)
|
||||||
|
triton_output = triton_output.transpose(1, 2).contiguous()
|
||||||
|
triton_output = triton_output.view(batch_size, seq_len, -1)
|
||||||
|
triton_output = model_attn.norm(triton_output)
|
||||||
|
triton_output = torch.sigmoid(model_attn.output_gate(hidden_states)) * triton_output
|
||||||
|
triton_output = model_attn.out_proj(triton_output)
|
||||||
|
|
||||||
|
# Test SGL implementation
|
||||||
|
sgl_output = torch.empty_like(v)
|
||||||
|
sgl_new_kv = torch.empty_like(past_kv)
|
||||||
|
sgl_lightning_attention_decode(q, k, v, past_kv, slope_rate, sgl_output, sgl_new_kv)
|
||||||
|
|
||||||
|
sgl_output = sgl_output.transpose(1, 2).contiguous()
|
||||||
|
sgl_output = sgl_output.view(batch_size, seq_len, -1)
|
||||||
|
sgl_output = model_attn.norm(sgl_output)
|
||||||
|
sgl_output = torch.sigmoid(model_attn.output_gate(hidden_states)) * sgl_output
|
||||||
|
sgl_output = model_attn.out_proj(sgl_output)
|
||||||
|
|
||||||
|
# Verify Triton implementation results
|
||||||
|
torch.testing.assert_close(
|
||||||
|
model_output,
|
||||||
|
triton_output,
|
||||||
|
rtol=1e-3,
|
||||||
|
atol=1e-2,
|
||||||
|
msg="Triton lightning attention implementation produces different output results",
|
||||||
|
)
|
||||||
|
torch.testing.assert_close(
|
||||||
|
new_kv,
|
||||||
|
triton_new_kv,
|
||||||
|
rtol=1e-3,
|
||||||
|
atol=1e-2,
|
||||||
|
msg="Triton lightning attention implementation produces different kv results",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify SGL implementation results
|
||||||
|
torch.testing.assert_close(
|
||||||
|
model_output,
|
||||||
|
sgl_output,
|
||||||
|
rtol=1e-3,
|
||||||
|
atol=1e-2,
|
||||||
|
msg="SGL lightning attention implementation produces different output results",
|
||||||
|
)
|
||||||
|
torch.testing.assert_close(
|
||||||
|
new_kv,
|
||||||
|
sgl_new_kv,
|
||||||
|
rtol=1e-3,
|
||||||
|
atol=1e-2,
|
||||||
|
msg="SGL lightning attention implementation produces different kv results",
|
||||||
|
)
|
||||||
|
|
||||||
|
print("✅ All implementations match")
|
||||||
|
|
||||||
|
|
||||||
|
def _build_slope_tensor(n_attention_heads: int):
|
||||||
|
def get_slopes(n):
|
||||||
|
def get_slopes_power_of_2(n):
|
||||||
|
start = 2 ** (-(2 ** -(math.log2(n) - 3)))
|
||||||
|
ratio = start
|
||||||
|
return [start * ratio**i for i in range(n)]
|
||||||
|
|
||||||
|
if math.log2(n).is_integer():
|
||||||
|
return get_slopes_power_of_2(n)
|
||||||
|
else:
|
||||||
|
closest_power_of_2 = 2 ** math.floor(math.log2(n))
|
||||||
|
return (
|
||||||
|
get_slopes_power_of_2(closest_power_of_2)
|
||||||
|
+ get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2]
|
||||||
|
)
|
||||||
|
|
||||||
|
slopes = torch.tensor(get_slopes(n_attention_heads)).reshape(
|
||||||
|
n_attention_heads, 1, 1
|
||||||
|
)
|
||||||
|
return slopes
|
||||||
|
|
||||||
|
|
||||||
|
def get_benchmark():
|
||||||
|
batch_size_range = [i for i in range(1, 33)] # max 32
|
||||||
|
seq_length_range = [1] # decode mode sequence length is fixed to 1
|
||||||
|
configs = list(itertools.product(batch_size_range, seq_length_range))
|
||||||
|
|
||||||
|
@triton.testing.perf_report(
|
||||||
|
triton.testing.Benchmark(
|
||||||
|
x_names=["batch_size", "seq_len"],
|
||||||
|
x_vals=[list(_) for _ in configs],
|
||||||
|
line_arg="provider",
|
||||||
|
line_vals=["Original", "Triton", "SGL"],
|
||||||
|
line_names=[
|
||||||
|
"Original PyTorch Implementation",
|
||||||
|
"Triton Implementation",
|
||||||
|
"SGL Implementation",
|
||||||
|
],
|
||||||
|
styles=[("blue", "-"), ("green", "-"), ("red", "-")],
|
||||||
|
ylabel="us",
|
||||||
|
plot_name="lightning-attention-decode-performance",
|
||||||
|
args={},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
def benchmark(batch_size, seq_len, provider):
|
||||||
|
dtype = torch.bfloat16
|
||||||
|
device = torch.device("cuda")
|
||||||
|
|
||||||
|
params = {
|
||||||
|
"hidden_size": 6144,
|
||||||
|
"num_attention_heads": 64,
|
||||||
|
"head_dim": 96,
|
||||||
|
"hidden_act": "gelu",
|
||||||
|
}
|
||||||
|
|
||||||
|
hidden_states = torch.randn(
|
||||||
|
batch_size, seq_len, params["hidden_size"], dtype=dtype, device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
attention_mask = torch.ones(batch_size, seq_len, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
slope_rate = _build_slope_tensor(params["num_attention_heads"]).to(device)
|
||||||
|
model_attn = MiniMaxText01LightningAttention(**params).to(dtype).to(device)
|
||||||
|
model_attn.eval()
|
||||||
|
|
||||||
|
d = params["head_dim"]
|
||||||
|
past_kv = torch.randn(
|
||||||
|
batch_size,
|
||||||
|
params["num_attention_heads"],
|
||||||
|
d,
|
||||||
|
d,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
quantiles = [0.5, 0.2, 0.8]
|
||||||
|
if provider == "Original":
|
||||||
|
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||||
|
lambda: model_attn.inference(
|
||||||
|
hidden_states,
|
||||||
|
attn_mask=attention_mask,
|
||||||
|
slope_rate=slope_rate,
|
||||||
|
past_key_value=past_kv,
|
||||||
|
),
|
||||||
|
quantiles=quantiles,
|
||||||
|
)
|
||||||
|
elif provider == "Triton":
|
||||||
|
|
||||||
|
def run_triton():
|
||||||
|
qkv = model_attn.act(model_attn.qkv_proj(hidden_states))
|
||||||
|
new_shape = qkv.size()[:-1] + (model_attn.num_heads, -1)
|
||||||
|
qkv = qkv.view(*new_shape)
|
||||||
|
q, k, v = torch.split(qkv, [model_attn.head_dim] * 3, dim=-1)
|
||||||
|
q = q.transpose(1, 2)
|
||||||
|
k = k.transpose(1, 2)
|
||||||
|
v = v.transpose(1, 2)
|
||||||
|
|
||||||
|
output, new_kv = lightning_attn_decode(q, k, v, past_kv, slope_rate)
|
||||||
|
output = output.transpose(1, 2).contiguous()
|
||||||
|
output = output.view(batch_size, seq_len, -1)
|
||||||
|
output = model_attn.norm(output)
|
||||||
|
output = torch.sigmoid(model_attn.output_gate(hidden_states)) * output
|
||||||
|
return model_attn.out_proj(output)
|
||||||
|
|
||||||
|
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||||
|
run_triton,
|
||||||
|
quantiles=quantiles,
|
||||||
|
)
|
||||||
|
else: # SGL
|
||||||
|
|
||||||
|
def run_sgl():
|
||||||
|
qkv = model_attn.act(model_attn.qkv_proj(hidden_states))
|
||||||
|
new_shape = qkv.size()[:-1] + (model_attn.num_heads, -1)
|
||||||
|
qkv = qkv.view(*new_shape)
|
||||||
|
q, k, v = torch.split(qkv, [model_attn.head_dim] * 3, dim=-1)
|
||||||
|
q = q.transpose(1, 2).contiguous()
|
||||||
|
k = k.transpose(1, 2).contiguous()
|
||||||
|
v = v.transpose(1, 2).contiguous()
|
||||||
|
|
||||||
|
output = torch.empty_like(v)
|
||||||
|
new_kv = torch.empty_like(past_kv)
|
||||||
|
sgl_lightning_attention_decode(
|
||||||
|
q, k, v, past_kv, slope_rate, output, new_kv
|
||||||
|
)
|
||||||
|
|
||||||
|
output = output.transpose(1, 2).contiguous()
|
||||||
|
output = output.view(batch_size, seq_len, -1)
|
||||||
|
output = model_attn.norm(output)
|
||||||
|
output = torch.sigmoid(model_attn.output_gate(hidden_states)) * output
|
||||||
|
return model_attn.out_proj(output)
|
||||||
|
|
||||||
|
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||||
|
run_sgl,
|
||||||
|
quantiles=quantiles,
|
||||||
|
)
|
||||||
|
|
||||||
|
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
||||||
|
|
||||||
|
return benchmark
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--save_path",
|
||||||
|
type=str,
|
||||||
|
default="./configs/benchmark_ops/lightning_attention_decode/",
|
||||||
|
help="Path to save lightning attention decode benchmark results",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
params = {
|
||||||
|
"hidden_size": 6144,
|
||||||
|
"num_attention_heads": 64,
|
||||||
|
"head_dim": 96,
|
||||||
|
"hidden_act": "silu",
|
||||||
|
}
|
||||||
|
# Run correctness test first
|
||||||
|
# Adapted from https://huggingface.co/MiniMaxAI/MiniMax-Text-01/blob/main/config.json
|
||||||
|
test_lightning_attention_implementations(params)
|
||||||
|
|
||||||
|
# Run performance benchmark
|
||||||
|
benchmark = get_benchmark()
|
||||||
|
benchmark.run(print_data=True, save_path=args.save_path)
|
||||||
|
|
@ -0,0 +1,603 @@
|
||||||
|
import itertools
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
from einops import rearrange
|
||||||
|
|
||||||
|
|
||||||
|
# Adapted from https://github.com/OpenNLPLab/lightning-attention/blob/main/lightning_attn/ops/triton/lightning_attn2.py
|
||||||
|
@triton.jit
|
||||||
|
def _fwd_kernel(
|
||||||
|
Q,
|
||||||
|
K,
|
||||||
|
V,
|
||||||
|
Out,
|
||||||
|
S, # log lambda
|
||||||
|
b: tl.constexpr,
|
||||||
|
h: tl.constexpr,
|
||||||
|
n: tl.constexpr,
|
||||||
|
d: tl.constexpr,
|
||||||
|
e: tl.constexpr,
|
||||||
|
BLOCK: tl.constexpr,
|
||||||
|
NUM_BLOCK: tl.constexpr,
|
||||||
|
BLOCK_MODEL: tl.constexpr,
|
||||||
|
):
|
||||||
|
##### get offset
|
||||||
|
off_bh = tl.program_id(0)
|
||||||
|
off_h = off_bh % h
|
||||||
|
off_e = tl.program_id(1)
|
||||||
|
qk_offset = off_bh * n * d
|
||||||
|
v_offset = off_bh * n * e
|
||||||
|
o_offset = off_bh * n * e
|
||||||
|
# channel offset
|
||||||
|
e_offset = off_e * BLOCK_MODEL
|
||||||
|
|
||||||
|
##### get block ptr
|
||||||
|
Q_block_ptr = Q + qk_offset + tl.arange(0, d)[None, :]
|
||||||
|
K_trans_block_ptr = K + qk_offset + tl.arange(0, d)[:, None]
|
||||||
|
V_block_ptr = V + v_offset + e_offset + tl.arange(0, BLOCK_MODEL)[None, :]
|
||||||
|
O_block_ptr = Out + o_offset + e_offset + tl.arange(0, BLOCK_MODEL)[None, :]
|
||||||
|
S_block_ptr = S + off_h
|
||||||
|
|
||||||
|
##### init diag decay(Lambda); q, k decay; kv
|
||||||
|
s = tl.load(S_block_ptr)
|
||||||
|
# q, k decay
|
||||||
|
off_block = tl.arange(
|
||||||
|
0, BLOCK
|
||||||
|
) # Not bug, this is a bit different from algorithm 1, but is mathematically equivalent
|
||||||
|
q_decay = tl.exp(-s.to(tl.float32) * off_block[:, None])
|
||||||
|
k_trans_decay = tl.exp(-s.to(tl.float32) * (BLOCK - off_block[None, :]))
|
||||||
|
block_decay = tl.exp(-s.to(tl.float32) * BLOCK)
|
||||||
|
# diag decay
|
||||||
|
index = off_block[:, None] - off_block[None, :]
|
||||||
|
s_index = s * index
|
||||||
|
s_index = tl.where(index >= 0, -s_index, float("-inf"))
|
||||||
|
diag_decay = tl.exp(s_index)
|
||||||
|
kv = tl.zeros([d, BLOCK_MODEL], dtype=tl.float32)
|
||||||
|
|
||||||
|
##### compute
|
||||||
|
for i in range(NUM_BLOCK):
|
||||||
|
# load
|
||||||
|
q = tl.load(
|
||||||
|
Q_block_ptr + off_block[:, None] * d, mask=off_block[:, None] < n, other=0.0
|
||||||
|
).to(tl.float32)
|
||||||
|
k_trans = tl.load(
|
||||||
|
K_trans_block_ptr + off_block[None, :] * d,
|
||||||
|
mask=off_block[None, :] < n,
|
||||||
|
other=0.0,
|
||||||
|
).to(tl.float32)
|
||||||
|
v = tl.load(
|
||||||
|
V_block_ptr + off_block[:, None] * e, mask=off_block[:, None] < n, other=0.0
|
||||||
|
).to(tl.float32)
|
||||||
|
|
||||||
|
# compute
|
||||||
|
qk = tl.dot(q, k_trans) * diag_decay
|
||||||
|
o_intra = tl.dot(qk, v)
|
||||||
|
o_inter = tl.dot(q, kv) * q_decay
|
||||||
|
o = o_intra + o_inter
|
||||||
|
|
||||||
|
# save and update
|
||||||
|
tl.store(
|
||||||
|
O_block_ptr + off_block[:, None] * e,
|
||||||
|
o.to(O_block_ptr.dtype.element_ty),
|
||||||
|
mask=off_block[:, None] < n,
|
||||||
|
)
|
||||||
|
kv = block_decay * kv + tl.dot(k_trans * k_trans_decay, v)
|
||||||
|
off_block += BLOCK
|
||||||
|
|
||||||
|
|
||||||
|
def lightning_attn2(q, k, v, s):
|
||||||
|
q = q.contiguous()
|
||||||
|
k = k.contiguous()
|
||||||
|
v = v.contiguous()
|
||||||
|
s = s.contiguous()
|
||||||
|
|
||||||
|
b, h, n, d = q.shape
|
||||||
|
e = v.shape[-1]
|
||||||
|
|
||||||
|
# Pad d to next power of 2
|
||||||
|
d_padded = next_power_of_2(d)
|
||||||
|
if d_padded != d:
|
||||||
|
q_padded = F.pad(q, (0, d_padded - d))
|
||||||
|
k_padded = F.pad(k, (0, d_padded - d))
|
||||||
|
else:
|
||||||
|
q_padded = q
|
||||||
|
k_padded = k
|
||||||
|
|
||||||
|
# Pad e to next power of 2
|
||||||
|
e_padded = next_power_of_2(e)
|
||||||
|
if e_padded != e:
|
||||||
|
v_padded = F.pad(v, (0, e_padded - e))
|
||||||
|
else:
|
||||||
|
v_padded = v
|
||||||
|
|
||||||
|
o_padded = torch.empty((b, h, n, e_padded), dtype=q.dtype, device=q.device)
|
||||||
|
|
||||||
|
BLOCK = 64
|
||||||
|
NUM_BLOCK = triton.cdiv(q.shape[2], BLOCK)
|
||||||
|
# parallel over channel
|
||||||
|
BLOCK_MODEL = min(triton.next_power_of_2(e_padded), 32)
|
||||||
|
grid = (b * h, triton.cdiv(e_padded, BLOCK_MODEL))
|
||||||
|
|
||||||
|
_fwd_kernel[grid](
|
||||||
|
q_padded,
|
||||||
|
k_padded,
|
||||||
|
v_padded,
|
||||||
|
o_padded,
|
||||||
|
s,
|
||||||
|
b,
|
||||||
|
h,
|
||||||
|
n,
|
||||||
|
d_padded,
|
||||||
|
e_padded,
|
||||||
|
BLOCK=BLOCK,
|
||||||
|
NUM_BLOCK=NUM_BLOCK,
|
||||||
|
BLOCK_MODEL=BLOCK_MODEL,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Remove padding from output
|
||||||
|
if e_padded != e:
|
||||||
|
o = o_padded[..., :e]
|
||||||
|
else:
|
||||||
|
o = o_padded
|
||||||
|
|
||||||
|
return o
|
||||||
|
|
||||||
|
|
||||||
|
def is_support(dim):
|
||||||
|
return 16 % dim
|
||||||
|
|
||||||
|
|
||||||
|
def next_power_of_2(n):
|
||||||
|
return 2 ** (int(math.ceil(math.log(n, 2))))
|
||||||
|
|
||||||
|
|
||||||
|
def lightning_attn_func(q, k, v, s):
|
||||||
|
b, h, n, d = q.shape
|
||||||
|
e = v.shape[-1]
|
||||||
|
assert is_support(d) and is_support(e)
|
||||||
|
|
||||||
|
# pad v's feature dim to power of 2
|
||||||
|
e_pad = next_power_of_2(e)
|
||||||
|
need_pad = e_pad != e
|
||||||
|
if need_pad:
|
||||||
|
v = F.pad(v, (0, e_pad - e))
|
||||||
|
|
||||||
|
if d > 128:
|
||||||
|
# split over head
|
||||||
|
if 64 % d:
|
||||||
|
m = 64
|
||||||
|
elif 32 % d:
|
||||||
|
m = 32
|
||||||
|
elif 16 % d:
|
||||||
|
m = 16
|
||||||
|
arr = [m * i for i in range(d // m + 1)]
|
||||||
|
if arr[-1] != d:
|
||||||
|
arr.append(d)
|
||||||
|
n = len(arr)
|
||||||
|
o = 0
|
||||||
|
for i in range(n - 1):
|
||||||
|
start = arr[i]
|
||||||
|
end = arr[i + 1]
|
||||||
|
q1 = q[..., start:end]
|
||||||
|
k1 = k[..., start:end]
|
||||||
|
o += lightning_attn2(q1, k1, v, s)
|
||||||
|
else:
|
||||||
|
o = lightning_attn2(q, k, v, s)
|
||||||
|
|
||||||
|
if need_pad:
|
||||||
|
o = o[:, :, :, :e]
|
||||||
|
|
||||||
|
return o
|
||||||
|
|
||||||
|
|
||||||
|
debug = eval(os.environ.get("debug", default="False"))
|
||||||
|
|
||||||
|
BLOCK = 256
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->MiniMaxText01
|
||||||
|
class MiniMaxText01RMSNorm(nn.Module):
|
||||||
|
def __init__(self, hidden_size, eps=1e-6):
|
||||||
|
"""
|
||||||
|
MiniMaxText01RMSNorm is equivalent to T5LayerNorm
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||||
|
self.variance_epsilon = eps
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
input_dtype = hidden_states.dtype
|
||||||
|
hidden_states = hidden_states.to(torch.float32)
|
||||||
|
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||||
|
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||||
|
return self.weight * hidden_states.to(input_dtype)
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from https://huggingface.co/MiniMaxAI/MiniMax-Text-01/blob/main/modeling_minimax_text_01.py
|
||||||
|
def get_activation_fn(activation):
|
||||||
|
if debug:
|
||||||
|
logger.info(f"activation: {activation}")
|
||||||
|
if activation == "gelu":
|
||||||
|
return F.gelu
|
||||||
|
elif activation == "relu":
|
||||||
|
return F.relu
|
||||||
|
elif activation == "elu":
|
||||||
|
return F.elu
|
||||||
|
elif activation == "sigmoid":
|
||||||
|
return F.sigmoid
|
||||||
|
elif activation == "exp":
|
||||||
|
|
||||||
|
def f(x):
|
||||||
|
with torch.no_grad():
|
||||||
|
x_max = torch.max(x, dim=-1, keepdims=True).values
|
||||||
|
y = torch.exp(x - x_max)
|
||||||
|
|
||||||
|
return y
|
||||||
|
|
||||||
|
return f
|
||||||
|
elif activation == "leak":
|
||||||
|
return F.leaky_relu
|
||||||
|
elif activation == "1+elu":
|
||||||
|
|
||||||
|
def f(x):
|
||||||
|
return 1 + F.elu(x)
|
||||||
|
|
||||||
|
return f
|
||||||
|
elif activation == "2+elu":
|
||||||
|
|
||||||
|
def f(x):
|
||||||
|
return 2 + F.elu(x)
|
||||||
|
|
||||||
|
return f
|
||||||
|
elif activation == "silu" or activation == "swish":
|
||||||
|
return F.silu
|
||||||
|
elif activation == "sine":
|
||||||
|
return torch.sin
|
||||||
|
else:
|
||||||
|
logger.info(f"activation: does not support {activation}, use Identity!!!")
|
||||||
|
return lambda x: x
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from https://huggingface.co/MiniMaxAI/MiniMax-Text-01/blob/main/modeling_minimax_text_01.py
|
||||||
|
class MiniMaxText01LightningAttention(nn.Module):
|
||||||
|
def __init__(self, config=None, layer_idx: Optional[int] = None, **kwargs):
|
||||||
|
super().__init__()
|
||||||
|
if config is None:
|
||||||
|
config = type("Config", (), kwargs)
|
||||||
|
|
||||||
|
bias = False
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.num_heads = config.num_attention_heads
|
||||||
|
self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads)
|
||||||
|
|
||||||
|
self.out_proj = nn.Linear(
|
||||||
|
self.head_dim * self.num_heads, self.hidden_size, bias=bias
|
||||||
|
)
|
||||||
|
self.act = get_activation_fn(config.hidden_act)
|
||||||
|
self.norm = MiniMaxText01RMSNorm(self.head_dim * self.num_heads)
|
||||||
|
|
||||||
|
self.qkv_proj = nn.Linear(
|
||||||
|
self.hidden_size, 3 * self.head_dim * self.num_heads, bias=bias
|
||||||
|
)
|
||||||
|
self.output_gate = nn.Linear(
|
||||||
|
self.hidden_size, self.head_dim * self.num_heads, bias=bias
|
||||||
|
)
|
||||||
|
|
||||||
|
# for inference only
|
||||||
|
self.offset = 0
|
||||||
|
self.layer_idx = layer_idx
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
attn_mask: Optional[torch.Tensor] = None, # (b, h, n, m)
|
||||||
|
output_attentions: bool = False,
|
||||||
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
|
use_cache: bool = False,
|
||||||
|
slope_rate: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
if (not self.training) and (not do_eval):
|
||||||
|
return self.inference(
|
||||||
|
hidden_states,
|
||||||
|
attn_mask,
|
||||||
|
output_attentions,
|
||||||
|
past_key_value,
|
||||||
|
use_cache,
|
||||||
|
slope_rate,
|
||||||
|
)
|
||||||
|
|
||||||
|
def inference(
|
||||||
|
self,
|
||||||
|
x,
|
||||||
|
attn_mask: Optional[torch.Tensor] = None, # (b, n)
|
||||||
|
output_attentions: bool = False,
|
||||||
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
|
use_cache: bool = False,
|
||||||
|
slope_rate: Optional[torch.Tensor] = None, # (h, 1, 1)
|
||||||
|
):
|
||||||
|
# x: b n d
|
||||||
|
b, n, d = x.shape
|
||||||
|
# linear map
|
||||||
|
qkv = self.act(self.qkv_proj(x))
|
||||||
|
new_shape = qkv.size()[:-1] + (self.num_heads, -1)
|
||||||
|
qkv = qkv.view(*new_shape)
|
||||||
|
q, k, v = torch.split(qkv, [self.head_dim] * 3, dim=3)
|
||||||
|
q = q.transpose(1, 2)
|
||||||
|
k = k.transpose(1, 2)
|
||||||
|
v = v.transpose(1, 2)
|
||||||
|
|
||||||
|
if past_key_value is None:
|
||||||
|
self.offset = q.shape[-2]
|
||||||
|
else:
|
||||||
|
self.offset += 1
|
||||||
|
|
||||||
|
# for align with metaseq
|
||||||
|
ratio = torch.exp(-slope_rate)
|
||||||
|
|
||||||
|
# only use for the first time
|
||||||
|
if past_key_value is None:
|
||||||
|
slope_rate = slope_rate.to(torch.float32)
|
||||||
|
if attn_mask is not None:
|
||||||
|
v = v.masked_fill(
|
||||||
|
(1 - attn_mask).unsqueeze(1).unsqueeze(-1).to(torch.bool), 0
|
||||||
|
)
|
||||||
|
NUM_BLOCK = (n + BLOCK - 1) // BLOCK
|
||||||
|
b, h, n, d = q.shape
|
||||||
|
e = v.shape[-1]
|
||||||
|
# other
|
||||||
|
array = torch.arange(BLOCK).to(q) + 1
|
||||||
|
q_decay = torch.exp(-slope_rate * array.reshape(-1, 1))
|
||||||
|
k_decay = torch.exp(-slope_rate * (BLOCK - array.reshape(-1, 1)))
|
||||||
|
index = array[:, None] - array[None, :]
|
||||||
|
s_index = (
|
||||||
|
slope_rate
|
||||||
|
* index[
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
s_index = torch.where(index >= 0, -s_index, float("-inf"))
|
||||||
|
diag_decay = torch.exp(s_index)
|
||||||
|
|
||||||
|
kv = torch.zeros(b, h, d, e).to(torch.float32).to(q.device)
|
||||||
|
output = torch.empty((b, h, n, e), dtype=q.dtype, device=q.device)
|
||||||
|
for i in range(NUM_BLOCK):
|
||||||
|
si = i * BLOCK
|
||||||
|
ei = min(si + BLOCK, n)
|
||||||
|
m = ei - si
|
||||||
|
qi = q[:, :, si:ei].contiguous()
|
||||||
|
ki = k[:, :, si:ei].contiguous()
|
||||||
|
vi = v[:, :, si:ei].contiguous()
|
||||||
|
qkv_none_diag = torch.matmul(qi * q_decay[:, :m], kv).to(torch.float32)
|
||||||
|
|
||||||
|
# diag
|
||||||
|
qk = (
|
||||||
|
torch.matmul(qi, ki.transpose(-1, -2)).to(torch.float32)
|
||||||
|
* diag_decay[:, :, :m, :m]
|
||||||
|
)
|
||||||
|
qkv_diag = torch.matmul(qk, vi.to(torch.float32))
|
||||||
|
block_decay = torch.exp(-slope_rate * m)
|
||||||
|
output[:, :, si:ei] = qkv_none_diag + qkv_diag
|
||||||
|
kv = block_decay * kv + torch.matmul(
|
||||||
|
(ki * k_decay[:, -m:]).transpose(-1, -2).to(vi.dtype), vi
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
kv = past_key_value
|
||||||
|
output = []
|
||||||
|
for i in range(n):
|
||||||
|
kv = ratio * kv + torch.einsum(
|
||||||
|
"... n d, ... n e -> ... d e",
|
||||||
|
k[:, :, i : i + 1],
|
||||||
|
v[:, :, i : i + 1],
|
||||||
|
)
|
||||||
|
qkv = torch.einsum(
|
||||||
|
"... n e, ... e d -> ... n d", q[:, :, i : i + 1], kv.to(q.dtype)
|
||||||
|
)
|
||||||
|
output.append(qkv)
|
||||||
|
output = torch.cat(output, dim=-2)
|
||||||
|
# reshape
|
||||||
|
output = rearrange(output, "b h n d -> b n (h d)")
|
||||||
|
# normalize
|
||||||
|
output = self.norm(output)
|
||||||
|
# gate
|
||||||
|
output = F.sigmoid(self.output_gate(x)) * output
|
||||||
|
# outproj
|
||||||
|
output = self.out_proj(output)
|
||||||
|
|
||||||
|
attn_weights = None
|
||||||
|
|
||||||
|
return output, attn_weights, kv
|
||||||
|
|
||||||
|
|
||||||
|
def _build_slope_tensor(n_attention_heads: int):
|
||||||
|
def get_slopes(n):
|
||||||
|
def get_slopes_power_of_2(n):
|
||||||
|
start = 2 ** (-(2 ** -(math.log2(n) - 3)))
|
||||||
|
ratio = start
|
||||||
|
return [start * ratio**i for i in range(n)]
|
||||||
|
|
||||||
|
if math.log2(n).is_integer():
|
||||||
|
return get_slopes_power_of_2(
|
||||||
|
n
|
||||||
|
) # In the paper, we only train models that have 2^a heads for some a. This function has
|
||||||
|
else: # some good properties that only occur when the input is a power of 2. To maintain that even
|
||||||
|
closest_power_of_2 = 2 ** math.floor(
|
||||||
|
math.log2(n)
|
||||||
|
) # when the number of heads is not a power of 2, we use this workaround.
|
||||||
|
return (
|
||||||
|
get_slopes_power_of_2(closest_power_of_2)
|
||||||
|
+ get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2]
|
||||||
|
)
|
||||||
|
|
||||||
|
# h, 1, 1
|
||||||
|
slopes = torch.tensor(get_slopes(n_attention_heads)).reshape(
|
||||||
|
n_attention_heads, 1, 1
|
||||||
|
)
|
||||||
|
|
||||||
|
return slopes
|
||||||
|
|
||||||
|
|
||||||
|
def test_lightning_attention_implementations(model_params):
|
||||||
|
torch.manual_seed(42)
|
||||||
|
|
||||||
|
batch_size = 2
|
||||||
|
seq_len = 1024
|
||||||
|
dtype = torch.bfloat16
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
hidden_states = torch.randn(
|
||||||
|
batch_size, seq_len, model_params["hidden_size"], dtype=dtype, device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
attention_mask = torch.ones(batch_size, seq_len, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
slope_rate = _build_slope_tensor(model_params["num_attention_heads"]).to(device)
|
||||||
|
|
||||||
|
model_attn = MiniMaxText01LightningAttention(**model_params).to(dtype).to(device)
|
||||||
|
model_attn.eval()
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
model_output, _, _ = model_attn.inference(
|
||||||
|
hidden_states, attn_mask=attention_mask, slope_rate=slope_rate
|
||||||
|
)
|
||||||
|
|
||||||
|
qkv = model_attn.act(model_attn.qkv_proj(hidden_states))
|
||||||
|
new_shape = qkv.size()[:-1] + (model_attn.num_heads, -1)
|
||||||
|
qkv = qkv.view(*new_shape)
|
||||||
|
q, k, v = torch.split(qkv, [model_attn.head_dim] * 3, dim=-1)
|
||||||
|
q = q.transpose(1, 2)
|
||||||
|
k = k.transpose(1, 2)
|
||||||
|
v = v.transpose(1, 2)
|
||||||
|
|
||||||
|
lib_output = lightning_attn_func(q, k, v, slope_rate)
|
||||||
|
lib_output = lib_output.transpose(1, 2).contiguous()
|
||||||
|
lib_output = lib_output.view(batch_size, seq_len, -1)
|
||||||
|
lib_output = model_attn.norm(lib_output)
|
||||||
|
lib_output = torch.sigmoid(model_attn.output_gate(hidden_states)) * lib_output
|
||||||
|
lib_output = model_attn.out_proj(lib_output)
|
||||||
|
|
||||||
|
torch.testing.assert_close(
|
||||||
|
model_output,
|
||||||
|
lib_output,
|
||||||
|
rtol=1e-3,
|
||||||
|
atol=1e-2,
|
||||||
|
msg="Lightning attention implementations produce different results",
|
||||||
|
)
|
||||||
|
|
||||||
|
print("✅ Two implementations match")
|
||||||
|
|
||||||
|
|
||||||
|
def get_benchmark():
|
||||||
|
batch_size_range = [2**i for i in range(0, 7)] # max 64
|
||||||
|
seq_length_range = [256, 512, 1024, 2048, 4096] # max 4096
|
||||||
|
configs = list(itertools.product(batch_size_range, seq_length_range))
|
||||||
|
|
||||||
|
@triton.testing.perf_report(
|
||||||
|
triton.testing.Benchmark(
|
||||||
|
x_names=["batch_size", "seq_len"],
|
||||||
|
x_vals=[list(_) for _ in configs],
|
||||||
|
line_arg="provider",
|
||||||
|
line_vals=["MiniMax-Text-01", "OpenNLPLab"],
|
||||||
|
line_names=[
|
||||||
|
"MiniMax-Text-01 Model Implementation",
|
||||||
|
"OpenNLPLab Library Implementation",
|
||||||
|
],
|
||||||
|
styles=[("blue", "-"), ("green", "-")],
|
||||||
|
ylabel="us",
|
||||||
|
plot_name="lightning-attention-prefill-performance",
|
||||||
|
args={},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
def benchmark(batch_size, seq_len, provider):
|
||||||
|
dtype = torch.bfloat16
|
||||||
|
device = torch.device("cuda")
|
||||||
|
|
||||||
|
params = {
|
||||||
|
"hidden_size": 6144,
|
||||||
|
"num_attention_heads": 64,
|
||||||
|
"head_dim": 96,
|
||||||
|
"hidden_act": "gelu",
|
||||||
|
}
|
||||||
|
|
||||||
|
hidden_states = torch.randn(
|
||||||
|
batch_size, seq_len, params["hidden_size"], dtype=dtype, device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
attention_mask = torch.ones(batch_size, seq_len, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
slope_rate = _build_slope_tensor(params["num_attention_heads"]).to(device)
|
||||||
|
model_attn = MiniMaxText01LightningAttention(**params).to(dtype).to(device)
|
||||||
|
model_attn.eval()
|
||||||
|
|
||||||
|
quantiles = [0.5, 0.2, 0.8]
|
||||||
|
if provider == "MiniMax-Text-01":
|
||||||
|
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||||
|
lambda: model_attn.inference(
|
||||||
|
hidden_states, attn_mask=attention_mask, slope_rate=slope_rate
|
||||||
|
),
|
||||||
|
quantiles=quantiles,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
|
||||||
|
def run_lib():
|
||||||
|
qkv = model_attn.act(model_attn.qkv_proj(hidden_states))
|
||||||
|
new_shape = qkv.size()[:-1] + (model_attn.num_heads, -1)
|
||||||
|
qkv = qkv.view(*new_shape)
|
||||||
|
q, k, v = torch.split(qkv, [model_attn.head_dim] * 3, dim=-1)
|
||||||
|
q = q.transpose(1, 2)
|
||||||
|
k = k.transpose(1, 2)
|
||||||
|
v = v.transpose(1, 2)
|
||||||
|
|
||||||
|
lib_output = lightning_attn_func(q, k, v, slope_rate)
|
||||||
|
lib_output = lib_output.transpose(1, 2).contiguous()
|
||||||
|
lib_output = lib_output.view(batch_size, seq_len, -1)
|
||||||
|
lib_output = model_attn.norm(lib_output)
|
||||||
|
lib_output = (
|
||||||
|
torch.sigmoid(model_attn.output_gate(hidden_states)) * lib_output
|
||||||
|
)
|
||||||
|
return model_attn.out_proj(lib_output)
|
||||||
|
|
||||||
|
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||||
|
run_lib,
|
||||||
|
quantiles=quantiles,
|
||||||
|
)
|
||||||
|
|
||||||
|
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
||||||
|
|
||||||
|
return benchmark
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--save_path",
|
||||||
|
type=str,
|
||||||
|
default="./configs/benchmark_ops/lightning_attention_prefill/",
|
||||||
|
help="Path to save lightning attention prefill benchmark results",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Run correctness test first
|
||||||
|
# Adapted from https://huggingface.co/MiniMaxAI/MiniMax-Text-01/blob/main/config.json
|
||||||
|
params = {
|
||||||
|
"hidden_size": 6144,
|
||||||
|
"num_attention_heads": 64,
|
||||||
|
"head_dim": 96,
|
||||||
|
"hidden_act": "silu",
|
||||||
|
}
|
||||||
|
test_lightning_attention_implementations(params)
|
||||||
|
|
||||||
|
# Run performance benchmark
|
||||||
|
benchmark = get_benchmark()
|
||||||
|
benchmark.run(print_data=True, save_path=args.save_path)
|
||||||
|
|
@ -0,0 +1,94 @@
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import triton
|
||||||
|
from vllm._custom_ops import scaled_int8_quant as vllm_scaled_int8_quant
|
||||||
|
|
||||||
|
from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
|
||||||
|
|
||||||
|
|
||||||
|
@torch.compile(backend="inductor")
|
||||||
|
def torch_int8_quant(x):
|
||||||
|
int8_max = torch.iinfo(torch.int8).max
|
||||||
|
|
||||||
|
abs_max = x.abs().max(dim=-1, keepdim=True).values
|
||||||
|
scales = abs_max.to(torch.float32) / float(int8_max)
|
||||||
|
|
||||||
|
q_x = (x / scales).round().to(torch.int8)
|
||||||
|
|
||||||
|
return q_x, scales
|
||||||
|
|
||||||
|
|
||||||
|
def _test_accuracy_once(M, K, input_dtype, device):
|
||||||
|
x = torch.randn(M, K, dtype=input_dtype, device=device) * 5000
|
||||||
|
out, scales, _ = vllm_scaled_int8_quant(x, symmetric=True)
|
||||||
|
out1, scales1 = per_token_quant_int8(x)
|
||||||
|
out2, scales2 = torch_int8_quant(x)
|
||||||
|
torch.testing.assert_close(out, out2, atol=1, rtol=0)
|
||||||
|
torch.testing.assert_close(out, out1, atol=1, rtol=0)
|
||||||
|
torch.testing.assert_close(scales, scales2)
|
||||||
|
torch.testing.assert_close(scales1, scales2)
|
||||||
|
print(f"M: {M}, K: {K}, type: {input_dtype} OK")
|
||||||
|
|
||||||
|
|
||||||
|
def test_accuracy():
|
||||||
|
Ms = [1, 13, 128, 1024, 2048, 4096]
|
||||||
|
Ks = [512, 1024, 2048, 8192]
|
||||||
|
input_dtypes = [torch.float16, torch.bfloat16]
|
||||||
|
for M in Ms:
|
||||||
|
for K in Ks:
|
||||||
|
for input_dtype in input_dtypes:
|
||||||
|
_test_accuracy_once(M, K, input_dtype, "cuda")
|
||||||
|
|
||||||
|
|
||||||
|
@triton.testing.perf_report(
|
||||||
|
triton.testing.Benchmark(
|
||||||
|
x_names=["batch_size"],
|
||||||
|
x_vals=[1, 16, 32, 64, 128, 256, 512, 1024, 2048],
|
||||||
|
x_log=False,
|
||||||
|
line_arg="provider",
|
||||||
|
line_vals=["vllm op", "triton", "torch.compile"],
|
||||||
|
line_names=["vllm op", "triton", "torch.compile"],
|
||||||
|
styles=[("blue", "-"), ("orange", "-"), ("red", "-")],
|
||||||
|
ylabel="ms",
|
||||||
|
plot_name="int8 per token quant",
|
||||||
|
args={},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
def benchmark(batch_size, provider):
|
||||||
|
M, K = batch_size, 16384
|
||||||
|
x = torch.randn(M, K, dtype=torch.float16, device="cuda") * 1000
|
||||||
|
|
||||||
|
quantiles = [0.5, 0.2, 0.8]
|
||||||
|
if provider == "vllm op":
|
||||||
|
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||||
|
lambda: vllm_scaled_int8_quant(x, symmetric=True),
|
||||||
|
quantiles=quantiles,
|
||||||
|
)
|
||||||
|
if provider == "triton":
|
||||||
|
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||||
|
lambda: per_token_quant_int8(x),
|
||||||
|
quantiles=quantiles,
|
||||||
|
)
|
||||||
|
if provider == "torch.compile":
|
||||||
|
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||||
|
lambda: torch_int8_quant(x),
|
||||||
|
quantiles=quantiles,
|
||||||
|
)
|
||||||
|
|
||||||
|
return ms, min_ms, max_ms
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--save_path",
|
||||||
|
type=str,
|
||||||
|
default="./bench_int8_quant_res",
|
||||||
|
help="Path to save int8 quant benchmark results",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
test_accuracy()
|
||||||
|
|
||||||
|
benchmark.run(print_data=True, show_plots=True, save_path=args.save_path)
|
||||||
|
|
@ -0,0 +1,474 @@
|
||||||
|
# Copyright 2025 SGLang Team
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import multiprocessing as mp
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import triton
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
mp.set_start_method("spawn", force=True)
|
||||||
|
|
||||||
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||||
|
_w8a8_block_fp8_matmul,
|
||||||
|
_w8a8_block_fp8_matmul_unrolledx4,
|
||||||
|
)
|
||||||
|
from sglang.srt.layers.quantization.int8_kernel import _w8a8_block_int8_matmul
|
||||||
|
from sglang.srt.utils import get_device_core_count, get_device_name, is_hip
|
||||||
|
|
||||||
|
is_hip_ = is_hip()
|
||||||
|
|
||||||
|
DTYPE_MAP = {
|
||||||
|
"float32": torch.float32,
|
||||||
|
"float16": torch.float16,
|
||||||
|
"half": torch.half,
|
||||||
|
"bfloat16": torch.bfloat16,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def w8a8_block_matmul(
|
||||||
|
A: torch.Tensor,
|
||||||
|
B: torch.Tensor,
|
||||||
|
As: torch.Tensor,
|
||||||
|
Bs: torch.Tensor,
|
||||||
|
block_size: List[int],
|
||||||
|
config: Dict[str, Any],
|
||||||
|
output_dtype: torch.dtype = torch.float16,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""This function performs matrix multiplication with block-wise quantization.
|
||||||
|
|
||||||
|
It takes two input tensors `A` and `B` with scales `As` and `Bs`.
|
||||||
|
The output is returned in the specified `output_dtype`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
A: The input tensor, e.g., activation.
|
||||||
|
B: The input tensor, e.g., weight.
|
||||||
|
As: The per-token-group quantization scale for `A`.
|
||||||
|
Bs: The per-block quantization scale for `B`.
|
||||||
|
block_size: The block size for per-block quantization. It should be 2-dim, e.g., [128, 128].
|
||||||
|
output_dytpe: The dtype of the returned tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: The result of matmul.
|
||||||
|
"""
|
||||||
|
assert len(block_size) == 2
|
||||||
|
block_n, block_k = block_size[0], block_size[1]
|
||||||
|
|
||||||
|
assert A.shape[-1] == B.shape[-1]
|
||||||
|
assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous()
|
||||||
|
assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1]
|
||||||
|
M = A.numel() // A.shape[-1]
|
||||||
|
|
||||||
|
assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2
|
||||||
|
N, K = B.shape
|
||||||
|
assert triton.cdiv(N, block_n) == Bs.shape[0]
|
||||||
|
assert triton.cdiv(K, block_k) == Bs.shape[1]
|
||||||
|
|
||||||
|
C_shape = A.shape[:-1] + (N,)
|
||||||
|
C = A.new_empty(C_shape, dtype=output_dtype)
|
||||||
|
|
||||||
|
def grid(META):
|
||||||
|
return (
|
||||||
|
triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use manually unrolledx4 kernel on AMD GPU when the grid size is small.
|
||||||
|
# Empirical testing shows the sweet spot lies when it's less than the # of
|
||||||
|
# compute units available on the device.
|
||||||
|
num_workgroups = triton.cdiv(M, config["BLOCK_SIZE_M"]) * triton.cdiv(
|
||||||
|
N, config["BLOCK_SIZE_N"]
|
||||||
|
)
|
||||||
|
|
||||||
|
if A.dtype == torch.float8_e4m3fnuz or A.dtype == torch.float8_e4m3fn:
|
||||||
|
kernel = (
|
||||||
|
_w8a8_block_fp8_matmul_unrolledx4
|
||||||
|
if (is_hip_ == True and num_workgroups <= get_device_core_count())
|
||||||
|
else _w8a8_block_fp8_matmul
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
kernel = _w8a8_block_int8_matmul
|
||||||
|
|
||||||
|
kernel[grid](
|
||||||
|
A,
|
||||||
|
B,
|
||||||
|
C,
|
||||||
|
As,
|
||||||
|
Bs,
|
||||||
|
M,
|
||||||
|
N,
|
||||||
|
K,
|
||||||
|
block_n,
|
||||||
|
block_k,
|
||||||
|
A.stride(-2),
|
||||||
|
A.stride(-1),
|
||||||
|
B.stride(1),
|
||||||
|
B.stride(0),
|
||||||
|
C.stride(-2),
|
||||||
|
C.stride(-1),
|
||||||
|
As.stride(-2),
|
||||||
|
As.stride(-1),
|
||||||
|
Bs.stride(1),
|
||||||
|
Bs.stride(0),
|
||||||
|
**config,
|
||||||
|
)
|
||||||
|
|
||||||
|
return C
|
||||||
|
|
||||||
|
|
||||||
|
def get_rocm_configs_compute_bound():
|
||||||
|
configs = []
|
||||||
|
waves_per_eu_range = 0
|
||||||
|
for num_stages in [2]:
|
||||||
|
for block_m in [32, 64, 128, 256]:
|
||||||
|
for block_k in [32, 64, 128, 256]:
|
||||||
|
for block_n in [16, 32, 64, 128, 256]:
|
||||||
|
for num_warps in [4, 8]:
|
||||||
|
for group_size in [1, 4, 8, 16, 32]:
|
||||||
|
configs.append(
|
||||||
|
{
|
||||||
|
"BLOCK_SIZE_M": block_m,
|
||||||
|
"BLOCK_SIZE_N": block_n,
|
||||||
|
"BLOCK_SIZE_K": block_k,
|
||||||
|
"GROUP_SIZE_M": group_size,
|
||||||
|
"num_warps": num_warps,
|
||||||
|
"num_stages": num_stages,
|
||||||
|
"waves_per_eu": waves_per_eu_range,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return configs
|
||||||
|
|
||||||
|
|
||||||
|
def get_configs_compute_bound():
|
||||||
|
configs = []
|
||||||
|
if is_hip_:
|
||||||
|
configs = get_rocm_configs_compute_bound()
|
||||||
|
else:
|
||||||
|
for num_stages in [2, 3, 4, 5]:
|
||||||
|
for block_m in [16, 32, 64, 128, 256]:
|
||||||
|
for block_k in [64, 128]:
|
||||||
|
for block_n in [32, 64, 128, 256]:
|
||||||
|
for num_warps in [4, 8]:
|
||||||
|
for group_size in [1, 16, 32, 64]:
|
||||||
|
configs.append(
|
||||||
|
{
|
||||||
|
"BLOCK_SIZE_M": block_m,
|
||||||
|
"BLOCK_SIZE_N": block_n,
|
||||||
|
"BLOCK_SIZE_K": block_k,
|
||||||
|
"GROUP_SIZE_M": group_size,
|
||||||
|
"num_warps": num_warps,
|
||||||
|
"num_stages": num_stages,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return configs
|
||||||
|
|
||||||
|
|
||||||
|
def get_weight_shapes(tp_size):
|
||||||
|
# NOTE(HandH1998): The weight shapes only works for DeepSeek-V3. Modify them, if you tune for another different model.
|
||||||
|
# cannot TP
|
||||||
|
total = [
|
||||||
|
(512 + 64, 7168),
|
||||||
|
((128 + 64) * 128, 7168),
|
||||||
|
(128 * (128 + 128), 512),
|
||||||
|
(7168, 16384),
|
||||||
|
(7168, 18432),
|
||||||
|
]
|
||||||
|
# N can TP
|
||||||
|
n_tp = [
|
||||||
|
(18432 * 2, 7168),
|
||||||
|
((128 + 64) * 128, 7168),
|
||||||
|
(128 * (128 + 128), 512),
|
||||||
|
(24576, 1536),
|
||||||
|
(4096, 7168),
|
||||||
|
]
|
||||||
|
# K can TP
|
||||||
|
k_tp = [(7168, 18432), (7168, 16384), (7168, 2048)]
|
||||||
|
|
||||||
|
weight_shapes = []
|
||||||
|
for t in total:
|
||||||
|
weight_shapes.append(t)
|
||||||
|
for n_t in n_tp:
|
||||||
|
new_t = (n_t[0] // tp_size, n_t[1])
|
||||||
|
weight_shapes.append(new_t)
|
||||||
|
for k_t in k_tp:
|
||||||
|
new_t = (k_t[0], k_t[1] // tp_size)
|
||||||
|
weight_shapes.append(new_t)
|
||||||
|
return weight_shapes
|
||||||
|
|
||||||
|
|
||||||
|
def benchmark_config(
|
||||||
|
A, B, As, Bs, block_size, config, out_dtype=torch.float16, num_iters=10
|
||||||
|
):
|
||||||
|
def run():
|
||||||
|
w8a8_block_matmul(A, B, As, Bs, block_size, config, out_dtype)
|
||||||
|
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
# JIT complication & warmup
|
||||||
|
for _ in range(5):
|
||||||
|
run()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
start_event = torch.cuda.Event(enable_timing=True)
|
||||||
|
end_event = torch.cuda.Event(enable_timing=True)
|
||||||
|
|
||||||
|
latencies: List[float] = []
|
||||||
|
for i in range(num_iters):
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
start_event.record()
|
||||||
|
run()
|
||||||
|
end_event.record()
|
||||||
|
end_event.synchronize()
|
||||||
|
latencies.append(start_event.elapsed_time(end_event))
|
||||||
|
avg = sum(latencies) / (num_iters * 10) * 1000 # us
|
||||||
|
return avg
|
||||||
|
|
||||||
|
|
||||||
|
def tune(M, N, K, block_size, out_dtype, search_space, input_type):
|
||||||
|
factor_for_scale = 1e-2
|
||||||
|
|
||||||
|
if input_type == "fp8":
|
||||||
|
fp8_info = torch.finfo(
|
||||||
|
torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn
|
||||||
|
)
|
||||||
|
fp8_max, fp8_min = fp8_info.max, fp8_info.min
|
||||||
|
|
||||||
|
A_fp32 = (
|
||||||
|
(torch.rand(M, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * fp8_max
|
||||||
|
)
|
||||||
|
A = A_fp32.clamp(min=fp8_min, max=fp8_max).to(
|
||||||
|
torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn
|
||||||
|
)
|
||||||
|
|
||||||
|
B_fp32 = (
|
||||||
|
(torch.rand(N, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * fp8_max
|
||||||
|
)
|
||||||
|
B = B_fp32.clamp(min=fp8_min, max=fp8_max).to(
|
||||||
|
torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
int8_info = torch.iinfo(torch.int8)
|
||||||
|
int8_max, int8_min = int8_info.max, int8_info.min
|
||||||
|
|
||||||
|
A_fp32 = (
|
||||||
|
(torch.rand(M, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * int8_max
|
||||||
|
)
|
||||||
|
A = A_fp32.clamp(min=int8_min, max=int8_max).to(torch.int8)
|
||||||
|
|
||||||
|
B_fp32 = (
|
||||||
|
(torch.rand(N, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * int8_max
|
||||||
|
)
|
||||||
|
B = B_fp32.clamp(min=int8_min, max=int8_max).to(torch.int8)
|
||||||
|
|
||||||
|
block_n, block_k = block_size[0], block_size[1]
|
||||||
|
n_tiles = (N + block_n - 1) // block_n
|
||||||
|
k_tiles = (K + block_k - 1) // block_k
|
||||||
|
|
||||||
|
As = torch.rand(M, k_tiles, dtype=torch.float32, device="cuda") * factor_for_scale
|
||||||
|
Bs = (
|
||||||
|
torch.rand(n_tiles, k_tiles, dtype=torch.float32, device="cuda")
|
||||||
|
* factor_for_scale
|
||||||
|
)
|
||||||
|
|
||||||
|
best_config = None
|
||||||
|
best_time = float("inf")
|
||||||
|
for config in tqdm(search_space):
|
||||||
|
try:
|
||||||
|
kernel_time = benchmark_config(
|
||||||
|
A,
|
||||||
|
B,
|
||||||
|
As,
|
||||||
|
Bs,
|
||||||
|
block_size,
|
||||||
|
config,
|
||||||
|
out_dtype,
|
||||||
|
num_iters=10,
|
||||||
|
)
|
||||||
|
except triton.runtime.autotuner.OutOfResources:
|
||||||
|
# Some configurations may be invalid and fail to compile.
|
||||||
|
continue
|
||||||
|
|
||||||
|
if kernel_time < best_time:
|
||||||
|
best_time = kernel_time
|
||||||
|
best_config = config
|
||||||
|
now = datetime.now()
|
||||||
|
print(f"{now.ctime()}] Completed tuning for batch_size={M}")
|
||||||
|
assert best_config is not None
|
||||||
|
return best_config
|
||||||
|
|
||||||
|
|
||||||
|
def save_configs(
|
||||||
|
N,
|
||||||
|
K,
|
||||||
|
block_n,
|
||||||
|
block_k,
|
||||||
|
configs,
|
||||||
|
save_path,
|
||||||
|
input_type="fp8",
|
||||||
|
) -> None:
|
||||||
|
os.makedirs(save_path, exist_ok=True)
|
||||||
|
device_name = get_device_name().replace(" ", "_")
|
||||||
|
json_file_name = f"N={N},K={K},device_name={device_name},dtype={input_type}_w8a8,block_shape=[{block_n}, {block_k}].json"
|
||||||
|
|
||||||
|
config_file_path = os.path.join(save_path, json_file_name)
|
||||||
|
print(f"Writing best config to {config_file_path}...")
|
||||||
|
|
||||||
|
with open(config_file_path, "w") as f:
|
||||||
|
json.dump(configs, f, indent=4)
|
||||||
|
f.write("\n")
|
||||||
|
|
||||||
|
|
||||||
|
def get_available_gpu_count():
|
||||||
|
"""Get the number of available GPUs."""
|
||||||
|
return torch.cuda.device_count()
|
||||||
|
|
||||||
|
|
||||||
|
def tune_on_gpu(args_dict):
|
||||||
|
"""Run tuning on a specific GPU."""
|
||||||
|
gpu_id = args_dict["gpu_id"]
|
||||||
|
batch_sizes = args_dict["batch_sizes"]
|
||||||
|
weight_shapes = args_dict["weight_shapes"]
|
||||||
|
args = args_dict["args"]
|
||||||
|
|
||||||
|
torch.cuda.set_device(gpu_id)
|
||||||
|
print(f"Starting tuning on GPU {gpu_id} with batch sizes {batch_sizes}")
|
||||||
|
|
||||||
|
block_n = args.block_n
|
||||||
|
block_k = args.block_k
|
||||||
|
out_dtype = DTYPE_MAP[args.out_dtype]
|
||||||
|
save_path = args.save_path
|
||||||
|
input_type = args.input_type
|
||||||
|
|
||||||
|
search_space = get_configs_compute_bound()
|
||||||
|
search_space = [
|
||||||
|
config for config in search_space if block_k % config["BLOCK_SIZE_K"] == 0
|
||||||
|
]
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
|
results = {}
|
||||||
|
for shape in tqdm(weight_shapes, desc=f"GPU {gpu_id} - Shapes"):
|
||||||
|
N, K = shape[0], shape[1]
|
||||||
|
print(f"[GPU {gpu_id}] Tune for weight shape of `N: {N}, K: {K}`")
|
||||||
|
benchmark_results = [
|
||||||
|
tune(
|
||||||
|
batch_size,
|
||||||
|
N,
|
||||||
|
K,
|
||||||
|
[block_n, block_k],
|
||||||
|
out_dtype,
|
||||||
|
search_space,
|
||||||
|
input_type,
|
||||||
|
)
|
||||||
|
for batch_size in tqdm(batch_sizes, desc=f"GPU {gpu_id} - Batch sizes")
|
||||||
|
]
|
||||||
|
best_configs = {M: config for M, config in zip(batch_sizes, benchmark_results)}
|
||||||
|
save_configs(N, K, block_n, block_k, best_configs, save_path, input_type)
|
||||||
|
|
||||||
|
end = time.time()
|
||||||
|
print(f"Tuning on GPU {gpu_id} took {end - start:.2f} seconds")
|
||||||
|
|
||||||
|
|
||||||
|
def distribute_batch_sizes(batch_sizes, num_gpus):
|
||||||
|
"""Distribute batch sizes across available GPUs."""
|
||||||
|
batches_per_gpu = []
|
||||||
|
for i in range(num_gpus):
|
||||||
|
start_idx = i * len(batch_sizes) // num_gpus
|
||||||
|
end_idx = (i + 1) * len(batch_sizes) // num_gpus
|
||||||
|
batches_per_gpu.append(batch_sizes[start_idx:end_idx])
|
||||||
|
return batches_per_gpu
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
print(args)
|
||||||
|
|
||||||
|
num_gpus = get_available_gpu_count()
|
||||||
|
if num_gpus == 0:
|
||||||
|
raise RuntimeError("No GPU available for tuning")
|
||||||
|
print(f"Found {num_gpus} GPUs for parallel tuning")
|
||||||
|
|
||||||
|
torch.cuda.init()
|
||||||
|
|
||||||
|
if args.batch_size is None:
|
||||||
|
batch_sizes = [
|
||||||
|
1,
|
||||||
|
2,
|
||||||
|
4,
|
||||||
|
8,
|
||||||
|
16,
|
||||||
|
24,
|
||||||
|
32,
|
||||||
|
48,
|
||||||
|
64,
|
||||||
|
96,
|
||||||
|
128,
|
||||||
|
256,
|
||||||
|
512,
|
||||||
|
1024,
|
||||||
|
1536,
|
||||||
|
2048,
|
||||||
|
3072,
|
||||||
|
4096,
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
batch_sizes = [args.batch_size]
|
||||||
|
num_gpus = 1 # If only one batch size, use only one GPU
|
||||||
|
|
||||||
|
weight_shapes = get_weight_shapes(args.tp_size)
|
||||||
|
|
||||||
|
batches_per_gpu = distribute_batch_sizes(batch_sizes, num_gpus)
|
||||||
|
|
||||||
|
process_args = []
|
||||||
|
for gpu_id in range(num_gpus):
|
||||||
|
process_args.append(
|
||||||
|
{
|
||||||
|
"gpu_id": gpu_id,
|
||||||
|
"batch_sizes": batches_per_gpu[gpu_id],
|
||||||
|
"weight_shapes": weight_shapes, # Each GPU processes all weight shapes
|
||||||
|
"args": args,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
ctx = mp.get_context("spawn")
|
||||||
|
with ctx.Pool(num_gpus) as pool:
|
||||||
|
pool.map(tune_on_gpu, process_args)
|
||||||
|
|
||||||
|
print("Multi-GPU tuning completed")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
parser.add_argument("--tp-size", "-tp", type=int, default=8)
|
||||||
|
parser.add_argument(
|
||||||
|
"--input-type", type=str, choices=["fp8", "int8"], default="fp8"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--out-dtype",
|
||||||
|
type=str,
|
||||||
|
choices=["float32", "float16", "bfloat16", "half"],
|
||||||
|
default="float16",
|
||||||
|
)
|
||||||
|
parser.add_argument("--block-n", type=int, default=128)
|
||||||
|
parser.add_argument("--block-k", type=int, default=128)
|
||||||
|
parser.add_argument("--batch-size", type=int, required=False)
|
||||||
|
parser.add_argument(
|
||||||
|
"--save-path", type=str, default="python/sglang/srt/layers/quantization/configs"
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
main(args)
|
||||||
|
|
@ -0,0 +1,230 @@
|
||||||
|
import itertools
|
||||||
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import triton
|
||||||
|
from flashinfer.norm import fused_add_rmsnorm, rmsnorm
|
||||||
|
from torch import nn
|
||||||
|
from vllm import _custom_ops as vllm_ops
|
||||||
|
|
||||||
|
|
||||||
|
class HuggingFaceRMSNorm(nn.Module):
|
||||||
|
def __init__(self, hidden_size: int, eps: float = 1e-6) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||||
|
self.variance_epsilon = eps
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
residual: Optional[torch.Tensor] = None,
|
||||||
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||||
|
orig_dtype = x.dtype
|
||||||
|
x = x.to(torch.float32)
|
||||||
|
if residual is not None:
|
||||||
|
x = x + residual.to(torch.float32)
|
||||||
|
residual = x.to(orig_dtype)
|
||||||
|
|
||||||
|
variance = x.pow(2).mean(dim=-1, keepdim=True)
|
||||||
|
x = x * torch.rsqrt(variance + self.variance_epsilon)
|
||||||
|
x = x.to(orig_dtype) * self.weight
|
||||||
|
if residual is None:
|
||||||
|
return x
|
||||||
|
else:
|
||||||
|
return x, residual
|
||||||
|
|
||||||
|
|
||||||
|
def rmsnorm_naive(
|
||||||
|
x: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
residual: Optional[torch.Tensor] = None,
|
||||||
|
eps: float = 1e-6,
|
||||||
|
):
|
||||||
|
naive_norm = HuggingFaceRMSNorm(x.shape[-1], eps=eps)
|
||||||
|
naive_norm.weight = nn.Parameter(weight)
|
||||||
|
naive_norm = naive_norm.to(x.device)
|
||||||
|
|
||||||
|
orig_shape = x.shape
|
||||||
|
x = x.view(-1, x.shape[-1])
|
||||||
|
if residual is not None:
|
||||||
|
residual = residual.view(-1, residual.shape[-1])
|
||||||
|
|
||||||
|
output = naive_norm(x, residual)
|
||||||
|
|
||||||
|
if isinstance(output, tuple):
|
||||||
|
output = (output[0].view(orig_shape), output[1].view(orig_shape))
|
||||||
|
else:
|
||||||
|
output = output.view(orig_shape)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def rmsnorm_flashinfer(
|
||||||
|
x: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
residual: Optional[torch.Tensor] = None,
|
||||||
|
eps: float = 1e-6,
|
||||||
|
):
|
||||||
|
orig_shape = x.shape
|
||||||
|
x = x.view(-1, x.shape[-1])
|
||||||
|
if residual is not None:
|
||||||
|
residual = residual.view(-1, residual.shape[-1])
|
||||||
|
|
||||||
|
if residual is not None:
|
||||||
|
fused_add_rmsnorm(x, residual, weight, eps)
|
||||||
|
output = (x, residual)
|
||||||
|
else:
|
||||||
|
output = rmsnorm(x, weight, eps)
|
||||||
|
|
||||||
|
if isinstance(output, tuple):
|
||||||
|
output = (output[0].view(orig_shape), output[1].view(orig_shape))
|
||||||
|
else:
|
||||||
|
output = output.view(orig_shape)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def rmsnorm_vllm(
|
||||||
|
x: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
residual: Optional[torch.Tensor] = None,
|
||||||
|
eps: float = 1e-6,
|
||||||
|
):
|
||||||
|
orig_shape = x.shape
|
||||||
|
x = x.view(-1, x.shape[-1])
|
||||||
|
if residual is not None:
|
||||||
|
residual = residual.view(-1, residual.shape[-1])
|
||||||
|
|
||||||
|
if residual is not None:
|
||||||
|
vllm_ops.fused_add_rms_norm(x, residual, weight, eps)
|
||||||
|
output = (x, residual)
|
||||||
|
else:
|
||||||
|
out = torch.empty_like(x)
|
||||||
|
vllm_ops.rms_norm(out, x, weight, eps)
|
||||||
|
output = out
|
||||||
|
|
||||||
|
if isinstance(output, tuple):
|
||||||
|
output = (output[0].view(orig_shape), output[1].view(orig_shape))
|
||||||
|
else:
|
||||||
|
output = output.view(orig_shape)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_diff(batch_size, seq_len, hidden_size, use_residual=True):
|
||||||
|
dtype = torch.bfloat16
|
||||||
|
x = torch.randn(batch_size, seq_len, hidden_size, dtype=dtype, device="cuda")
|
||||||
|
weight = torch.ones(hidden_size, dtype=dtype, device="cuda")
|
||||||
|
residual = torch.randn_like(x) if use_residual else None
|
||||||
|
|
||||||
|
output_naive = rmsnorm_naive(
|
||||||
|
x.clone(), weight, residual.clone() if residual is not None else None
|
||||||
|
)
|
||||||
|
output_flashinfer = rmsnorm_flashinfer(
|
||||||
|
x.clone(), weight, residual.clone() if residual is not None else None
|
||||||
|
)
|
||||||
|
output_vllm = rmsnorm_vllm(
|
||||||
|
x.clone(), weight, residual.clone() if residual is not None else None
|
||||||
|
)
|
||||||
|
|
||||||
|
if use_residual:
|
||||||
|
output_naive = output_naive[0]
|
||||||
|
output_flashinfer = output_flashinfer[0]
|
||||||
|
output_vllm = output_vllm[0]
|
||||||
|
|
||||||
|
print(f"Naive output={output_naive}")
|
||||||
|
print(f"FlashInfer output={output_flashinfer}")
|
||||||
|
print(f"VLLM output={output_vllm}")
|
||||||
|
|
||||||
|
if torch.allclose(
|
||||||
|
output_naive, output_flashinfer, atol=1e-2, rtol=1e-2
|
||||||
|
) and torch.allclose(output_naive, output_vllm, atol=1e-2, rtol=1e-2):
|
||||||
|
print("✅ All implementations match")
|
||||||
|
else:
|
||||||
|
print("❌ Implementations differ")
|
||||||
|
|
||||||
|
|
||||||
|
batch_size_range = [2**i for i in range(0, 7, 2)]
|
||||||
|
seq_length_range = [2**i for i in range(6, 11, 1)]
|
||||||
|
head_num_range = [32, 48]
|
||||||
|
configs = list(itertools.product(head_num_range, batch_size_range, seq_length_range))
|
||||||
|
|
||||||
|
|
||||||
|
def get_benchmark(use_residual):
|
||||||
|
@triton.testing.perf_report(
|
||||||
|
triton.testing.Benchmark(
|
||||||
|
x_names=["head_num", "batch_size", "seq_len"],
|
||||||
|
x_vals=[list(_) for _ in configs],
|
||||||
|
line_arg="provider",
|
||||||
|
line_vals=["huggingface", "flashinfer", "vllm"],
|
||||||
|
line_names=["HuggingFace", "FlashInfer", "vLLM"],
|
||||||
|
styles=[("blue", "-"), ("green", "-"), ("red", "-")],
|
||||||
|
ylabel="us",
|
||||||
|
plot_name=f"rmsnorm-performance-{'with' if use_residual else 'without'}-residual",
|
||||||
|
args={},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
def benchmark(head_num, batch_size, seq_len, provider):
|
||||||
|
dtype = torch.bfloat16
|
||||||
|
hidden_size = head_num * 128 # assuming head_dim = 128
|
||||||
|
|
||||||
|
x = torch.randn(batch_size, seq_len, hidden_size, dtype=dtype, device="cuda")
|
||||||
|
weight = torch.ones(hidden_size, dtype=dtype, device="cuda")
|
||||||
|
residual = torch.randn_like(x) if use_residual else None
|
||||||
|
|
||||||
|
quantiles = [0.5, 0.2, 0.8]
|
||||||
|
|
||||||
|
if provider == "huggingface":
|
||||||
|
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||||
|
lambda: rmsnorm_naive(
|
||||||
|
x.clone(),
|
||||||
|
weight,
|
||||||
|
residual.clone() if residual is not None else None,
|
||||||
|
),
|
||||||
|
quantiles=quantiles,
|
||||||
|
)
|
||||||
|
elif provider == "flashinfer":
|
||||||
|
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||||
|
lambda: rmsnorm_flashinfer(
|
||||||
|
x.clone(),
|
||||||
|
weight,
|
||||||
|
residual.clone() if residual is not None else None,
|
||||||
|
),
|
||||||
|
quantiles=quantiles,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||||
|
lambda: rmsnorm_vllm(
|
||||||
|
x.clone(),
|
||||||
|
weight,
|
||||||
|
residual.clone() if residual is not None else None,
|
||||||
|
),
|
||||||
|
quantiles=quantiles,
|
||||||
|
)
|
||||||
|
|
||||||
|
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
||||||
|
|
||||||
|
return benchmark
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--use_residual", action="store_true", help="Whether to use residual connection"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--save_path",
|
||||||
|
type=str,
|
||||||
|
default="./configs/benchmark_ops/rmsnorm/",
|
||||||
|
help="Path to save rmsnorm benchmark results",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Run correctness test
|
||||||
|
calculate_diff(
|
||||||
|
batch_size=4, seq_len=128, hidden_size=4096, use_residual=args.use_residual
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get the benchmark function with proper use_residual setting
|
||||||
|
benchmark = get_benchmark(args.use_residual)
|
||||||
|
# Run performance benchmark
|
||||||
|
benchmark.run(print_data=True, save_path=args.save_path)
|
||||||
|
|
@ -0,0 +1,342 @@
|
||||||
|
import itertools
|
||||||
|
import os
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def write_req_to_token_pool_triton(
|
||||||
|
req_to_token_ptr, # [max_batch, max_context_len]
|
||||||
|
req_pool_indices,
|
||||||
|
pre_lens,
|
||||||
|
seq_lens,
|
||||||
|
extend_lens,
|
||||||
|
out_cache_loc,
|
||||||
|
req_to_token_ptr_stride: tl.constexpr,
|
||||||
|
):
|
||||||
|
BLOCK_SIZE: tl.constexpr = 512
|
||||||
|
pid = tl.program_id(0)
|
||||||
|
|
||||||
|
req_pool_index = tl.load(req_pool_indices + pid)
|
||||||
|
pre_len = tl.load(pre_lens + pid)
|
||||||
|
seq_len = tl.load(seq_lens + pid)
|
||||||
|
|
||||||
|
# TODO: optimize this?
|
||||||
|
cumsum_start = 0
|
||||||
|
for i in range(pid):
|
||||||
|
cumsum_start += tl.load(extend_lens + i)
|
||||||
|
|
||||||
|
num_loop = tl.cdiv(seq_len - pre_len, BLOCK_SIZE)
|
||||||
|
for i in range(num_loop):
|
||||||
|
offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
|
||||||
|
mask = offset < (seq_len - pre_len)
|
||||||
|
value = tl.load(out_cache_loc + cumsum_start + offset, mask=mask)
|
||||||
|
tl.store(
|
||||||
|
req_to_token_ptr
|
||||||
|
+ req_pool_index * req_to_token_ptr_stride
|
||||||
|
+ offset
|
||||||
|
+ pre_len,
|
||||||
|
value,
|
||||||
|
mask=mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def write_req_to_token_pool_triton_optimize(
|
||||||
|
req_to_token_ptr, # [max_batch, max_context_len]
|
||||||
|
req_pool_indices,
|
||||||
|
pre_lens,
|
||||||
|
seq_lens,
|
||||||
|
extend_lens,
|
||||||
|
out_cache_loc,
|
||||||
|
req_to_token_ptr_stride: tl.constexpr,
|
||||||
|
BLOCK_SIZE: tl.constexpr,
|
||||||
|
):
|
||||||
|
pid_batch = tl.program_id(0)
|
||||||
|
pid_token = tl.program_id(1)
|
||||||
|
|
||||||
|
req_pool_index = tl.load(req_pool_indices + pid_batch)
|
||||||
|
pre_len = tl.load(pre_lens + pid_batch)
|
||||||
|
seq_len = tl.load(seq_lens + pid_batch)
|
||||||
|
extend_len = seq_len - pre_len
|
||||||
|
|
||||||
|
cumsum_start = 0
|
||||||
|
for i in range(pid_batch):
|
||||||
|
cumsum_start += tl.load(extend_lens + i)
|
||||||
|
|
||||||
|
token_start = pid_token * BLOCK_SIZE
|
||||||
|
|
||||||
|
offset = tl.arange(0, BLOCK_SIZE)
|
||||||
|
actual_offset = token_start + offset
|
||||||
|
mask = actual_offset < extend_len
|
||||||
|
|
||||||
|
src_ptr = out_cache_loc + cumsum_start + actual_offset
|
||||||
|
src_ptr = tl.max_contiguous(tl.multiple_of(src_ptr, BLOCK_SIZE), BLOCK_SIZE)
|
||||||
|
value = tl.load(src_ptr, mask=mask)
|
||||||
|
dst_ptr = (
|
||||||
|
req_to_token_ptr
|
||||||
|
+ req_pool_index * req_to_token_ptr_stride
|
||||||
|
+ actual_offset
|
||||||
|
+ pre_len
|
||||||
|
)
|
||||||
|
dst_ptr = tl.max_contiguous(tl.multiple_of(dst_ptr, BLOCK_SIZE), BLOCK_SIZE)
|
||||||
|
|
||||||
|
tl.store(dst_ptr, value, mask=mask)
|
||||||
|
|
||||||
|
|
||||||
|
def write_req_to_token_pool_reference(
|
||||||
|
req_to_token: torch.Tensor,
|
||||||
|
req_pool_indices: torch.Tensor,
|
||||||
|
pre_lens: torch.Tensor,
|
||||||
|
seq_lens: torch.Tensor,
|
||||||
|
extend_lens: torch.Tensor,
|
||||||
|
out_cache_loc: torch.Tensor,
|
||||||
|
) -> None:
|
||||||
|
"""Reference implementation using PyTorch"""
|
||||||
|
for i in range(len(req_pool_indices)):
|
||||||
|
req_pool_idx = req_pool_indices[i].item()
|
||||||
|
pre_len = pre_lens[i].item()
|
||||||
|
seq_len = seq_lens[i].item()
|
||||||
|
extend_len = extend_lens[i].item()
|
||||||
|
|
||||||
|
cumsum_start = sum(extend_lens[:i].tolist())
|
||||||
|
|
||||||
|
# Copy values from out_cache_loc to req_to_token
|
||||||
|
req_to_token[req_pool_idx, pre_len:seq_len] = out_cache_loc[
|
||||||
|
cumsum_start : cumsum_start + extend_len
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_write_req_to_token_pool():
|
||||||
|
max_batch = 4097
|
||||||
|
max_context_len = 6148
|
||||||
|
batch_size = 1
|
||||||
|
extend_len = 14
|
||||||
|
|
||||||
|
# Initialize input tensors
|
||||||
|
req_to_token = torch.zeros(
|
||||||
|
(max_batch, max_context_len), dtype=torch.int32, device="cuda"
|
||||||
|
)
|
||||||
|
req_pool_indices = torch.tensor([42], dtype=torch.int32, device="cuda")
|
||||||
|
pre_lens = torch.tensor([8], dtype=torch.int32, device="cuda")
|
||||||
|
seq_lens = torch.tensor([22], dtype=torch.int32, device="cuda")
|
||||||
|
extend_lens = torch.tensor([extend_len], dtype=torch.int32, device="cuda")
|
||||||
|
out_cache_loc = torch.arange(extend_len, dtype=torch.int32, device="cuda")
|
||||||
|
|
||||||
|
# Create copies for reference implementation
|
||||||
|
req_to_token_ref = req_to_token.clone()
|
||||||
|
req_to_token_opt = req_to_token.clone()
|
||||||
|
|
||||||
|
# Run original triton kernel
|
||||||
|
write_req_to_token_pool_triton[(batch_size,)](
|
||||||
|
req_to_token,
|
||||||
|
req_pool_indices,
|
||||||
|
pre_lens,
|
||||||
|
seq_lens,
|
||||||
|
extend_lens,
|
||||||
|
out_cache_loc,
|
||||||
|
max_context_len,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run optimized triton kernel
|
||||||
|
def grid(batch_size, extend_len):
|
||||||
|
num_token_blocks = triton.cdiv(extend_len, 512)
|
||||||
|
return (batch_size, num_token_blocks)
|
||||||
|
|
||||||
|
write_req_to_token_pool_triton_optimize[grid(batch_size, extend_len)](
|
||||||
|
req_to_token_opt,
|
||||||
|
req_pool_indices,
|
||||||
|
pre_lens,
|
||||||
|
seq_lens,
|
||||||
|
extend_lens,
|
||||||
|
out_cache_loc,
|
||||||
|
max_context_len,
|
||||||
|
BLOCK_SIZE=512,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run reference implementation
|
||||||
|
write_req_to_token_pool_reference(
|
||||||
|
req_to_token_ref,
|
||||||
|
req_pool_indices,
|
||||||
|
pre_lens,
|
||||||
|
seq_lens,
|
||||||
|
extend_lens,
|
||||||
|
out_cache_loc,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compare results
|
||||||
|
torch.testing.assert_close(req_to_token, req_to_token_ref)
|
||||||
|
torch.testing.assert_close(req_to_token_opt, req_to_token_ref)
|
||||||
|
|
||||||
|
# Test case 2: batch size > 1
|
||||||
|
batch_size = 3
|
||||||
|
extend_lens_list = [14, 20, 30]
|
||||||
|
total_extend_len = sum(extend_lens_list)
|
||||||
|
|
||||||
|
req_to_token = torch.zeros(
|
||||||
|
(max_batch, max_context_len), dtype=torch.int32, device="cuda"
|
||||||
|
)
|
||||||
|
req_pool_indices = torch.tensor([42, 100, 200], dtype=torch.int32, device="cuda")
|
||||||
|
pre_lens = torch.tensor([8, 10, 15], dtype=torch.int32, device="cuda")
|
||||||
|
seq_lens = torch.tensor([22, 30, 45], dtype=torch.int32, device="cuda")
|
||||||
|
extend_lens = torch.tensor(extend_lens_list, dtype=torch.int32, device="cuda")
|
||||||
|
out_cache_loc = torch.arange(total_extend_len, dtype=torch.int32, device="cuda")
|
||||||
|
|
||||||
|
req_to_token_ref = req_to_token.clone()
|
||||||
|
req_to_token_opt = req_to_token.clone()
|
||||||
|
|
||||||
|
# Run original triton kernel
|
||||||
|
write_req_to_token_pool_triton[(batch_size,)](
|
||||||
|
req_to_token,
|
||||||
|
req_pool_indices,
|
||||||
|
pre_lens,
|
||||||
|
seq_lens,
|
||||||
|
extend_lens,
|
||||||
|
out_cache_loc,
|
||||||
|
max_context_len,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run optimized triton kernel
|
||||||
|
max_extend_len = max(extend_lens_list)
|
||||||
|
write_req_to_token_pool_triton_optimize[grid(batch_size, max_extend_len)](
|
||||||
|
req_to_token_opt,
|
||||||
|
req_pool_indices,
|
||||||
|
pre_lens,
|
||||||
|
seq_lens,
|
||||||
|
extend_lens,
|
||||||
|
out_cache_loc,
|
||||||
|
max_context_len,
|
||||||
|
BLOCK_SIZE=512,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run reference implementation
|
||||||
|
write_req_to_token_pool_reference(
|
||||||
|
req_to_token_ref,
|
||||||
|
req_pool_indices,
|
||||||
|
pre_lens,
|
||||||
|
seq_lens,
|
||||||
|
extend_lens,
|
||||||
|
out_cache_loc,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compare results
|
||||||
|
torch.testing.assert_close(req_to_token, req_to_token_ref)
|
||||||
|
torch.testing.assert_close(req_to_token_opt, req_to_token_ref)
|
||||||
|
|
||||||
|
|
||||||
|
def get_benchmark():
|
||||||
|
batch_sizes = [1, 2, 4, 8, 16, 32, 64, 128]
|
||||||
|
extend_lens = [32, 64, 128, 256, 512, 1024, 2048, 4096, 8192]
|
||||||
|
configs = list(itertools.product(batch_sizes, extend_lens))
|
||||||
|
|
||||||
|
@triton.testing.perf_report(
|
||||||
|
triton.testing.Benchmark(
|
||||||
|
x_names=["batch_size", "extend_len"],
|
||||||
|
x_vals=configs,
|
||||||
|
line_arg="provider",
|
||||||
|
line_vals=["reference", "triton", "triton_optimize"],
|
||||||
|
line_names=["PyTorch", "Triton", "Triton Optimized"],
|
||||||
|
styles=[("blue", "-"), ("green", "-"), ("red", "-")],
|
||||||
|
ylabel="us",
|
||||||
|
plot_name="write-req-to-token-pool-performance",
|
||||||
|
args={},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
def benchmark(batch_size, extend_len, provider):
|
||||||
|
max_batch = 256
|
||||||
|
max_context_len = 16384
|
||||||
|
|
||||||
|
extend_lens_list = [extend_len] * batch_size
|
||||||
|
total_extend_len = sum(extend_lens_list)
|
||||||
|
|
||||||
|
req_to_token = torch.zeros(
|
||||||
|
(max_batch, max_context_len), dtype=torch.int32, device="cuda"
|
||||||
|
)
|
||||||
|
req_pool_indices = torch.arange(batch_size, dtype=torch.int32, device="cuda")
|
||||||
|
pre_lens = torch.ones(batch_size, dtype=torch.int32, device="cuda") * 8
|
||||||
|
seq_lens = pre_lens + extend_len
|
||||||
|
extend_lens = torch.tensor(extend_lens_list, dtype=torch.int32, device="cuda")
|
||||||
|
out_cache_loc = torch.arange(total_extend_len, dtype=torch.int32, device="cuda")
|
||||||
|
|
||||||
|
quantiles = [0.5, 0.2, 0.8]
|
||||||
|
|
||||||
|
if provider == "reference":
|
||||||
|
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||||
|
lambda: write_req_to_token_pool_reference(
|
||||||
|
req_to_token.clone(),
|
||||||
|
req_pool_indices,
|
||||||
|
pre_lens,
|
||||||
|
seq_lens,
|
||||||
|
extend_lens,
|
||||||
|
out_cache_loc,
|
||||||
|
),
|
||||||
|
quantiles=quantiles,
|
||||||
|
)
|
||||||
|
elif provider == "triton":
|
||||||
|
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||||
|
lambda: write_req_to_token_pool_triton[(batch_size,)](
|
||||||
|
req_to_token.clone(),
|
||||||
|
req_pool_indices,
|
||||||
|
pre_lens,
|
||||||
|
seq_lens,
|
||||||
|
extend_lens,
|
||||||
|
out_cache_loc,
|
||||||
|
max_context_len,
|
||||||
|
),
|
||||||
|
quantiles=quantiles,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
|
||||||
|
def run_optimized():
|
||||||
|
block_size = 128 if extend_len <= 1024 else 512
|
||||||
|
grid_config = (batch_size, triton.cdiv(extend_len, block_size))
|
||||||
|
write_req_to_token_pool_triton_optimize[grid_config](
|
||||||
|
req_to_token.clone(),
|
||||||
|
req_pool_indices,
|
||||||
|
pre_lens,
|
||||||
|
seq_lens,
|
||||||
|
extend_lens,
|
||||||
|
out_cache_loc,
|
||||||
|
max_context_len,
|
||||||
|
BLOCK_SIZE=block_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||||
|
run_optimized, quantiles=quantiles
|
||||||
|
)
|
||||||
|
|
||||||
|
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
||||||
|
|
||||||
|
return benchmark
|
||||||
|
|
||||||
|
|
||||||
|
def run_benchmark(save_path: str = "./configs/benchmark_ops/write_req_to_token_pool/"):
|
||||||
|
"""Run benchmark and save results"""
|
||||||
|
|
||||||
|
# Ensure save path exists
|
||||||
|
os.makedirs(save_path, exist_ok=True)
|
||||||
|
|
||||||
|
# Run correctness test
|
||||||
|
test_write_req_to_token_pool()
|
||||||
|
print("Correctness test passed!")
|
||||||
|
|
||||||
|
# Run performance test
|
||||||
|
benchmark = get_benchmark()
|
||||||
|
benchmark.run(print_data=True, save_path=save_path)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--save_path",
|
||||||
|
type=str,
|
||||||
|
default="./configs/benchmark_ops/write_req_to_token_pool/",
|
||||||
|
help="Path to save benchmark results",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
run_benchmark(args.save_path)
|
||||||
|
|
@ -0,0 +1,37 @@
|
||||||
|
## Download data
|
||||||
|
|
||||||
|
```
|
||||||
|
wget https://raw.githubusercontent.com/merrymercy/merrymercy.github.io/master/files/random_words.json
|
||||||
|
python3 gen_data.py --number 1000
|
||||||
|
```
|
||||||
|
|
||||||
|
## Run benchmark
|
||||||
|
|
||||||
|
### Benchmark sglang
|
||||||
|
```
|
||||||
|
python3 -m sglang.launch_server --model-path codellama/CodeLlama-7b-hf --port 30000
|
||||||
|
```
|
||||||
|
|
||||||
|
```
|
||||||
|
python3 bench_sglang.py --src-index 600 --num-q 50 --parallel 1
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
###
|
||||||
|
|
||||||
|
```
|
||||||
|
# original
|
||||||
|
Accuracy: 0.940, latency: 332.83 s
|
||||||
|
|
||||||
|
# parallel encoding (no_adjust, offset = 1000)
|
||||||
|
Accuracy: 0.760, latency: 238.46 s
|
||||||
|
|
||||||
|
# parallel encoding (no_adjust, offset = 3000)
|
||||||
|
Accuracy: 0.760, latency: 238.46 s
|
||||||
|
|
||||||
|
# parallel encoding (no_adjust, offset = 0)
|
||||||
|
Accuracy: 0.520, latency: 238.46 s
|
||||||
|
|
||||||
|
# parallel encoding (adjust_cache)
|
||||||
|
Accuracy: 0.460, latency: 257.66 s
|
||||||
|
```
|
||||||
|
|
@ -0,0 +1,149 @@
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
import time
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import sglang as sgl
|
||||||
|
from sglang.test.test_utils import (
|
||||||
|
add_common_sglang_args_and_parse,
|
||||||
|
select_sglang_backend,
|
||||||
|
)
|
||||||
|
from sglang.utils import dump_state_text
|
||||||
|
|
||||||
|
|
||||||
|
@sgl.function
|
||||||
|
def line_retrieval(s, prefix, suffix, body_0, body_1, body_2, body_3):
|
||||||
|
s += prefix + "\n"
|
||||||
|
|
||||||
|
contexts = [body_0, body_1, body_2, body_3]
|
||||||
|
position_ids_offset = [i * 1000 for i in range(len(contexts))]
|
||||||
|
forks = s.fork(len(contexts), position_ids_offset)
|
||||||
|
forks += lambda i: contexts[i] + "\n"
|
||||||
|
forks.join(mode="concate_and_append")
|
||||||
|
|
||||||
|
s += "\n" + suffix
|
||||||
|
s += sgl.gen("answer", max_tokens=16)
|
||||||
|
|
||||||
|
|
||||||
|
def eval_model(args, line_obj, num_hoops, src_indices, dst_percents):
|
||||||
|
arguments = []
|
||||||
|
labels = []
|
||||||
|
sum_src_indices = []
|
||||||
|
sum_dst_indices = []
|
||||||
|
|
||||||
|
for i in range(len(src_indices)):
|
||||||
|
for j in range(len(dst_percents)):
|
||||||
|
src_index = src_indices[i]
|
||||||
|
dst_percent = dst_percents[j]
|
||||||
|
|
||||||
|
query_indices = line_obj["group_by_num_hoops"][str(num_hoops)]
|
||||||
|
query_indices = [
|
||||||
|
q
|
||||||
|
for q in query_indices
|
||||||
|
if all(l <= src_index for l in line_obj["links"][q]) and q < src_index
|
||||||
|
]
|
||||||
|
dst_index = query_indices[
|
||||||
|
min(int(len(query_indices) * dst_percent), len(query_indices) - 1)
|
||||||
|
]
|
||||||
|
label = line_obj["values"][dst_index]
|
||||||
|
|
||||||
|
body = line_obj["lines"][: src_index + 1]
|
||||||
|
suffix = line_obj["suffix"].replace("???", line_obj["indices"][dst_index])
|
||||||
|
body_part_len = len(body) // 4
|
||||||
|
|
||||||
|
arguments.append(
|
||||||
|
{
|
||||||
|
"prefix": line_obj["prefix"],
|
||||||
|
"body_0": "\n".join(body[:body_part_len]),
|
||||||
|
"body_1": "\n".join(body[body_part_len : 2 * body_part_len]),
|
||||||
|
"body_2": "\n".join(body[2 * body_part_len : 3 * body_part_len]),
|
||||||
|
"body_3": "\n".join(body[3 * body_part_len :]),
|
||||||
|
"suffix": suffix,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
labels.append(label)
|
||||||
|
sum_src_indices.append(src_index)
|
||||||
|
sum_dst_indices.append(dst_index)
|
||||||
|
|
||||||
|
# Select backend
|
||||||
|
backend = select_sglang_backend(args)
|
||||||
|
|
||||||
|
tic = time.time()
|
||||||
|
states = line_retrieval.run_batch(
|
||||||
|
arguments,
|
||||||
|
temperature=0,
|
||||||
|
backend=backend,
|
||||||
|
num_threads=args.parallel,
|
||||||
|
progress_bar=True,
|
||||||
|
)
|
||||||
|
latency = time.time() - tic
|
||||||
|
|
||||||
|
corrects = []
|
||||||
|
for i in range(len(arguments)):
|
||||||
|
output = states[i]["answer"]
|
||||||
|
prompt_len = states[i].get_meta_info("answer").get("prompt_length", -1)
|
||||||
|
label = labels[i]
|
||||||
|
|
||||||
|
# Try all numbers
|
||||||
|
findall = re.findall("\d+", output)
|
||||||
|
if not findall:
|
||||||
|
response_number = output
|
||||||
|
else:
|
||||||
|
for response_number in findall:
|
||||||
|
if response_number == label:
|
||||||
|
break
|
||||||
|
|
||||||
|
correct = response_number == label
|
||||||
|
corrects.append(correct)
|
||||||
|
|
||||||
|
# Log results
|
||||||
|
summary = (
|
||||||
|
f"Line index: {sum_src_indices[i]} -> {sum_dst_indices[i]}, "
|
||||||
|
f"Prompt len: {prompt_len}, "
|
||||||
|
f"Correct: {correct}, "
|
||||||
|
f"Label: {label}, Predicted: {response_number}, "
|
||||||
|
)
|
||||||
|
print(summary)
|
||||||
|
|
||||||
|
accuracy = np.mean(corrects)
|
||||||
|
print(f"Accuracy: {accuracy:.3f}, latency: {latency:.2f} s")
|
||||||
|
|
||||||
|
# Write results
|
||||||
|
dump_state_text(f"tmp_output_{args.backend}.txt", states)
|
||||||
|
|
||||||
|
with open(args.result_file, "a") as fout:
|
||||||
|
value = {
|
||||||
|
"task": "line_retrieval",
|
||||||
|
"backend": args.backend,
|
||||||
|
"num_gpus": 1,
|
||||||
|
"latency": round(latency, 3),
|
||||||
|
"num_requests": len(arguments),
|
||||||
|
"other": {
|
||||||
|
"num_questions": len(arguments),
|
||||||
|
"parallel": args.parallel,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
fout.write(json.dumps(value) + "\n")
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
line_obj = json.load(open(args.data_path, "r"))
|
||||||
|
|
||||||
|
num_hoops = args.num_hoops
|
||||||
|
for src_index in args.src_index:
|
||||||
|
src_indices = [src_index]
|
||||||
|
num_queries = args.num_queries_per_src
|
||||||
|
dst_percents = [i * (1 / (num_queries)) for i in range(num_queries)]
|
||||||
|
eval_model(args, line_obj, num_hoops, src_indices, dst_percents)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--data-path", type=str, default="lines_1000_0.0.json")
|
||||||
|
parser.add_argument("--src-index", type=int, nargs="+", default=[100])
|
||||||
|
parser.add_argument("--num-queries-per-src", type=int, default=10)
|
||||||
|
parser.add_argument("--num-hoops", type=int, default=1)
|
||||||
|
args = add_common_sglang_args_and_parse(parser)
|
||||||
|
main(args)
|
||||||
|
|
@ -0,0 +1,139 @@
|
||||||
|
"""
|
||||||
|
Generate line data for line retrieval task.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python3 gen_data.py --number 1000
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
def generate_lines(random_words, num_lines, redirect_ratio):
|
||||||
|
prefix = "Here is a list of lines, each with its corresponding REGISTER_CONTENT value. Please memorize them. Be prepared to provide the REGISTER_CONTENT value for a specific line index when I ask."
|
||||||
|
suffix = "The list has ended. Please give the final REGISTER_CONTENT value for a specific line after resovling the redirections and references. For example, the REGISTER_CONTENT of Line __idx0__ is __val0__. The REGISTER_CONTENT of Line __idx1__ is __val1__. The REGISTER_CONTENT of Line __idx2__ is __val2__. The REGISTER_CONTENT of Line ??? is"
|
||||||
|
|
||||||
|
# Raw lines
|
||||||
|
visited_indices = set([None])
|
||||||
|
visited_values = set([None])
|
||||||
|
|
||||||
|
lines = []
|
||||||
|
redirects = []
|
||||||
|
indices = []
|
||||||
|
values = []
|
||||||
|
for i in tqdm(range(num_lines)):
|
||||||
|
line_index = None
|
||||||
|
while line_index in visited_indices:
|
||||||
|
line_index = "-".join(np.random.choice(random_words, size=(2,)))
|
||||||
|
visited_indices.add(line_index)
|
||||||
|
|
||||||
|
line_value = np.random.randint(low=0, high=999999)
|
||||||
|
line_value = f"{line_value:06}"
|
||||||
|
|
||||||
|
line = f"Line {line_index}: The REGISTER_CONTENT is {line_value}."
|
||||||
|
lines.append(line)
|
||||||
|
redirects.append(None)
|
||||||
|
indices.append(line_index)
|
||||||
|
values.append(line_value)
|
||||||
|
|
||||||
|
# Add redirect
|
||||||
|
if redirect_ratio > 0:
|
||||||
|
num_redirect_lines = int(len(lines) * redirect_ratio)
|
||||||
|
redirect_indices = np.random.choice(
|
||||||
|
np.arange(len(lines)), size=(num_redirect_lines,), replace=False
|
||||||
|
)
|
||||||
|
for i in redirect_indices:
|
||||||
|
target_idx = np.random.choice(min(i * 2 + 100, num_lines))
|
||||||
|
lines[i] = (
|
||||||
|
f"Line {indices[i]}: The REGISTER_CONTENT is the same as Line {indices[target_idx]}."
|
||||||
|
)
|
||||||
|
redirects[i] = target_idx
|
||||||
|
|
||||||
|
# Build links and find sources
|
||||||
|
links = [[] for _ in range(num_lines)]
|
||||||
|
contains_ring = set()
|
||||||
|
for i in range(num_lines):
|
||||||
|
if redirects[i] is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
tmp_link = []
|
||||||
|
cur = i
|
||||||
|
visited = set()
|
||||||
|
while redirects[cur] is not None:
|
||||||
|
visited.add(cur)
|
||||||
|
tmp_link.append(redirects[cur])
|
||||||
|
cur = redirects[cur]
|
||||||
|
|
||||||
|
if cur in visited:
|
||||||
|
contains_ring.add(i)
|
||||||
|
tmp_link = None
|
||||||
|
break
|
||||||
|
values[i] = values[cur]
|
||||||
|
links[i] = tmp_link
|
||||||
|
|
||||||
|
# Group by num_links
|
||||||
|
group_by_num_hoops = defaultdict(list)
|
||||||
|
for i in range(num_lines):
|
||||||
|
if i in contains_ring:
|
||||||
|
continue
|
||||||
|
group_by_num_hoops[len(links[i]) + 1].append(i)
|
||||||
|
|
||||||
|
keys = sorted(list(group_by_num_hoops.keys()))
|
||||||
|
for num_links in keys:
|
||||||
|
print(f"#links: {num_links}, #lines: {len(group_by_num_hoops[num_links])}")
|
||||||
|
|
||||||
|
# Append few-shot examples
|
||||||
|
hoop1_candidates = list(group_by_num_hoops[1])
|
||||||
|
hoop1_candidate_keys = {c: max([c] + links[c]) for c in hoop1_candidates}
|
||||||
|
hoop1_candidates.sort(key=lambda c: hoop1_candidate_keys[c])
|
||||||
|
hoop2_candidates = list(group_by_num_hoops[2])
|
||||||
|
hoop2_candidate_keys = {c: max([c] + links[c]) for c in hoop2_candidates}
|
||||||
|
hoop2_candidates.sort(key=lambda c: hoop2_candidate_keys[c])
|
||||||
|
|
||||||
|
i = hoop1_candidates[5]
|
||||||
|
suffix = suffix.replace("__idx0__", indices[i]).replace("__val0__", values[i])
|
||||||
|
if len(hoop2_candidates):
|
||||||
|
i = hoop2_candidates[0]
|
||||||
|
suffix = suffix.replace("__idx1__", indices[i]).replace("__val1__", values[i])
|
||||||
|
i = hoop2_candidates[1]
|
||||||
|
suffix = suffix.replace("__idx2__", indices[i]).replace("__val2__", values[i])
|
||||||
|
else:
|
||||||
|
i = hoop1_candidates[1]
|
||||||
|
suffix = suffix.replace("__idx1__", indices[i]).replace("__val1__", values[i])
|
||||||
|
i = hoop1_candidates[10]
|
||||||
|
suffix = suffix.replace("__idx2__", indices[i]).replace("__val2__", values[i])
|
||||||
|
|
||||||
|
obj = {
|
||||||
|
"prefix": prefix,
|
||||||
|
"suffix": suffix,
|
||||||
|
"lines": lines,
|
||||||
|
"indices": indices,
|
||||||
|
"values": values,
|
||||||
|
"links": links,
|
||||||
|
"group_by_num_hoops": group_by_num_hoops,
|
||||||
|
"contains_ring": sorted(list(contains_ring)),
|
||||||
|
}
|
||||||
|
return obj
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--number", type=int)
|
||||||
|
parser.add_argument("--redirect-ratio", type=float, default=0.0)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
num_lines = args.number
|
||||||
|
|
||||||
|
random_words_filename = "random_words.json"
|
||||||
|
random_words = json.load(open(random_words_filename, "r"))
|
||||||
|
|
||||||
|
np.random.seed(42)
|
||||||
|
obj = generate_lines(random_words, num_lines, args.redirect_ratio)
|
||||||
|
|
||||||
|
fout = f"lines_{num_lines}_{args.redirect_ratio:.1f}.json"
|
||||||
|
with open(fout, "w") as fout:
|
||||||
|
json.dump(obj, fout, indent=2)
|
||||||
|
|
@ -0,0 +1,61 @@
|
||||||
|
## Download benchmark images
|
||||||
|
|
||||||
|
```
|
||||||
|
python3 download_images.py
|
||||||
|
```
|
||||||
|
|
||||||
|
image benchmark source: https://huggingface.co/datasets/liuhaotian/llava-bench-in-the-wild
|
||||||
|
|
||||||
|
### Other Dependency
|
||||||
|
```
|
||||||
|
pip3 install "sglang[all]"
|
||||||
|
pip3 install "torch>=2.1.2" "transformers>=4.36" pillow
|
||||||
|
```
|
||||||
|
|
||||||
|
## Run benchmark
|
||||||
|
|
||||||
|
### Benchmark sglang
|
||||||
|
Launch a server
|
||||||
|
```
|
||||||
|
python3 -m sglang.launch_server --model-path liuhaotian/llava-v1.6-vicuna-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --port 30000
|
||||||
|
```
|
||||||
|
|
||||||
|
Run benchmark
|
||||||
|
```
|
||||||
|
# Run with local models
|
||||||
|
python3 bench_sglang.py --num-questions 60
|
||||||
|
|
||||||
|
# Run with OpenAI models
|
||||||
|
python3 bench_sglang.py --num-questions 60 --backend gpt-4-vision-preview
|
||||||
|
```
|
||||||
|
|
||||||
|
### Bench LLaVA original code
|
||||||
|
```
|
||||||
|
git clone git@github.com:haotian-liu/LLaVA.git
|
||||||
|
cd LLaVA
|
||||||
|
git reset --hard 9a26bd1435b4ac42c282757f2c16d34226575e96
|
||||||
|
pip3 install -e .
|
||||||
|
|
||||||
|
cd ~/sglang/benchmark/llava_bench
|
||||||
|
CUDA_VISIBLE_DEVICES=0 bash bench_hf_llava_bench.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
### Benchmark llama.cpp
|
||||||
|
|
||||||
|
```
|
||||||
|
# Install
|
||||||
|
CMAKE_ARGS="-DLLAMA_CUBLAS=on" pip install llama-cpp-python
|
||||||
|
pip install sse_starlette starlette_context pydantic_settings
|
||||||
|
|
||||||
|
# Download weights
|
||||||
|
mkdir -p ~/model_weights/llava-v1.5-7b/
|
||||||
|
wget https://huggingface.co/mys/ggml_llava-v1.5-7b/resolve/main/ggml-model-f16.gguf -O ~/model_weights/llava-v1.5-7b/ggml-model-f16.gguf
|
||||||
|
wget https://huggingface.co/mys/ggml_llava-v1.5-7b/resolve/main/mmproj-model-f16.gguf -O ~/model_weights/llava-v1.5-7b/mmproj-model-f16.gguf
|
||||||
|
```
|
||||||
|
|
||||||
|
```
|
||||||
|
python3 -m llama_cpp.server --model ~/model_weights/llava-v1.5-7b/ggml-model-f16.gguf --clip_model_path ~/model_weights/llava-v1.5-7b/mmproj-model-f16.gguf --chat_format llava-1-5 --port 23000
|
||||||
|
|
||||||
|
OPENAI_BASE_URL=http://localhost:23000/v1 python3 bench_sglang.py --backend gpt-4-vision-preview --num-q 1
|
||||||
|
```
|
||||||
|
|
@ -0,0 +1,9 @@
|
||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
python -m llava.eval.model_vqa \
|
||||||
|
--model-path liuhaotian/llava-v1.5-7b \
|
||||||
|
--question-file ./questions.jsonl \
|
||||||
|
--image-folder ./images \
|
||||||
|
--answers-file ./answers_hf.jsonl \
|
||||||
|
--temperature 0 \
|
||||||
|
--conv-mode vicuna_v1
|
||||||
|
|
@ -0,0 +1,9 @@
|
||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
python -m llava.eval.model_vqa_loader \
|
||||||
|
--model-path liuhaotian/llava-v1.5-7b \
|
||||||
|
--question-file ./mme_pack/llava_mme_bench_replace.jsonl \
|
||||||
|
--image-folder ./mme_pack/MME_Benchmark_release_version \
|
||||||
|
--answers-file ./answers_hf_mme.jsonl \
|
||||||
|
--temperature 0 \
|
||||||
|
--conv-mode vicuna_v1
|
||||||
|
|
@ -0,0 +1,96 @@
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
|
||||||
|
import tqdm
|
||||||
|
|
||||||
|
import sglang as sgl
|
||||||
|
from sglang.test.test_utils import (
|
||||||
|
add_common_sglang_args_and_parse,
|
||||||
|
select_sglang_backend,
|
||||||
|
)
|
||||||
|
from sglang.utils import dump_state_text, read_jsonl
|
||||||
|
|
||||||
|
|
||||||
|
@sgl.function
|
||||||
|
def image_qa(s, image_file, question):
|
||||||
|
s += sgl.user(sgl.image(image_file) + question)
|
||||||
|
s += sgl.assistant(sgl.gen("answer", max_tokens=args.max_tokens))
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
lines = list(read_jsonl(args.question_file))[: args.num_questions]
|
||||||
|
arguments = [
|
||||||
|
{
|
||||||
|
"image_file": os.path.abspath(args.image_folder + "/" + l["image"]),
|
||||||
|
"question": l["text"],
|
||||||
|
}
|
||||||
|
for l in lines
|
||||||
|
]
|
||||||
|
# arguments = [
|
||||||
|
# {"image_file":
|
||||||
|
# Image.open(os.path.abspath(args.image_folder + "/" + l["image"])),
|
||||||
|
# "question": l["text"]} for l in lines
|
||||||
|
# ]
|
||||||
|
|
||||||
|
states = [None] * len(lines)
|
||||||
|
|
||||||
|
# Select backend
|
||||||
|
backend = select_sglang_backend(args)
|
||||||
|
sgl.set_default_backend(backend)
|
||||||
|
|
||||||
|
# Run requests
|
||||||
|
tic = time.time()
|
||||||
|
if args.parallel == 1:
|
||||||
|
for i in tqdm.tqdm(range(len(lines))):
|
||||||
|
image_file = arguments[i]["image_file"]
|
||||||
|
question = arguments[i]["question"]
|
||||||
|
ret = image_qa.run(image_file=image_file, question=question, temperature=0)
|
||||||
|
states[i] = ret
|
||||||
|
else:
|
||||||
|
states = image_qa.run_batch(
|
||||||
|
arguments, temperature=0, num_threads=args.parallel, progress_bar=True
|
||||||
|
)
|
||||||
|
latency = time.time() - tic
|
||||||
|
|
||||||
|
print(f"Latency: {latency:.3f}")
|
||||||
|
|
||||||
|
# Write results
|
||||||
|
dump_state_text(f"tmp_output_{args.backend}.txt", states)
|
||||||
|
|
||||||
|
print(f"Write output to {args.answer_file}")
|
||||||
|
with open(args.answer_file, "w") as fout:
|
||||||
|
for i in range(len(lines)):
|
||||||
|
value = {
|
||||||
|
"question_id": lines[i]["question_id"],
|
||||||
|
"prompt": lines[i]["text"],
|
||||||
|
"text": states[i]["answer"].strip(),
|
||||||
|
"model_id": backend.model_info["model_path"],
|
||||||
|
"answer_id": i,
|
||||||
|
"metadata": {},
|
||||||
|
}
|
||||||
|
fout.write(json.dumps(value) + "\n")
|
||||||
|
|
||||||
|
with open(args.result_file, "a") as fout:
|
||||||
|
value = {
|
||||||
|
"task": "llava_bench",
|
||||||
|
"backend": args.backend,
|
||||||
|
"num_gpus": 1,
|
||||||
|
"latency": round(latency, 3),
|
||||||
|
"num_requests": len(lines),
|
||||||
|
"parallel": args.parallel,
|
||||||
|
}
|
||||||
|
fout.write(json.dumps(value) + "\n")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--question-file", type=str, default="questions.jsonl")
|
||||||
|
parser.add_argument("--answer-file", type=str, default="answers.jsonl")
|
||||||
|
parser.add_argument("--image-folder", type=str, default="./images")
|
||||||
|
parser.add_argument("--temperature", type=float, default=0.0)
|
||||||
|
parser.add_argument("--num-questions", type=int, default=None)
|
||||||
|
parser.add_argument("--max-tokens", type=int, default=768)
|
||||||
|
args = add_common_sglang_args_and_parse(parser)
|
||||||
|
main(args)
|
||||||
|
|
@ -0,0 +1,2 @@
|
||||||
|
MME_FOLDER=./mme_pack
|
||||||
|
python3 bench_sglang.py --num-questions 5000 --question-file $MME_FOLDER/llava_mme_bench_replace.jsonl --answer-file answer_mme.jsonl --image-folder $MME_FOLDER/MME_Benchmark_release_version --max-tokens 4
|
||||||
|
|
@ -0,0 +1,20 @@
|
||||||
|
import os
|
||||||
|
|
||||||
|
# Create the 'images' directory if it doesn't exist
|
||||||
|
if not os.path.exists("images"):
|
||||||
|
os.makedirs("images")
|
||||||
|
|
||||||
|
# Base URL
|
||||||
|
base_url = "https://huggingface.co/datasets/liuhaotian/llava-bench-in-the-wild/resolve/main/images/"
|
||||||
|
|
||||||
|
# Loop through image numbers
|
||||||
|
for i in range(1, 25):
|
||||||
|
# Format the image number with leading zeros
|
||||||
|
image_number = str(i).zfill(3)
|
||||||
|
image_url = base_url + image_number + ".jpg"
|
||||||
|
image_path = "images/" + image_number + ".jpg"
|
||||||
|
|
||||||
|
# Download the image using wget
|
||||||
|
os.system(f"wget -O {image_path} {image_url}")
|
||||||
|
|
||||||
|
print("Download complete.")
|
||||||
|
|
@ -0,0 +1,33 @@
|
||||||
|
## Run benchmark
|
||||||
|
|
||||||
|
### Benchmark sglang
|
||||||
|
```
|
||||||
|
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
|
||||||
|
```
|
||||||
|
|
||||||
|
```
|
||||||
|
python3 bench_sglang.py --num-questions 25 --parallel 8
|
||||||
|
python3 bench_sglang.py --num-questions 16 --parallel 1
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
### Benchmark vllm
|
||||||
|
```
|
||||||
|
python3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000
|
||||||
|
```
|
||||||
|
|
||||||
|
```
|
||||||
|
python3 bench_other.py --backend vllm --num-questions 25
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
### Benchmark guidance
|
||||||
|
```
|
||||||
|
python3 bench_other.py --backend guidance --num-questions 25 --parallel 1 --n-ctx 4096 --model-path path/to/gguf
|
||||||
|
```
|
||||||
|
|
||||||
|
### Benchmark lmql
|
||||||
|
|
||||||
|
```
|
||||||
|
python3 bench_other.py --backend lmql --num-questions 25 --parallel 1
|
||||||
|
```
|
||||||
|
|
@ -0,0 +1,151 @@
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate
|
||||||
|
from sglang.utils import dump_state_text, read_jsonl
|
||||||
|
|
||||||
|
system_prompt = "Please serve as an impartial judge and rigorously evaluate the quality of the following article. Apply the most stringent standards possible, showing no leniency."
|
||||||
|
|
||||||
|
dimension_prompts = [
|
||||||
|
"Content: This refers to the essences of the essay. The substance should be well researched, accurate, relevant to the topic and should show a thorough understanding of the subject. The essay should also reflect a clear goal or purpose.",
|
||||||
|
"Organization and Structure: An essay needs to be properly structured with a clear introduction, body, and conclusion. The essay should flow naturally, with one paragraph leading seamlessly into the next.",
|
||||||
|
"Argument and Analysis: The argument made in the essay should be logical, coherent and clearly articulated. Each point made should be backed up by solid evidence and thorough analysis.",
|
||||||
|
"Clarity and Precision: The essay should be written in a clear and concise manner. The points made should be easily understood by the reader. The language used should also be precise and unambiguous.",
|
||||||
|
"Grammar and Punctuation: Proper use of grammar and punctuation is vital in an academic essay. Errors in grammar and punctuation not only distract the reader but can also negatively impact the meaning and interpretation of the content.",
|
||||||
|
"Referencing and Citation: An essay should contain proper citations and references for all sources used. This not only prevents accusations of plagiarism but also gives credit to the authors of the works that have contributed to the essay. The citation should adhere to a specific format as required by the academic institution or specified by the professor.",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def multi_dimension_judge(article, generate):
|
||||||
|
s = system_prompt
|
||||||
|
s += "\n```\n" + article + "\n```\n\n"
|
||||||
|
|
||||||
|
judges = []
|
||||||
|
for i in range(len(dimension_prompts)):
|
||||||
|
comp = generate(
|
||||||
|
s
|
||||||
|
+ "USER: Please judge the quality based on the following metric. "
|
||||||
|
+ dimension_prompts[i]
|
||||||
|
+ " Please provide a single-paragraph judgement. "
|
||||||
|
+ "Focus on the provided metric and do not say other things. "
|
||||||
|
'End your judgement paragraph with the word "END"\nJUDGE:',
|
||||||
|
max_tokens=256,
|
||||||
|
stop="END",
|
||||||
|
)
|
||||||
|
judges.append(comp)
|
||||||
|
|
||||||
|
s += "I will judge the quality based on the following metrics.\n"
|
||||||
|
for i in range(len(dimension_prompts)):
|
||||||
|
s += dimension_prompts[i].split(":")[0] + ": " + judges[i].strip() + "\n"
|
||||||
|
|
||||||
|
s += "In summary, on a scale of 1 to 10, I would give the article a score of"
|
||||||
|
s += generate(s, max_tokens=2, stop=None)
|
||||||
|
|
||||||
|
return s
|
||||||
|
|
||||||
|
|
||||||
|
async def multi_dimension_judge_async(article, generate):
|
||||||
|
s = system_prompt
|
||||||
|
s += "\n```\n" + article + "\n```\n\n"
|
||||||
|
|
||||||
|
judges = []
|
||||||
|
for i in range(len(dimension_prompts)):
|
||||||
|
comp = await generate(
|
||||||
|
s
|
||||||
|
+ "USER: Please judge the quality based on the following metric. "
|
||||||
|
+ dimension_prompts[i]
|
||||||
|
+ " Please provide a single-paragraph judgement. "
|
||||||
|
+ "Focus on the provided metric and do not say other things. "
|
||||||
|
'End your judgement paragraph with the word "END"\nJUDGE:',
|
||||||
|
max_tokens=256,
|
||||||
|
stop="END",
|
||||||
|
)
|
||||||
|
judges.append(comp)
|
||||||
|
|
||||||
|
s += "I will judge the quality based on the following metrics.\n"
|
||||||
|
for i in range(len(dimension_prompts)):
|
||||||
|
s += dimension_prompts[i].split(":")[0] + ": " + judges[i].strip() + "\n"
|
||||||
|
|
||||||
|
s += "In summary, on a scale of 1 to 10, I would give the article a score of"
|
||||||
|
s += await generate(s, max_tokens=2, stop=None)
|
||||||
|
|
||||||
|
return s
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
lines = read_jsonl(args.data_path)[: args.num_questions]
|
||||||
|
states = [None] * len(lines)
|
||||||
|
|
||||||
|
# Select backend
|
||||||
|
call_generate = partial(get_call_generate(args), temperature=0)
|
||||||
|
|
||||||
|
# Run requests
|
||||||
|
tic = time.time()
|
||||||
|
|
||||||
|
if args.backend != "lmql":
|
||||||
|
|
||||||
|
def get_one_answer(i):
|
||||||
|
states[i] = multi_dimension_judge(lines[i], call_generate)
|
||||||
|
|
||||||
|
if args.parallel == 1:
|
||||||
|
for i in tqdm(range(len(lines))):
|
||||||
|
get_one_answer(i)
|
||||||
|
else:
|
||||||
|
with ThreadPoolExecutor(args.parallel) as executor:
|
||||||
|
list(
|
||||||
|
tqdm(
|
||||||
|
executor.map(get_one_answer, list(range(len(lines)))),
|
||||||
|
total=len(lines),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
async def get_one_answer_async(i):
|
||||||
|
states[i] = await multi_dimension_judge_async(lines[i], call_generate)
|
||||||
|
|
||||||
|
batches = []
|
||||||
|
for i in range(0, len(lines), args.parallel):
|
||||||
|
batches.append(list(range(i, min(i + args.parallel, len(lines)))))
|
||||||
|
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
for bt in tqdm(batches):
|
||||||
|
loop.run_until_complete(
|
||||||
|
asyncio.gather(*[get_one_answer_async(i) for i in bt])
|
||||||
|
)
|
||||||
|
|
||||||
|
latency = time.time() - tic
|
||||||
|
|
||||||
|
# Compute accuracy
|
||||||
|
print(f"Latency: {latency:.3f}")
|
||||||
|
|
||||||
|
# Write results
|
||||||
|
dump_state_text(f"tmp_output_{args.backend}.txt", states)
|
||||||
|
|
||||||
|
with open(args.result_file, "a") as fout:
|
||||||
|
value = {
|
||||||
|
"task": "llm_judge",
|
||||||
|
"backend": args.backend,
|
||||||
|
"num_gpus": 1,
|
||||||
|
"latency": round(latency, 3),
|
||||||
|
"num_requests": args.num_questions,
|
||||||
|
"other": {
|
||||||
|
"num_questions": args.num_questions,
|
||||||
|
"parallel": args.parallel,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
fout.write(json.dumps(value) + "\n")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--data-path", type=str, default="articles.jsonl")
|
||||||
|
parser.add_argument("--num-questions", type=int, default=20)
|
||||||
|
args = add_common_other_args_and_parse(parser)
|
||||||
|
main(args)
|
||||||
|
|
@ -0,0 +1,97 @@
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
|
||||||
|
import sglang as sgl
|
||||||
|
from sglang.test.test_utils import (
|
||||||
|
add_common_sglang_args_and_parse,
|
||||||
|
select_sglang_backend,
|
||||||
|
)
|
||||||
|
from sglang.utils import dump_state_text, read_jsonl
|
||||||
|
|
||||||
|
system_prompt = "Please serve as an impartial judge and rigorously evaluate the quality of the following article. Apply the most stringent standards possible, showing no leniency."
|
||||||
|
|
||||||
|
dimension_prompts = [
|
||||||
|
"Content: This refers to the essences of the essay. The substance should be well researched, accurate, relevant to the topic and should show a thorough understanding of the subject. The essay should also reflect a clear goal or purpose.",
|
||||||
|
"Organization and Structure: An essay needs to be properly structured with a clear introduction, body, and conclusion. The essay should flow naturally, with one paragraph leading seamlessly into the next.",
|
||||||
|
"Argument and Analysis: The argument made in the essay should be logical, coherent and clearly articulated. Each point made should be backed up by solid evidence and thorough analysis.",
|
||||||
|
"Clarity and Precision: The essay should be written in a clear and concise manner. The points made should be easily understood by the reader. The language used should also be precise and unambiguous.",
|
||||||
|
"Grammar and Punctuation: Proper use of grammar and punctuation is vital in an academic essay. Errors in grammar and punctuation not only distract the reader but can also negatively impact the meaning and interpretation of the content.",
|
||||||
|
"Referencing and Citation: An essay should contain proper citations and references for all sources used. This not only prevents accusations of plagiarism but also gives credit to the authors of the works that have contributed to the essay. The citation should adhere to a specific format as required by the academic institution or specified by the professor.",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@sgl.function
|
||||||
|
def multi_dimension_judge(s, article):
|
||||||
|
s += system_prompt
|
||||||
|
s += "\n```\n" + article + "\n```\n\n"
|
||||||
|
|
||||||
|
forks = s.fork(len(dimension_prompts))
|
||||||
|
for i in range(len(dimension_prompts)):
|
||||||
|
forks[i] += (
|
||||||
|
"USER: Please judge the quality based on the following metric. "
|
||||||
|
+ dimension_prompts[i]
|
||||||
|
+ " Please provide a single-paragraph judgement. "
|
||||||
|
+ "Focus on the provided metric and do not say other things. "
|
||||||
|
'End your judgement paragraph with the word "END"\nJUDGE:'
|
||||||
|
)
|
||||||
|
forks[i] += sgl.gen("judgement", max_tokens=256, stop="END")
|
||||||
|
forks.join()
|
||||||
|
|
||||||
|
s += "I will judge the quality based on the following metrics.\n"
|
||||||
|
for i in range(len(dimension_prompts)):
|
||||||
|
s += (
|
||||||
|
dimension_prompts[i].split(":")[0]
|
||||||
|
+ ": "
|
||||||
|
+ forks[i]["judgement"].strip()
|
||||||
|
+ "\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
s += "In summary, on a scale of 1 to 10, I would give the article a score of"
|
||||||
|
s += sgl.gen("score", max_tokens=2)
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
lines = read_jsonl(args.data_path)[: args.num_questions]
|
||||||
|
arguments = [{"article": l} for l in lines]
|
||||||
|
|
||||||
|
# Select backend
|
||||||
|
backend = select_sglang_backend(args)
|
||||||
|
|
||||||
|
# Run requests
|
||||||
|
tic = time.time()
|
||||||
|
states = multi_dimension_judge.run_batch(
|
||||||
|
arguments,
|
||||||
|
temperature=0,
|
||||||
|
backend=backend,
|
||||||
|
num_threads=args.parallel,
|
||||||
|
progress_bar=True,
|
||||||
|
)
|
||||||
|
latency = time.time() - tic
|
||||||
|
|
||||||
|
print(f"Latency: {latency:.3f}")
|
||||||
|
|
||||||
|
# Write results
|
||||||
|
dump_state_text(f"tmp_output_{args.backend}.txt", states)
|
||||||
|
|
||||||
|
with open(args.result_file, "a") as fout:
|
||||||
|
value = {
|
||||||
|
"task": "llm_judge",
|
||||||
|
"backend": args.backend,
|
||||||
|
"num_gpus": 1,
|
||||||
|
"latency": round(latency, 3),
|
||||||
|
"num_requests": args.num_questions,
|
||||||
|
"other": {
|
||||||
|
"num_questions": args.num_questions,
|
||||||
|
"parallel": args.parallel,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
fout.write(json.dumps(value) + "\n")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--data-path", type=str, default="articles.jsonl")
|
||||||
|
parser.add_argument("--num-questions", type=int, default=20)
|
||||||
|
args = add_common_sglang_args_and_parse(parser)
|
||||||
|
main(args)
|
||||||
|
|
@ -0,0 +1,33 @@
|
||||||
|
## Run benchmark
|
||||||
|
|
||||||
|
### Benchmark sglang
|
||||||
|
```
|
||||||
|
python3 -m sglang.launch_server --model-path codellama/CodeLlama-7b-instruct-hf --port 30000
|
||||||
|
```
|
||||||
|
|
||||||
|
```
|
||||||
|
python3 bench_sglang.py --num-questions 5 --parallel 1
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
### Benchmark vllm
|
||||||
|
```
|
||||||
|
python3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model codellama/CodeLlama-7b-instruct-hf --disable-log-requests --port 21000 --gpu 0.97
|
||||||
|
```
|
||||||
|
|
||||||
|
```
|
||||||
|
python3 bench_other.py --backend vllm --num-questions 5
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
### Benchmark guidance
|
||||||
|
```
|
||||||
|
python3 bench_other.py --backend guidance --num-questions 5 --parallel 1 --n-ctx 11000 --model-path path/to/code-llama/gguf
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
### Build dataset
|
||||||
|
```
|
||||||
|
pip install wikipedia
|
||||||
|
python3 build_dataset.py
|
||||||
|
```
|
||||||
|
|
@ -0,0 +1,89 @@
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate
|
||||||
|
from sglang.utils import dump_state_text, read_jsonl
|
||||||
|
|
||||||
|
|
||||||
|
def json_decode(document, generate):
|
||||||
|
s = "Please extract the information of a city from the following wikipedia page.\n"
|
||||||
|
s += "Page begin.\n" + document + "Page end.\n"
|
||||||
|
s += "Here is the name, country, and symbol of the city in JSON format.\n"
|
||||||
|
s += "{\n"
|
||||||
|
s += ' "name": "'
|
||||||
|
s += generate(s, max_tokens=8, stop='"') + '",\n'
|
||||||
|
s += ' "country": "'
|
||||||
|
s += generate(s, max_tokens=8, stop='"') + '",\n'
|
||||||
|
s += ' "air port code": "'
|
||||||
|
s += generate(s, max_tokens=8, stop='"') + '",\n'
|
||||||
|
s += ' "top 3 landmarks": "'
|
||||||
|
s += generate(s, max_tokens=24, stop='"') + '",\n'
|
||||||
|
s += "}\n"
|
||||||
|
return s
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
lines = read_jsonl(args.data_path)
|
||||||
|
arguments = []
|
||||||
|
for i in range(len(lines[: args.num_questions])):
|
||||||
|
arguments.append(
|
||||||
|
{
|
||||||
|
"document": lines[i]["document"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
states = [None] * len(arguments)
|
||||||
|
|
||||||
|
# Select backend
|
||||||
|
call_generate = partial(get_call_generate(args), temperature=0)
|
||||||
|
|
||||||
|
# Run requests
|
||||||
|
def get_one_answer(i):
|
||||||
|
states[i] = json_decode(generate=call_generate, **arguments[i])
|
||||||
|
|
||||||
|
tic = time.time()
|
||||||
|
if args.parallel == 1:
|
||||||
|
for i in tqdm(range(len(arguments))):
|
||||||
|
get_one_answer(i)
|
||||||
|
else:
|
||||||
|
with ThreadPoolExecutor(args.parallel) as executor:
|
||||||
|
list(
|
||||||
|
tqdm(
|
||||||
|
executor.map(get_one_answer, list(range(len(arguments)))),
|
||||||
|
total=len(arguments),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
latency = time.time() - tic
|
||||||
|
|
||||||
|
# Compute accuracy
|
||||||
|
print(f"Latency: {latency:.3f}")
|
||||||
|
|
||||||
|
# Write results
|
||||||
|
dump_state_text(f"tmp_output_{args.backend}.txt", states)
|
||||||
|
|
||||||
|
with open(args.result_file, "a") as fout:
|
||||||
|
value = {
|
||||||
|
"task": "long_json_decode",
|
||||||
|
"backend": args.backend,
|
||||||
|
"num_gpus": 1,
|
||||||
|
"latency": round(latency, 3),
|
||||||
|
"num_requests": args.num_questions,
|
||||||
|
"other": {
|
||||||
|
"num_questions": args.num_questions,
|
||||||
|
"parallel": args.parallel,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
fout.write(json.dumps(value) + "\n")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--data-path", type=str, default="questions.jsonl")
|
||||||
|
parser.add_argument("--num-questions", type=int, default=100)
|
||||||
|
args = add_common_other_args_and_parse(parser)
|
||||||
|
main(args)
|
||||||
|
|
@ -0,0 +1,81 @@
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
|
||||||
|
import sglang as sgl
|
||||||
|
from sglang.test.test_utils import (
|
||||||
|
add_common_sglang_args_and_parse,
|
||||||
|
select_sglang_backend,
|
||||||
|
)
|
||||||
|
from sglang.utils import dump_state_text, read_jsonl
|
||||||
|
|
||||||
|
|
||||||
|
@sgl.function
|
||||||
|
def json_decode(s, document):
|
||||||
|
s += "Please extract the information of a city from the following wikipedia page.\n"
|
||||||
|
s += "Page begin.\n" + document + "Page end.\n"
|
||||||
|
s += "Here is the name, country, and symbol of the city in JSON format.\n"
|
||||||
|
s += "{\n"
|
||||||
|
s += ' "name": "' + sgl.gen("name", max_tokens=8, stop='"') + '",\n'
|
||||||
|
s += ' "country": "' + sgl.gen("country", max_tokens=8, stop='"') + '",\n'
|
||||||
|
s += (
|
||||||
|
' "air port code": "'
|
||||||
|
+ sgl.gen("air port code", max_tokens=8, stop='"')
|
||||||
|
+ '",\n'
|
||||||
|
)
|
||||||
|
s += (
|
||||||
|
' "top 3 landmarks": "'
|
||||||
|
+ sgl.gen("landmarks", max_tokens=24, stop='"')
|
||||||
|
+ '",\n'
|
||||||
|
)
|
||||||
|
s += "}\n"
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
lines = read_jsonl(args.data_path)
|
||||||
|
arguments = []
|
||||||
|
for i in range(len(lines[: args.num_questions])):
|
||||||
|
arguments.append(
|
||||||
|
{
|
||||||
|
"document": lines[i]["document"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Select backend
|
||||||
|
backend = select_sglang_backend(args)
|
||||||
|
sgl.set_default_backend(backend)
|
||||||
|
|
||||||
|
# Run requests
|
||||||
|
tic = time.time()
|
||||||
|
states = json_decode.run_batch(
|
||||||
|
arguments, temperature=0, num_threads=args.parallel, progress_bar=True
|
||||||
|
)
|
||||||
|
latency = time.time() - tic
|
||||||
|
|
||||||
|
# Compute accuracy
|
||||||
|
print(f"Latency: {latency:.3f}")
|
||||||
|
|
||||||
|
# Write results
|
||||||
|
dump_state_text(f"tmp_output_{args.backend}.txt", states)
|
||||||
|
|
||||||
|
with open(args.result_file, "a") as fout:
|
||||||
|
value = {
|
||||||
|
"task": "long_json_decode",
|
||||||
|
"backend": args.backend,
|
||||||
|
"num_gpus": 1,
|
||||||
|
"latency": round(latency, 3),
|
||||||
|
"num_requests": args.num_questions,
|
||||||
|
"other": {
|
||||||
|
"num_questions": args.num_questions,
|
||||||
|
"parallel": args.parallel,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
fout.write(json.dumps(value) + "\n")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--data-path", type=str, default="questions.jsonl")
|
||||||
|
parser.add_argument("--num-questions", type=int, default=10)
|
||||||
|
args = add_common_sglang_args_and_parse(parser)
|
||||||
|
main(args)
|
||||||
|
|
@ -0,0 +1,27 @@
|
||||||
|
import json
|
||||||
|
|
||||||
|
import transformers
|
||||||
|
import wikipedia
|
||||||
|
|
||||||
|
name = "meta-llama/Llama-2-7b-chat-hf"
|
||||||
|
t = transformers.AutoTokenizer.from_pretrained(name)
|
||||||
|
city_names = ["los angles", "london", "tokyo", "beijing", "singapore"]
|
||||||
|
|
||||||
|
|
||||||
|
for city_name in city_names:
|
||||||
|
content = str(wikipedia.page(city_name).content)
|
||||||
|
content = content.replace("\n\n", "\n")
|
||||||
|
|
||||||
|
tokens = t.encode(content)
|
||||||
|
|
||||||
|
truncate_len = int((10000 / len(tokens)) * len(content))
|
||||||
|
truncate_content = content[:truncate_len]
|
||||||
|
truncate_tokens = t.encode(truncate_content)
|
||||||
|
|
||||||
|
# Count token
|
||||||
|
print(
|
||||||
|
f"city_name: {city_name}, #tokens: {len(tokens)}, #truncate tokens: {len(truncate_tokens)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
with open("questions.jsonl", "a") as fout:
|
||||||
|
fout.write(json.dumps({"document": truncate_content}) + "\n")
|
||||||
|
|
@ -0,0 +1,68 @@
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
|
||||||
|
NUM_LORAS = 4
|
||||||
|
LORA_PATH = {
|
||||||
|
"base": "meta-llama/Llama-2-7b-hf",
|
||||||
|
"lora": "winddude/wizardLM-LlaMA-LoRA-7B",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def launch_server(args):
|
||||||
|
base_path = LORA_PATH["base"]
|
||||||
|
lora_path = LORA_PATH["lora"]
|
||||||
|
|
||||||
|
if args.base_only:
|
||||||
|
cmd = f"python3 -m sglang.launch_server --model {base_path} "
|
||||||
|
else:
|
||||||
|
cmd = f"python3 -m sglang.launch_server --model {base_path} --lora-paths "
|
||||||
|
for i in range(NUM_LORAS):
|
||||||
|
lora_name = f"lora{i}"
|
||||||
|
cmd += f"{lora_name}={lora_path} "
|
||||||
|
cmd += f"--disable-radix --disable-cuda-graph "
|
||||||
|
cmd += f"--max-loras-per-batch {args.max_loras_per_batch} "
|
||||||
|
cmd += f"--max-running-requests {args.max_running_requests} "
|
||||||
|
cmd += f"--lora-backend {args.lora_backend} "
|
||||||
|
cmd += f"--tp-size {args.tp_size} "
|
||||||
|
if args.disable_custom_all_reduce:
|
||||||
|
cmd += "--disable-custom-all-reduce"
|
||||||
|
print(cmd)
|
||||||
|
os.system(cmd)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--base-only",
|
||||||
|
action="store_true",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-loras-per-batch",
|
||||||
|
type=int,
|
||||||
|
default=8,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-running-requests",
|
||||||
|
type=int,
|
||||||
|
default=8,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--lora-backend",
|
||||||
|
type=str,
|
||||||
|
default="triton",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--tp-size",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Tensor parallel size for distributed inference",
|
||||||
|
)
|
||||||
|
# disable_custom_all_reduce
|
||||||
|
parser.add_argument(
|
||||||
|
"--disable-custom-all-reduce",
|
||||||
|
action="store_true",
|
||||||
|
help="Disable custom all reduce when device does not support p2p communication",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
launch_server(args)
|
||||||
|
|
@ -0,0 +1,475 @@
|
||||||
|
# Copyright 2023-2024 SGLang Team
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import random
|
||||||
|
import resource
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
import traceback
|
||||||
|
from argparse import ArgumentParser
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
import numpy as np
|
||||||
|
from launch_server import LORA_PATH, NUM_LORAS
|
||||||
|
from tqdm.asyncio import tqdm
|
||||||
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
|
||||||
|
from sglang.bench_serving import (
|
||||||
|
AIOHTTP_TIMEOUT,
|
||||||
|
RequestFuncInput,
|
||||||
|
RequestFuncOutput,
|
||||||
|
calculate_metrics,
|
||||||
|
get_request,
|
||||||
|
get_tokenizer,
|
||||||
|
remove_prefix,
|
||||||
|
sample_random_requests,
|
||||||
|
)
|
||||||
|
|
||||||
|
global args
|
||||||
|
|
||||||
|
|
||||||
|
# set ignore_eos True by default
|
||||||
|
async def async_request_openai_completions(
|
||||||
|
request_func_input: RequestFuncInput,
|
||||||
|
pbar: Optional[tqdm] = None,
|
||||||
|
) -> RequestFuncOutput:
|
||||||
|
api_url = request_func_input.api_url
|
||||||
|
# assert api_url.endswith(
|
||||||
|
# "completions"
|
||||||
|
# ), "OpenAI Completions API URL must end with 'completions'."
|
||||||
|
|
||||||
|
prompt = request_func_input.prompt
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
|
||||||
|
# payload = {
|
||||||
|
# "model": request_func_input.model,
|
||||||
|
# "prompt": prompt,
|
||||||
|
# "temperature": 0.0,
|
||||||
|
# "best_of": 1,
|
||||||
|
# "max_tokens": request_func_input.output_len,
|
||||||
|
# "stream": not args.disable_stream,
|
||||||
|
# "ignore_eos": not args.disable_ignore_eos,
|
||||||
|
# **request_func_input.extra_request_body,
|
||||||
|
# }
|
||||||
|
# headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
|
||||||
|
if args.base_only:
|
||||||
|
payload = {
|
||||||
|
"text": prompt,
|
||||||
|
"sampling_params": {"max_new_tokens": request_func_input.output_len},
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
payload = {
|
||||||
|
"text": prompt,
|
||||||
|
"sampling_params": {"max_new_tokens": request_func_input.output_len},
|
||||||
|
"lora_path": f"lora{random.randint(0, NUM_LORAS - 1)}",
|
||||||
|
}
|
||||||
|
headers = {"Authorization": ""}
|
||||||
|
|
||||||
|
output = RequestFuncOutput()
|
||||||
|
output.prompt_len = request_func_input.prompt_len
|
||||||
|
|
||||||
|
generated_text = ""
|
||||||
|
ttft = 0.0
|
||||||
|
st = time.perf_counter()
|
||||||
|
most_recent_timestamp = st
|
||||||
|
try:
|
||||||
|
async with session.post(
|
||||||
|
url=api_url, json=payload, headers=headers
|
||||||
|
) as response:
|
||||||
|
if response.status == 200:
|
||||||
|
async for chunk_bytes in response.content:
|
||||||
|
chunk_bytes = chunk_bytes.strip()
|
||||||
|
if not chunk_bytes:
|
||||||
|
continue
|
||||||
|
|
||||||
|
chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ")
|
||||||
|
latency = time.perf_counter() - st
|
||||||
|
if chunk == "[DONE]":
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
data = json.loads(chunk)
|
||||||
|
|
||||||
|
# NOTE: Some completion API might have a last
|
||||||
|
# usage summary response without a token so we
|
||||||
|
# want to check a token was generated
|
||||||
|
if data["text"]:
|
||||||
|
# if data["choices"][0]["text"]:
|
||||||
|
timestamp = time.perf_counter()
|
||||||
|
# First token
|
||||||
|
if ttft == 0.0:
|
||||||
|
ttft = time.perf_counter() - st
|
||||||
|
output.ttft = ttft
|
||||||
|
|
||||||
|
# Decoding phase
|
||||||
|
else:
|
||||||
|
output.itl.append(timestamp - most_recent_timestamp)
|
||||||
|
|
||||||
|
most_recent_timestamp = timestamp
|
||||||
|
# generated_text += data["choices"][0]["text"]
|
||||||
|
generated_text += data["text"]
|
||||||
|
|
||||||
|
output.generated_text = generated_text
|
||||||
|
output.success = True
|
||||||
|
output.latency = latency
|
||||||
|
output.output_len = request_func_input.output_len
|
||||||
|
else:
|
||||||
|
output.error = response.reason or ""
|
||||||
|
output.success = False
|
||||||
|
except Exception:
|
||||||
|
output.success = False
|
||||||
|
exc_info = sys.exc_info()
|
||||||
|
output.error = "".join(traceback.format_exception(*exc_info))
|
||||||
|
|
||||||
|
if pbar:
|
||||||
|
pbar.update(1)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
ASYNC_REQUEST_FUNCS = {
|
||||||
|
"sglang": async_request_openai_completions,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def benchmark(
|
||||||
|
backend: str,
|
||||||
|
api_url: str,
|
||||||
|
model_id: str,
|
||||||
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
|
input_requests: List[Tuple[str, int, int]],
|
||||||
|
request_rate: float,
|
||||||
|
disable_tqdm: bool,
|
||||||
|
extra_request_body: Dict[str, Any],
|
||||||
|
):
|
||||||
|
if backend in ASYNC_REQUEST_FUNCS:
|
||||||
|
request_func = ASYNC_REQUEST_FUNCS[backend]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown backend: {backend}")
|
||||||
|
|
||||||
|
print("Starting initial single prompt test run...")
|
||||||
|
test_prompt, test_prompt_len, test_output_len = input_requests[0]
|
||||||
|
test_input = RequestFuncInput(
|
||||||
|
model=model_id,
|
||||||
|
prompt=test_prompt,
|
||||||
|
api_url=api_url,
|
||||||
|
prompt_len=test_prompt_len,
|
||||||
|
output_len=test_output_len,
|
||||||
|
lora_name="dummy", # the lora_name argument will not be used
|
||||||
|
extra_request_body=extra_request_body,
|
||||||
|
)
|
||||||
|
test_output = await request_func(request_func_input=test_input)
|
||||||
|
if not test_output.success:
|
||||||
|
raise ValueError(
|
||||||
|
"Initial test run failed - Please make sure benchmark arguments "
|
||||||
|
f"are correctly specified. Error: {test_output.error}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print("Initial test run completed. Starting main benchmark run...")
|
||||||
|
|
||||||
|
pbar = None if disable_tqdm else tqdm(total=len(input_requests))
|
||||||
|
|
||||||
|
benchmark_start_time = time.perf_counter()
|
||||||
|
tasks: List[asyncio.Task] = []
|
||||||
|
async for request in get_request(input_requests, request_rate):
|
||||||
|
prompt, prompt_len, output_len = request
|
||||||
|
request_func_input = RequestFuncInput(
|
||||||
|
model=model_id,
|
||||||
|
prompt=prompt,
|
||||||
|
api_url=api_url,
|
||||||
|
prompt_len=prompt_len,
|
||||||
|
output_len=output_len,
|
||||||
|
lora_name="dummy",
|
||||||
|
extra_request_body=extra_request_body,
|
||||||
|
)
|
||||||
|
tasks.append(
|
||||||
|
asyncio.create_task(
|
||||||
|
request_func(request_func_input=request_func_input, pbar=pbar)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
if pbar is not None:
|
||||||
|
pbar.close()
|
||||||
|
|
||||||
|
benchmark_duration = time.perf_counter() - benchmark_start_time
|
||||||
|
|
||||||
|
metrics, output_lens = calculate_metrics(
|
||||||
|
input_requests=input_requests,
|
||||||
|
outputs=outputs,
|
||||||
|
dur_s=benchmark_duration,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
backend=backend,
|
||||||
|
)
|
||||||
|
|
||||||
|
print("\n{s:{c}^{n}}".format(s=" Serving Benchmark Result ", n=50, c="="))
|
||||||
|
print("{:<40} {:<10}".format("Backend:", backend))
|
||||||
|
print("{:<40} {:<10}".format("Traffic request rate:", request_rate))
|
||||||
|
print("{:<40} {:<10}".format("Successful requests:", metrics.completed))
|
||||||
|
print("{:<40} {:<10.2f}".format("Benchmark duration (s):", benchmark_duration))
|
||||||
|
print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input))
|
||||||
|
print("{:<40} {:<10}".format("Total generated tokens:", metrics.total_output))
|
||||||
|
print(
|
||||||
|
"{:<40} {:<10}".format(
|
||||||
|
"Total generated tokens (retokenized):", metrics.total_output_retokenized
|
||||||
|
)
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
"{:<40} {:<10.2f}".format(
|
||||||
|
"Request throughput (req/s):", metrics.request_throughput
|
||||||
|
)
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
"{:<40} {:<10.2f}".format(
|
||||||
|
"Input token throughput (tok/s):", metrics.input_throughput
|
||||||
|
)
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
"{:<40} {:<10.2f}".format(
|
||||||
|
"Output token throughput (tok/s):", metrics.output_throughput
|
||||||
|
)
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
"{:<40} {:<10.2f}".format("Total throughput (tok/s):", metrics.total_throughput)
|
||||||
|
)
|
||||||
|
print("{s:{c}^{n}}".format(s="End-to-End Latency", n=50, c="-"))
|
||||||
|
print(
|
||||||
|
"{:<40} {:<10.2f}".format("Mean E2E Latency (ms):", metrics.mean_e2e_latency_ms)
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
"{:<40} {:<10.2f}".format(
|
||||||
|
"Median E2E Latency (ms):", metrics.median_e2e_latency_ms
|
||||||
|
)
|
||||||
|
)
|
||||||
|
print("{s:{c}^{n}}".format(s="Time to First Token", n=50, c="-"))
|
||||||
|
print("{:<40} {:<10.2f}".format("Mean TTFT (ms):", metrics.mean_ttft_ms))
|
||||||
|
print("{:<40} {:<10.2f}".format("Median TTFT (ms):", metrics.median_ttft_ms))
|
||||||
|
print("{:<40} {:<10.2f}".format("P99 TTFT (ms):", metrics.p99_ttft_ms))
|
||||||
|
print(
|
||||||
|
"{s:{c}^{n}}".format(s="Time per Output Token (excl. 1st token)", n=50, c="-")
|
||||||
|
)
|
||||||
|
print("{:<40} {:<10.2f}".format("Mean TPOT (ms):", metrics.mean_tpot_ms))
|
||||||
|
print("{:<40} {:<10.2f}".format("Median TPOT (ms):", metrics.median_tpot_ms))
|
||||||
|
print("{:<40} {:<10.2f}".format("P99 TPOT (ms):", metrics.p99_tpot_ms))
|
||||||
|
print("{s:{c}^{n}}".format(s="Inter-token Latency", n=50, c="-"))
|
||||||
|
print("{:<40} {:<10.2f}".format("Mean ITL (ms):", metrics.mean_itl_ms))
|
||||||
|
print("{:<40} {:<10.2f}".format("Median ITL (ms):", metrics.median_itl_ms))
|
||||||
|
print("{:<40} {:<10.2f}".format("P99 ITL (ms):", metrics.p99_itl_ms))
|
||||||
|
print("=" * 50)
|
||||||
|
|
||||||
|
if (
|
||||||
|
metrics.median_ttft_ms is not None
|
||||||
|
and metrics.mean_itl_ms is not None
|
||||||
|
and metrics.output_throughput is not None
|
||||||
|
):
|
||||||
|
result = {
|
||||||
|
"backend": args.backend,
|
||||||
|
"request_rate": request_rate,
|
||||||
|
"total_input_tokens": metrics.total_input,
|
||||||
|
"total_output_tokens": metrics.total_output,
|
||||||
|
"total_output_tokens_retokenized": metrics.total_output_retokenized,
|
||||||
|
"mean_e2e_latency_ms": metrics.mean_e2e_latency_ms,
|
||||||
|
"median_e2e_latency_ms": metrics.median_e2e_latency_ms,
|
||||||
|
"median_ttft_ms": metrics.median_ttft_ms,
|
||||||
|
"median_itl_ms": metrics.median_itl_ms,
|
||||||
|
"output_throughput": metrics.output_throughput,
|
||||||
|
"random_input_len": args.random_input_len,
|
||||||
|
"random_output_len": args.random_output_len,
|
||||||
|
"random_range_ratio": args.random_range_ratio,
|
||||||
|
"duration": benchmark_duration,
|
||||||
|
"completed": metrics.completed,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
print(f"Error running benchmark for request rate: {request_rate}")
|
||||||
|
print("-" * 30)
|
||||||
|
|
||||||
|
# Determine output file name
|
||||||
|
if args.output_file:
|
||||||
|
output_file_name = args.output_file
|
||||||
|
else:
|
||||||
|
now = datetime.now().strftime("%m%d")
|
||||||
|
output_file_name = f"{args.backend}_{now}_{args.num_prompts}_{args.random_input_len}_{args.random_output_len}.jsonl"
|
||||||
|
|
||||||
|
# Append results to a JSONL file
|
||||||
|
with open(output_file_name, "a") as file:
|
||||||
|
file.write(json.dumps(result) + "\n")
|
||||||
|
|
||||||
|
result = {
|
||||||
|
"duration": benchmark_duration,
|
||||||
|
"completed": metrics.completed,
|
||||||
|
"total_input_tokens": metrics.total_input,
|
||||||
|
"total_output_tokens": metrics.total_output,
|
||||||
|
"total_output_tokens_retokenized": metrics.total_output_retokenized,
|
||||||
|
"request_throughput": metrics.request_throughput,
|
||||||
|
"input_throughput": metrics.input_throughput,
|
||||||
|
"output_throughput": metrics.output_throughput,
|
||||||
|
"mean_ttft_ms": metrics.mean_ttft_ms,
|
||||||
|
"median_ttft_ms": metrics.median_ttft_ms,
|
||||||
|
"std_ttft_ms": metrics.std_ttft_ms,
|
||||||
|
"p99_ttft_ms": metrics.p99_ttft_ms,
|
||||||
|
"mean_tpot_ms": metrics.mean_tpot_ms,
|
||||||
|
"median_tpot_ms": metrics.median_tpot_ms,
|
||||||
|
"std_tpot_ms": metrics.std_tpot_ms,
|
||||||
|
"p99_tpot_ms": metrics.p99_tpot_ms,
|
||||||
|
"mean_itl_ms": metrics.mean_itl_ms,
|
||||||
|
"median_itl_ms": metrics.median_itl_ms,
|
||||||
|
"std_itl_ms": metrics.std_itl_ms,
|
||||||
|
"p99_itl_ms": metrics.p99_itl_ms,
|
||||||
|
"input_lens": [output.prompt_len for output in outputs],
|
||||||
|
"output_lens": output_lens,
|
||||||
|
"ttfts": [output.ttft for output in outputs],
|
||||||
|
"itls": [output.itl for output in outputs],
|
||||||
|
"generated_texts": [output.generated_text for output in outputs],
|
||||||
|
"errors": [output.error for output in outputs],
|
||||||
|
"mean_e2e_latency_ms": metrics.mean_e2e_latency_ms,
|
||||||
|
"median_e2e_latency_ms": metrics.median_e2e_latency_ms,
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def run_benchmark(args_: argparse.Namespace):
|
||||||
|
global args
|
||||||
|
args = args_
|
||||||
|
|
||||||
|
# Set global environments
|
||||||
|
set_ulimit()
|
||||||
|
random.seed(args.seed)
|
||||||
|
np.random.seed(args.seed)
|
||||||
|
|
||||||
|
# Set url
|
||||||
|
if args.port is None:
|
||||||
|
args.port = {
|
||||||
|
"sglang": 30000,
|
||||||
|
}.get(args.backend, 30000)
|
||||||
|
|
||||||
|
# api_url = (
|
||||||
|
# f"{args.base_url}/v1/completions"
|
||||||
|
# if args.base_url
|
||||||
|
# else f"http://{args.host}:{args.port}/v1/completions"
|
||||||
|
# )
|
||||||
|
api_url = (
|
||||||
|
f"{args.base_url}/generate"
|
||||||
|
if args.base_url
|
||||||
|
else f"http://{args.host}:{args.port}/generate"
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"{args}\n")
|
||||||
|
|
||||||
|
# Read dataset
|
||||||
|
backend = args.backend
|
||||||
|
model_id = args.model = LORA_PATH["base"]
|
||||||
|
tokenizer_id = args.model
|
||||||
|
|
||||||
|
tokenizer = get_tokenizer(tokenizer_id)
|
||||||
|
|
||||||
|
input_requests = sample_random_requests(
|
||||||
|
input_len=args.random_input_len,
|
||||||
|
output_len=args.random_output_len,
|
||||||
|
num_prompts=args.num_prompts,
|
||||||
|
range_ratio=args.random_range_ratio,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
dataset_path="",
|
||||||
|
)
|
||||||
|
|
||||||
|
return asyncio.run(
|
||||||
|
benchmark(
|
||||||
|
backend=backend,
|
||||||
|
api_url=api_url,
|
||||||
|
model_id=model_id,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
input_requests=input_requests,
|
||||||
|
request_rate=args.request_rate,
|
||||||
|
disable_tqdm=False,
|
||||||
|
extra_request_body={},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def set_ulimit(target_soft_limit=65535):
|
||||||
|
resource_type = resource.RLIMIT_NOFILE
|
||||||
|
current_soft, current_hard = resource.getrlimit(resource_type)
|
||||||
|
|
||||||
|
if current_soft < target_soft_limit:
|
||||||
|
try:
|
||||||
|
resource.setrlimit(resource_type, (target_soft_limit, current_hard))
|
||||||
|
except ValueError as e:
|
||||||
|
print(f"Fail to set RLIMIT_NOFILE: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = ArgumentParser(description="Benchmark the online lora serving throughput.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--backend",
|
||||||
|
type=str,
|
||||||
|
choices=list(ASYNC_REQUEST_FUNCS.keys()),
|
||||||
|
default="sglang",
|
||||||
|
help="Must specify a backend, depending on the LLM Inference Engine.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--base-url",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Server or API base url if not using http host and port.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--host", type=str, default="0.0.0.0", help="Default host is 0.0.0.0."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--port",
|
||||||
|
type=int,
|
||||||
|
help="If not set, the default port is configured according to its default value for different LLM Inference Engines.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-prompts",
|
||||||
|
type=int,
|
||||||
|
default=50,
|
||||||
|
help="Number of prompts to process. Default is 1000.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--random-input-len",
|
||||||
|
type=int,
|
||||||
|
default=1024,
|
||||||
|
help="Number of input tokens per request, used only for random dataset.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--random-output-len",
|
||||||
|
type=int,
|
||||||
|
default=128,
|
||||||
|
help="Number of output tokens per request, used only for random dataset.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--random-range-ratio",
|
||||||
|
type=float,
|
||||||
|
default=0.0,
|
||||||
|
help="Range of sampled ratio of input/output length, "
|
||||||
|
"used only for random dataset.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--request-rate",
|
||||||
|
type=float,
|
||||||
|
default=float("inf"),
|
||||||
|
help="Number of requests per second. If this is inf, then all the requests are sent at time 0. "
|
||||||
|
"Otherwise, we use Poisson process to synthesize the request arrival times. Default is inf.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--base-only",
|
||||||
|
action="store_true",
|
||||||
|
)
|
||||||
|
parser.add_argument("--output-file", type=str, help="Output JSONL file name.")
|
||||||
|
parser.add_argument("--seed", type=int, default=1, help="The random seed.")
|
||||||
|
args = parser.parse_args()
|
||||||
|
run_benchmark(args)
|
||||||
|
|
@ -0,0 +1,59 @@
|
||||||
|
## Download data
|
||||||
|
```
|
||||||
|
bash download_data.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
## Run benchmark
|
||||||
|
|
||||||
|
### Benchmark sglang
|
||||||
|
```
|
||||||
|
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
|
||||||
|
```
|
||||||
|
|
||||||
|
```
|
||||||
|
python3 bench_sglang.py --nsub 10
|
||||||
|
```
|
||||||
|
|
||||||
|
```
|
||||||
|
# OpenAI models
|
||||||
|
python3 bench_sglang.py --backend gpt-3.5-turbo --parallel 8
|
||||||
|
```
|
||||||
|
|
||||||
|
### Benchmark vllm
|
||||||
|
```
|
||||||
|
python3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000
|
||||||
|
```
|
||||||
|
|
||||||
|
```
|
||||||
|
python3 bench_other.py --nsub 10 --backend vllm
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
### Benchmark lightllm
|
||||||
|
```
|
||||||
|
# A10G
|
||||||
|
python -m lightllm.server.api_server --tokenizer_mode auto --model_dir ~/model_weights/llama-2-7b-chat-hf --max_total_token_num 16000 --port 22000
|
||||||
|
|
||||||
|
# V100
|
||||||
|
python -m lightllm.server.api_server --tokenizer_mode auto --model_dir ~/model_weights/llama-2-7b-chat-hf --max_total_token_num 4500 --port 22000
|
||||||
|
```
|
||||||
|
|
||||||
|
```
|
||||||
|
python3 bench_other.py --nsub 10 --backend lightllm
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
### Benchmark guidance
|
||||||
|
```
|
||||||
|
python3 bench_other.py --nsub 10 --backend guidance --parallel 1 --n-ctx 4096 --model-path path/to/gguf
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
### Benchmark lmql
|
||||||
|
```
|
||||||
|
CUDA_VISIBLE_DEVICES=0,1 lmql serve-model meta-llama/Llama-2-7b-chat-hf --cuda --port 23000
|
||||||
|
```
|
||||||
|
|
||||||
|
```
|
||||||
|
python3 bench_other.py --nsub 10 --backend lmql --parallel 2
|
||||||
|
```
|
||||||
|
|
@ -0,0 +1,173 @@
|
||||||
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
import tiktoken
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate
|
||||||
|
|
||||||
|
choices = ["A", "B", "C", "D"]
|
||||||
|
|
||||||
|
tokenizer = tiktoken.encoding_for_model("gpt-3.5-turbo")
|
||||||
|
|
||||||
|
|
||||||
|
def format_subject(subject):
|
||||||
|
l = subject.split("_")
|
||||||
|
s = ""
|
||||||
|
for entry in l:
|
||||||
|
s += " " + entry
|
||||||
|
return s
|
||||||
|
|
||||||
|
|
||||||
|
def format_example(df, idx, include_answer=True):
|
||||||
|
prompt = df.iloc[idx, 0]
|
||||||
|
k = df.shape[1] - 2
|
||||||
|
for j in range(k):
|
||||||
|
prompt += "\n{}. {}".format(choices[j], df.iloc[idx, j + 1])
|
||||||
|
prompt += "\nAnswer:"
|
||||||
|
if include_answer:
|
||||||
|
prompt += " {}\n\n".format(df.iloc[idx, k + 1])
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
|
def gen_prompt(train_df, subject, k=-1):
|
||||||
|
prompt = "The following are multiple choice questions (with answers) about{}.\n\n".format(
|
||||||
|
format_subject(subject)
|
||||||
|
)
|
||||||
|
if k == -1:
|
||||||
|
k = train_df.shape[0]
|
||||||
|
for i in range(k):
|
||||||
|
prompt += format_example(train_df, i)
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate(args, subject, dev_df, test_df, call_generate):
|
||||||
|
prompts = []
|
||||||
|
labels = []
|
||||||
|
|
||||||
|
# Construct prompts
|
||||||
|
k = args.ntrain
|
||||||
|
train_prompt = gen_prompt(dev_df, subject, k)
|
||||||
|
while len(tokenizer.encode(train_prompt)) > 1536:
|
||||||
|
k -= 1
|
||||||
|
train_prompt = gen_prompt(dev_df, subject, k)
|
||||||
|
|
||||||
|
for i in range(test_df.shape[0]):
|
||||||
|
prompt_end = format_example(test_df, i, include_answer=False)
|
||||||
|
prompt = train_prompt + prompt_end
|
||||||
|
prompts.append(prompt)
|
||||||
|
|
||||||
|
label = test_df.iloc[i, test_df.shape[1] - 1]
|
||||||
|
labels.append(label)
|
||||||
|
|
||||||
|
preds = [None] * len(prompts)
|
||||||
|
max_tokens = 1
|
||||||
|
|
||||||
|
# Run requests
|
||||||
|
if args.backend != "lmql":
|
||||||
|
# Use thread pool
|
||||||
|
def get_one_answer(i):
|
||||||
|
pred = call_generate(prompts[i], temperature=0, max_tokens=max_tokens)
|
||||||
|
preds[i] = pred.strip()[0]
|
||||||
|
|
||||||
|
tic = time.time()
|
||||||
|
if args.parallel == 1:
|
||||||
|
for i in range(len(prompts)):
|
||||||
|
get_one_answer(i)
|
||||||
|
else:
|
||||||
|
with ThreadPoolExecutor(args.parallel) as executor:
|
||||||
|
executor.map(get_one_answer, list(range(len(prompts))))
|
||||||
|
else:
|
||||||
|
# Use asyncio
|
||||||
|
async def batched_call(batch_size):
|
||||||
|
for i in range(0, len(prompts), batch_size):
|
||||||
|
tasks = []
|
||||||
|
for p in prompts[i : i + batch_size]:
|
||||||
|
tasks.append(call_generate(p, temperature=0, max_tokens=max_tokens))
|
||||||
|
rets = await asyncio.gather(*tasks)
|
||||||
|
for j in range(len(rets)):
|
||||||
|
preds[i + j] = rets[j].strip()[0]
|
||||||
|
|
||||||
|
tic = time.time()
|
||||||
|
asyncio.run(batched_call(batch_size=args.parallel))
|
||||||
|
latency = time.time() - tic
|
||||||
|
|
||||||
|
# Compute accuracy
|
||||||
|
cors = [pred == label for pred, label in zip(preds, labels)]
|
||||||
|
acc = np.mean(cors)
|
||||||
|
cors = np.array(cors)
|
||||||
|
|
||||||
|
print(
|
||||||
|
"Average accuracy {:.3f}, latency {:.2f}, #q: {} - {}".format(
|
||||||
|
acc, latency, len(prompts), subject
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return cors, acc, latency
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
subjects = sorted(
|
||||||
|
[
|
||||||
|
f.split("_test.csv")[0]
|
||||||
|
for f in os.listdir(os.path.join(args.data_dir, "test"))
|
||||||
|
if "_test.csv" in f
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
all_cors = []
|
||||||
|
all_latencies = []
|
||||||
|
num_requests = 0
|
||||||
|
|
||||||
|
# Select backend
|
||||||
|
call_generate = get_call_generate(args)
|
||||||
|
|
||||||
|
for subject in tqdm(subjects[: args.nsub]):
|
||||||
|
dev_df = pd.read_csv(
|
||||||
|
os.path.join(args.data_dir, "dev", subject + "_dev.csv"), header=None
|
||||||
|
)[: args.ntrain]
|
||||||
|
test_df = pd.read_csv(
|
||||||
|
os.path.join(args.data_dir, "test", subject + "_test.csv"), header=None
|
||||||
|
)
|
||||||
|
|
||||||
|
cors, acc, latency = evaluate(args, subject, dev_df, test_df, call_generate)
|
||||||
|
all_cors.append(cors)
|
||||||
|
all_latencies.append(latency)
|
||||||
|
num_requests += len(test_df)
|
||||||
|
|
||||||
|
total_latency = np.sum(all_latencies)
|
||||||
|
print("Total latency: {:.3f}".format(total_latency))
|
||||||
|
|
||||||
|
weighted_acc = np.mean(np.concatenate(all_cors))
|
||||||
|
print("Average accuracy: {:.3f}".format(weighted_acc))
|
||||||
|
|
||||||
|
# Write results
|
||||||
|
with open(args.result_file, "a") as fout:
|
||||||
|
value = {
|
||||||
|
"task": "mmlu",
|
||||||
|
"backend": args.backend,
|
||||||
|
"num_gpus": 1,
|
||||||
|
"latency": round(total_latency, 3),
|
||||||
|
"accuracy": round(weighted_acc, 3),
|
||||||
|
"num_requests": num_requests,
|
||||||
|
"other": {
|
||||||
|
"nsub": args.nsub,
|
||||||
|
"parallel": args.parallel,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
fout.write(json.dumps(value) + "\n")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--ntrain", type=int, default=5)
|
||||||
|
parser.add_argument("--data_dir", type=str, default="data")
|
||||||
|
parser.add_argument("--nsub", type=int, default=60)
|
||||||
|
args = add_common_other_args_and_parse(parser)
|
||||||
|
main(args)
|
||||||
|
|
@ -0,0 +1,173 @@
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
import tiktoken
|
||||||
|
|
||||||
|
from sglang.test.test_utils import (
|
||||||
|
add_common_sglang_args_and_parse,
|
||||||
|
select_sglang_backend,
|
||||||
|
)
|
||||||
|
|
||||||
|
choices = ["A", "B", "C", "D"]
|
||||||
|
|
||||||
|
tokenizer = tiktoken.encoding_for_model("gpt-3.5-turbo")
|
||||||
|
|
||||||
|
|
||||||
|
def format_subject(subject):
|
||||||
|
l = subject.split("_")
|
||||||
|
s = ""
|
||||||
|
for entry in l:
|
||||||
|
s += " " + entry
|
||||||
|
return s
|
||||||
|
|
||||||
|
|
||||||
|
def format_example(df, idx, include_answer=True):
|
||||||
|
prompt = df.iloc[idx, 0]
|
||||||
|
k = df.shape[1] - 2
|
||||||
|
for j in range(k):
|
||||||
|
prompt += "\n{}. {}".format(choices[j], df.iloc[idx, j + 1])
|
||||||
|
prompt += "\nAnswer:"
|
||||||
|
if include_answer:
|
||||||
|
prompt += " {}\n\n".format(df.iloc[idx, k + 1])
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
|
def gen_prompt(train_df, subject, k=-1):
|
||||||
|
prompt = "The following are multiple choice questions (with answers) about{}.\n\n".format(
|
||||||
|
format_subject(subject)
|
||||||
|
)
|
||||||
|
if k == -1:
|
||||||
|
k = train_df.shape[0]
|
||||||
|
for i in range(k):
|
||||||
|
prompt += format_example(train_df, i)
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
subjects = sorted(
|
||||||
|
[
|
||||||
|
f.split("_test.csv")[0]
|
||||||
|
for f in os.listdir(os.path.join(args.data_dir, "test"))
|
||||||
|
if "_test.csv" in f
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build prompts
|
||||||
|
arguments = []
|
||||||
|
labels = []
|
||||||
|
num_questions = []
|
||||||
|
|
||||||
|
for subject in subjects[: args.nsub]:
|
||||||
|
dev_df = pd.read_csv(
|
||||||
|
os.path.join(args.data_dir, "dev", subject + "_dev.csv"), header=None
|
||||||
|
)[: args.ntrain]
|
||||||
|
test_df = pd.read_csv(
|
||||||
|
os.path.join(args.data_dir, "test", subject + "_test.csv"), header=None
|
||||||
|
)
|
||||||
|
num_questions.append(test_df.shape[0])
|
||||||
|
|
||||||
|
k = args.ntrain
|
||||||
|
few_shot_examples = gen_prompt(dev_df, subject, k)
|
||||||
|
while len(tokenizer.encode(few_shot_examples)) > 1536:
|
||||||
|
k -= 1
|
||||||
|
few_shot_examples = gen_prompt(dev_df, subject, k)
|
||||||
|
|
||||||
|
for i in range(test_df.shape[0]):
|
||||||
|
prompt_end = format_example(test_df, i, include_answer=False)
|
||||||
|
|
||||||
|
arguments.append(
|
||||||
|
{
|
||||||
|
"examples": few_shot_examples,
|
||||||
|
"question": prompt_end,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
label = test_df.iloc[i, test_df.shape[1] - 1]
|
||||||
|
labels.append(label)
|
||||||
|
|
||||||
|
#####################################
|
||||||
|
######### SGL Program Begin #########
|
||||||
|
#####################################
|
||||||
|
|
||||||
|
import sglang as sgl
|
||||||
|
|
||||||
|
if args.backend.startswith("gpt-"):
|
||||||
|
|
||||||
|
@sgl.function
|
||||||
|
def few_shot_mmlu(s, examples, question):
|
||||||
|
s += sgl.user(examples + question)
|
||||||
|
s += sgl.assistant(sgl.gen("answer"))
|
||||||
|
|
||||||
|
else:
|
||||||
|
|
||||||
|
@sgl.function
|
||||||
|
def few_shot_mmlu(s, examples, question):
|
||||||
|
s += examples + question + sgl.gen("answer")
|
||||||
|
|
||||||
|
#####################################
|
||||||
|
########## SGL Program End ##########
|
||||||
|
#####################################
|
||||||
|
|
||||||
|
# Select backend
|
||||||
|
backend = select_sglang_backend(args)
|
||||||
|
|
||||||
|
# Run
|
||||||
|
tic = time.time()
|
||||||
|
states = few_shot_mmlu.run_batch(
|
||||||
|
arguments,
|
||||||
|
temperature=0,
|
||||||
|
max_new_tokens=1,
|
||||||
|
backend=backend,
|
||||||
|
num_threads=args.parallel,
|
||||||
|
progress_bar=True,
|
||||||
|
)
|
||||||
|
preds = [
|
||||||
|
s["answer"].strip()[0] if len(s["answer"].strip()) > 0 else "" for s in states
|
||||||
|
]
|
||||||
|
latency = time.time() - tic
|
||||||
|
|
||||||
|
# Compute accuracy
|
||||||
|
cors = [pred == label for pred, label in zip(preds, labels)]
|
||||||
|
|
||||||
|
pt = 0
|
||||||
|
for subject, num_qs in zip(subjects[: args.nsub], num_questions):
|
||||||
|
print(
|
||||||
|
f"subject: {subject}, #q:{num_qs}, acc: {np.mean(cors[pt: pt + num_qs]):.3f}"
|
||||||
|
)
|
||||||
|
pt += num_qs
|
||||||
|
assert pt == len(cors)
|
||||||
|
weighted_acc = np.mean(cors)
|
||||||
|
|
||||||
|
# Print results
|
||||||
|
print("Total latency: {:.3f}".format(latency))
|
||||||
|
print("Average accuracy: {:.3f}".format(weighted_acc))
|
||||||
|
|
||||||
|
# Write results
|
||||||
|
with open(args.result_file, "a") as fout:
|
||||||
|
value = {
|
||||||
|
"task": "mmlu",
|
||||||
|
"backend": args.backend,
|
||||||
|
"num_gpus": 1,
|
||||||
|
"latency": round(latency, 3),
|
||||||
|
"accuracy": round(weighted_acc, 3),
|
||||||
|
"num_requests": len(arguments),
|
||||||
|
"other": {
|
||||||
|
"nsub": args.nsub,
|
||||||
|
"parallel": args.parallel,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
fout.write(json.dumps(value) + "\n")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--ntrain", "-k", type=int, default=5)
|
||||||
|
parser.add_argument("--data_dir", "-d", type=str, default="data")
|
||||||
|
parser.add_argument("--save_dir", "-s", type=str, default="results")
|
||||||
|
parser.add_argument("--nsub", type=int, default=60)
|
||||||
|
args = add_common_sglang_args_and_parse(parser)
|
||||||
|
main(args)
|
||||||
|
|
@ -0,0 +1,2 @@
|
||||||
|
wget https://people.eecs.berkeley.edu/~hendrycks/data.tar
|
||||||
|
tar xf data.tar
|
||||||
|
|
@ -0,0 +1,30 @@
|
||||||
|
## Run evaluation
|
||||||
|
|
||||||
|
### Evaluate sglang
|
||||||
|
|
||||||
|
Host the VLM:
|
||||||
|
|
||||||
|
```
|
||||||
|
python -m sglang.launch_server --model-path Qwen/Qwen2-VL-7B-Instruct --chat-template qwen2-vl --port 30000
|
||||||
|
```
|
||||||
|
|
||||||
|
Benchmark:
|
||||||
|
|
||||||
|
```
|
||||||
|
python benchmark/mmmu/bench_sglang.py --port 30000
|
||||||
|
```
|
||||||
|
|
||||||
|
It's recommended to reduce the memory usage by appending something ike `--mem-fraction-static 0.6` to the command above.
|
||||||
|
|
||||||
|
### Evaluate hf
|
||||||
|
|
||||||
|
```
|
||||||
|
python benchmark/mmmu/bench_hf.py --model-path Qwen/Qwen2-VL-7B-Instruct
|
||||||
|
```
|
||||||
|
|
||||||
|
Some popular model results:
|
||||||
|
|
||||||
|
1. Qwen/Qwen2-VL-2B-Instruct: 0.241
|
||||||
|
2. Qwen/Qwen2-VL-7B-Instruct: 0.255
|
||||||
|
3. Qwen/Qwen2.5-VL-3B-Instruct: 0.245
|
||||||
|
4. Qwen/Qwen2.5-VL-7B-Instruct: 0.242
|
||||||
|
|
@ -0,0 +1,105 @@
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from data_utils import save_json
|
||||||
|
from eval_utils import (
|
||||||
|
EvalArgs,
|
||||||
|
eval_result,
|
||||||
|
get_sampling_params,
|
||||||
|
prepare_samples,
|
||||||
|
process_result,
|
||||||
|
)
|
||||||
|
from tqdm import tqdm
|
||||||
|
from transformers import AutoModel, AutoProcessor, GenerationConfig
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def eval_mmmu(args):
|
||||||
|
eval_args = EvalArgs.from_cli_args(args)
|
||||||
|
try:
|
||||||
|
from transformers import AutoModelForImageTextToText
|
||||||
|
|
||||||
|
model = AutoModelForImageTextToText.from_pretrained(
|
||||||
|
args.model_path,
|
||||||
|
torch_dtype="auto",
|
||||||
|
trust_remote_code=True,
|
||||||
|
)
|
||||||
|
except Exception as first_exception:
|
||||||
|
try:
|
||||||
|
model = AutoModel.from_pretrained(
|
||||||
|
args.model_path,
|
||||||
|
torch_dtype="auto",
|
||||||
|
trust_remote_code=True,
|
||||||
|
init_tts=False,
|
||||||
|
)
|
||||||
|
except Exception as second_exception:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Failed to load model: First attempt failed with {first_exception}, "
|
||||||
|
f"second attempt failed with {second_exception}"
|
||||||
|
) from second_exception
|
||||||
|
|
||||||
|
model = model.eval().cuda()
|
||||||
|
|
||||||
|
processor = AutoProcessor.from_pretrained(
|
||||||
|
args.model_path, torch_dtype="auto", device_map="auto", trust_remote_code=True
|
||||||
|
)
|
||||||
|
|
||||||
|
samples = prepare_samples(eval_args)
|
||||||
|
out_samples = dict()
|
||||||
|
|
||||||
|
sampling_params = get_sampling_params(eval_args)
|
||||||
|
generation_config = GenerationConfig(
|
||||||
|
max_new_tokens=sampling_params["max_new_tokens"],
|
||||||
|
do_sample=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
answer_dict = {}
|
||||||
|
for sample in tqdm(samples):
|
||||||
|
prompt = sample["final_input_prompt"]
|
||||||
|
image = sample["image"]
|
||||||
|
prefix = prompt.split("<")[0]
|
||||||
|
suffix = prompt.split(">")[1]
|
||||||
|
assert image is not None
|
||||||
|
contents = []
|
||||||
|
if prefix:
|
||||||
|
contents += [{"type": "text", "text": prefix}]
|
||||||
|
contents += [
|
||||||
|
{
|
||||||
|
"type": "image",
|
||||||
|
"image": sample["image_path"],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
if suffix:
|
||||||
|
contents += [{"type": "text", "text": suffix}]
|
||||||
|
messages = [{"role": "user", "content": contents}]
|
||||||
|
model_inputs = processor.apply_chat_template(
|
||||||
|
messages,
|
||||||
|
tokenize=True,
|
||||||
|
return_dict=True,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
).to(model.device)
|
||||||
|
input_len = model_inputs["input_ids"].shape[-1]
|
||||||
|
generation = model.generate(**model_inputs, generation_config=generation_config)
|
||||||
|
generation = generation[0][input_len:]
|
||||||
|
response = processor.decode(generation, skip_special_tokens=True)
|
||||||
|
print(f"response: {response}")
|
||||||
|
process_result(response, sample, answer_dict, out_samples)
|
||||||
|
|
||||||
|
args.output_path = f"{args.model_path}_val_hf.json"
|
||||||
|
save_json(args.output_path, out_samples)
|
||||||
|
eval_result(model_answer_path=args.output_path, answer_dict=answer_dict)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--model-path",
|
||||||
|
type=str,
|
||||||
|
help="The path of the model weights. This can be a local folder or a Hugging Face repo ID.",
|
||||||
|
required=True,
|
||||||
|
)
|
||||||
|
EvalArgs.add_cli_args(parser)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
eval_mmmu(args)
|
||||||
|
|
@ -0,0 +1,93 @@
|
||||||
|
"""
|
||||||
|
Bench the sglang-hosted vLM with benchmark MMMU
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
Host the VLM: python -m sglang.launch_server --model-path Qwen/Qwen2-VL-7B-Instruct --chat-template qwen2-vl --port 30000
|
||||||
|
|
||||||
|
Benchmark: python benchmark/mmmu/bench_sglang.py --port 30000
|
||||||
|
|
||||||
|
The eval output will be logged
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import time
|
||||||
|
|
||||||
|
import openai
|
||||||
|
from data_utils import save_json
|
||||||
|
from eval_utils import (
|
||||||
|
EvalArgs,
|
||||||
|
eval_result,
|
||||||
|
get_sampling_params,
|
||||||
|
prepare_samples,
|
||||||
|
process_result,
|
||||||
|
)
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from sglang.test.test_utils import add_common_sglang_args_and_parse
|
||||||
|
|
||||||
|
|
||||||
|
def eval_mmmu(args):
|
||||||
|
eval_args = EvalArgs.from_cli_args(args)
|
||||||
|
|
||||||
|
out_samples = dict()
|
||||||
|
|
||||||
|
sampling_params = get_sampling_params(eval_args)
|
||||||
|
|
||||||
|
samples = prepare_samples(eval_args)
|
||||||
|
|
||||||
|
answer_dict = {}
|
||||||
|
|
||||||
|
# had to use an openai server, since SglImage doesn't support image data
|
||||||
|
client = openai.Client(api_key="sk", base_url=f"http://127.0.0.1:{args.port}/v1")
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
|
for i, sample in enumerate(tqdm(samples)):
|
||||||
|
prompt = sample["final_input_prompt"]
|
||||||
|
prefix = prompt.split("<")[0]
|
||||||
|
suffix = prompt.split(">")[1]
|
||||||
|
image = sample["image"]
|
||||||
|
assert image is not None
|
||||||
|
image_path = sample["image_path"]
|
||||||
|
# TODO: batch
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model="default",
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": prefix,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {"url": image_path},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": suffix,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
temperature=0,
|
||||||
|
max_completion_tokens=sampling_params["max_new_tokens"],
|
||||||
|
max_tokens=sampling_params["max_new_tokens"],
|
||||||
|
)
|
||||||
|
response = response.choices[0].message.content
|
||||||
|
process_result(response, sample, answer_dict, out_samples)
|
||||||
|
|
||||||
|
print(f"Benchmark time: {time.time() - start}")
|
||||||
|
|
||||||
|
args.output_path = f"./val_sglang.json"
|
||||||
|
save_json(args.output_path, out_samples)
|
||||||
|
eval_result(model_answer_path=args.output_path, answer_dict=answer_dict)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
args = add_common_sglang_args_and_parse(parser)
|
||||||
|
EvalArgs.add_cli_args(parser)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
eval_mmmu(args)
|
||||||
|
|
@ -0,0 +1,221 @@
|
||||||
|
"""Utils for data load, save, and process (e.g., prompt construction)"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
DOMAIN_CAT2SUB_CAT = {
|
||||||
|
"Art and Design": ["Art", "Art_Theory", "Design", "Music"],
|
||||||
|
"Business": ["Accounting", "Economics", "Finance", "Manage", "Marketing"],
|
||||||
|
"Science": [
|
||||||
|
"Biology",
|
||||||
|
"Chemistry",
|
||||||
|
"Geography",
|
||||||
|
"Math",
|
||||||
|
"Physics",
|
||||||
|
],
|
||||||
|
"Health and Medicine": [
|
||||||
|
"Basic_Medical_Science",
|
||||||
|
"Clinical_Medicine",
|
||||||
|
"Diagnostics_and_Laboratory_Medicine",
|
||||||
|
"Pharmacy",
|
||||||
|
"Public_Health",
|
||||||
|
],
|
||||||
|
"Humanities and Social Science": [
|
||||||
|
"History",
|
||||||
|
"Literature",
|
||||||
|
"Sociology",
|
||||||
|
"Psychology",
|
||||||
|
],
|
||||||
|
"Tech and Engineering": [
|
||||||
|
"Agriculture",
|
||||||
|
"Architecture_and_Engineering",
|
||||||
|
"Computer_Science",
|
||||||
|
"Electronics",
|
||||||
|
"Energy_and_Power",
|
||||||
|
"Materials",
|
||||||
|
"Mechanical_Engineering",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
CAT_SHORT2LONG = {
|
||||||
|
"acc": "Accounting",
|
||||||
|
"agri": "Agriculture",
|
||||||
|
"arch": "Architecture_and_Engineering",
|
||||||
|
"art": "Art",
|
||||||
|
"art_theory": "Art_Theory",
|
||||||
|
"bas_med": "Basic_Medical_Science",
|
||||||
|
"bio": "Biology",
|
||||||
|
"chem": "Chemistry",
|
||||||
|
"cli_med": "Clinical_Medicine",
|
||||||
|
"cs": "Computer_Science",
|
||||||
|
"design": "Design",
|
||||||
|
"diag_med": "Diagnostics_and_Laboratory_Medicine",
|
||||||
|
"econ": "Economics",
|
||||||
|
"elec": "Electronics",
|
||||||
|
"ep": "Energy_and_Power",
|
||||||
|
"fin": "Finance",
|
||||||
|
"geo": "Geography",
|
||||||
|
"his": "History",
|
||||||
|
"liter": "Literature",
|
||||||
|
"manage": "Manage",
|
||||||
|
"mark": "Marketing",
|
||||||
|
"mate": "Materials",
|
||||||
|
"math": "Math",
|
||||||
|
"mech": "Mechanical_Engineering",
|
||||||
|
"music": "Music",
|
||||||
|
"phar": "Pharmacy",
|
||||||
|
"phys": "Physics",
|
||||||
|
"psy": "Psychology",
|
||||||
|
"pub_health": "Public_Health",
|
||||||
|
"socio": "Sociology",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# DATA SAVING
|
||||||
|
def save_json(filename, ds):
|
||||||
|
with open(filename, "w") as f:
|
||||||
|
json.dump(ds, f, indent=4)
|
||||||
|
|
||||||
|
|
||||||
|
def get_multi_choice_info(options):
|
||||||
|
"""
|
||||||
|
Given the list of options for multiple choice question
|
||||||
|
Return the index2ans and all_choices
|
||||||
|
"""
|
||||||
|
|
||||||
|
start_chr = "A"
|
||||||
|
all_choices = []
|
||||||
|
index2ans = {}
|
||||||
|
for i, option in enumerate(options):
|
||||||
|
index2ans[chr(ord(start_chr) + i)] = option
|
||||||
|
all_choices.append(chr(ord(start_chr) + i))
|
||||||
|
|
||||||
|
return index2ans, all_choices
|
||||||
|
|
||||||
|
|
||||||
|
def load_yaml(file_path):
|
||||||
|
with open(file_path, "r") as stream:
|
||||||
|
try:
|
||||||
|
yaml_dict = yaml.safe_load(stream)
|
||||||
|
except yaml.YAMLError as exc:
|
||||||
|
print(exc)
|
||||||
|
|
||||||
|
return yaml_dict
|
||||||
|
|
||||||
|
|
||||||
|
def parse_img_path(text):
|
||||||
|
matches = re.findall("<img='(.*?)'>", text)
|
||||||
|
return matches
|
||||||
|
|
||||||
|
|
||||||
|
def process_single_sample(data):
|
||||||
|
question = data["question"]
|
||||||
|
o_imgs_paths = []
|
||||||
|
for option in data["options"]:
|
||||||
|
current_o_imgs_paths = parse_img_path(option)
|
||||||
|
for img_path in current_o_imgs_paths:
|
||||||
|
o_imgs_paths.append(img_path)
|
||||||
|
|
||||||
|
if len(o_imgs_paths) > 1: # multiple images in options, used for random selection
|
||||||
|
return {
|
||||||
|
"id": data["id"],
|
||||||
|
"question": question,
|
||||||
|
"options": data["options"],
|
||||||
|
"answer": data["answer"],
|
||||||
|
"image": None,
|
||||||
|
"question_type": data["question_type"],
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
return {
|
||||||
|
"id": data["id"],
|
||||||
|
"question": question,
|
||||||
|
"options": data["options"],
|
||||||
|
"answer": data["answer"],
|
||||||
|
"image": data["image_1"],
|
||||||
|
"question_type": data["question_type"],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# DATA SAVING
|
||||||
|
def save_json(filename, ds):
|
||||||
|
print(f"answers saved to: {filename}")
|
||||||
|
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
||||||
|
with open(filename, "w") as f:
|
||||||
|
json.dump(ds, f, indent=4)
|
||||||
|
|
||||||
|
|
||||||
|
def save_jsonl(filename, data):
|
||||||
|
"""
|
||||||
|
Save a dictionary of data to a JSON Lines file with the filename as key and caption as value.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename (str): The path to the file where the data should be saved.
|
||||||
|
data (dict): The dictionary containing the data to save where key is the image path and value is the caption.
|
||||||
|
"""
|
||||||
|
with open(filename, "w", encoding="utf-8") as f:
|
||||||
|
for img_path, caption in data.items():
|
||||||
|
# Extract the base filename without the extension
|
||||||
|
base_filename = os.path.basename(img_path)
|
||||||
|
# Create a JSON object with the filename as the key and caption as the value
|
||||||
|
json_record = json.dumps({base_filename: caption}, ensure_ascii=False)
|
||||||
|
# Write the JSON object to the file, one per line
|
||||||
|
f.write(json_record + "\n")
|
||||||
|
|
||||||
|
|
||||||
|
def save_args(args, path_dir):
|
||||||
|
argsDict = args.__dict__
|
||||||
|
with open(path_dir + "setting.txt", "w") as f:
|
||||||
|
f.writelines("------------------ start ------------------" + "\n")
|
||||||
|
for eachArg, value in argsDict.items():
|
||||||
|
f.writelines(eachArg + " : " + str(value) + "\n")
|
||||||
|
f.writelines("------------------- end -------------------")
|
||||||
|
|
||||||
|
|
||||||
|
# DATA PROCESSING
|
||||||
|
def construct_prompt(sample, config):
|
||||||
|
question = sample["question"]
|
||||||
|
options = eval(sample["options"])
|
||||||
|
example = ""
|
||||||
|
if sample["question_type"] == "multiple-choice":
|
||||||
|
start_chr = "A"
|
||||||
|
prediction_range = []
|
||||||
|
index2ans = {}
|
||||||
|
for option in options:
|
||||||
|
prediction_range.append(start_chr)
|
||||||
|
example += f"({start_chr}) {option}\n"
|
||||||
|
index2ans[start_chr] = option
|
||||||
|
start_chr = chr(ord(start_chr) + 1)
|
||||||
|
empty_prompt_sample_structure = config["multi_choice_example_format"]
|
||||||
|
empty_prompt = empty_prompt_sample_structure.format(question, example)
|
||||||
|
res_dict = {}
|
||||||
|
res_dict["index2ans"] = index2ans
|
||||||
|
res_dict["correct_choice"] = sample["answer"]
|
||||||
|
res_dict["all_choices"] = prediction_range
|
||||||
|
res_dict["empty_prompt"] = empty_prompt
|
||||||
|
if config["task_instructions"]:
|
||||||
|
res_dict["final_input_prompt"] = (
|
||||||
|
config["task_instructions"].strip() + "\n\n" + empty_prompt
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
res_dict["final_input_prompt"] = empty_prompt
|
||||||
|
|
||||||
|
res_dict["gt_content"] = options[ord(sample["answer"].upper()) - ord("A")]
|
||||||
|
else:
|
||||||
|
empty_prompt_sample_structure = config["short_ans_example_format"]
|
||||||
|
empty_prompt = empty_prompt_sample_structure.format(question)
|
||||||
|
res_dict = {}
|
||||||
|
res_dict["empty_prompt"] = empty_prompt
|
||||||
|
if config["task_instructions"]:
|
||||||
|
res_dict["final_input_prompt"] = (
|
||||||
|
config["task_instructions"].strip() + "\n\n" + empty_prompt
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
res_dict["final_input_prompt"] = empty_prompt
|
||||||
|
res_dict["gt_content"] = sample["answer"]
|
||||||
|
|
||||||
|
res_dict.update(sample)
|
||||||
|
return res_dict
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue