sglang_v0.5.2/sglang/docs/advanced_features/attention_backend.md

5.3 KiB

Attention Backend

SGLang supports multiple attention backends. Each of them has different pros and cons. You can test them according to your needs.

Supporting matrix for different attention backends

Backend Page Size > 1 Spec Decoding MLA Sliding Window MultiModal
FlashInfer
FA3
Triton
Torch Native
FlashMLA
TRTLLM MLA
Ascend
Wave

Notes:

  • TRTLLM MLA only implements decode operations. For prefill operations (including multimodal inputs), it falls back to FlashInfer MLA backend.

Note: Every kernel backend is compatible with a page size > 1 by specifying an argument such as --page-size 16. This is because a page size of 16 can be converted to a page size of 1 in the kernel backend. The "" and "" symbols in the table above under "Page Size > 1" indicate whether the kernel actually operates with a page size greater than 1, rather than treating a page size of 16 as a page size of 1.

User guide

Launch command for different attention backends.

  • FlashInfer (Default for Non-Hopper Machines, e.g., A100, A40)
python3 -m sglang.launch_server --model meta-llama/Meta-Llama-3.1-8B-Instruct --attention-backend flashinfer
python3 -m sglang.launch_server --tp 8 --model deepseek-ai/DeepSeek-V3 --attention-backend flashinfer --trust-remote-code
  • FlashAttention 3 (Default for Hopper Machines, e.g., H100, H200, H20)
python3 -m sglang.launch_server --model meta-llama/Meta-Llama-3.1-8B-Instruct --attention-backend fa3
python3 -m sglang.launch_server --tp 8 --model deepseek-ai/DeepSeek-V3 --trust-remote-code --attention-backend fa3
  • Triton
python3 -m sglang.launch_server --model meta-llama/Meta-Llama-3.1-8B-Instruct --attention-backend triton
python3 -m sglang.launch_server --tp 8 --model deepseek-ai/DeepSeek-V3 --attention-backend triton --trust-remote-code
  • Torch Native
python3 -m sglang.launch_server --model meta-llama/Meta-Llama-3.1-8B-Instruct --attention-backend torch_native
  • FlashMLA
python3 -m sglang.launch_server --tp 8 --model deepseek-ai/DeepSeek-R1 --attention-backend flashmla --trust-remote-code
python3 -m sglang.launch_server --tp 8 --model deepseek-ai/DeepSeek-R1 --attention-backend flashmla --kv-cache-dtype fp8_e4m3 --trust-remote-code
  • TRTLLM MLA (Optimized for Blackwell Architecture, e.g., B200)
python3 -m sglang.launch_server --tp 8 --model deepseek-ai/DeepSeek-R1 --attention-backend trtllm_mla --trust-remote-code
  • TRTLLM MLA with FP8 KV Cache (Higher concurrency, lower memory footprint)
python3 -m sglang.launch_server --tp 8 --model deepseek-ai/DeepSeek-R1 --attention-backend trtllm_mla --kv-cache-dtype fp8_e4m3 --trust-remote-code
  • Ascend
python3 -m sglang.launch_server --model meta-llama/Meta-Llama-3.1-8B-Instruct --attention-backend ascend
  • Wave
python3 -m sglang.launch_server --model meta-llama/Meta-Llama-3.1-8B-Instruct --attention-backend wave

Steps to add a new attention backend

To add a new attention backend, you can learn from the existing backends (python/sglang/srt/layers/attention/triton_backend.py, python/sglang/srt/layers/attention/flashattention_backend.py) and follow the steps below.

  1. Run without cuda graph. Support the two forward functions
    • forward_extend
      • Will be used for prefill, prefill with KV cache, and target verification
      • It will be called once per layer
    • forward_decode
      • Will be used for normal decode, and draft decode
      • It will be called once per layer
    • init_forward_metadata
      • Initialize the class and common metadata shared by all layers
      • Call the plan function for optimizations like split_kv
      • It will be called once per forward
  2. Run with cuda graph. It has two phases (capture and replay) and you need to implement three functions
    • init_cuda_graph_state
      • It will be called once during life time
      • Create all common shared buffers
    • init_forward_metadata_capture_cuda_graph
      • It will be called before capturing a cuda graph
      • It is similar to init_forward_metadata but write the medatada to some pre-defined buffers
    • init_forward_metadata_replay_cuda_graph
      • It will be called before replaying a cuda graph
      • This function is in the critical path and needs to be fast