215 lines
11 KiB
ReStructuredText
215 lines
11 KiB
ReStructuredText
.. _kv-layout:
|
|
|
|
KV-Cache Layout in FlashInfer
|
|
=============================
|
|
|
|
Layout: NHD/HND
|
|
---------------
|
|
|
|
FlashInfer provides two layouts for last 3 dimensions in KV-Cache: ``NHD`` and ``HND``:
|
|
|
|
- ``NHD``: the last 3 dimensions are organized as ``(seq_len, num_heads, head_dim)``.
|
|
- ``HND``: the last 3 dimensions are organized as ``(num_heads, seq_len, head_dim)``.
|
|
|
|
The ``NHD`` layout is more natural because it's consistent with the output of
|
|
:math:`xW_k` and :math:`xW_v` without transpose. The ``HND`` layout is more friendly
|
|
for GPU implementation when KV-Cache uses low-precision data type (e.g. fp8).
|
|
In practice we don't observe significant performance difference between these two layouts
|
|
on fp16 kV-Cache and we prioritize ``NHD`` layout for better readability. FlashInfer implements
|
|
Attention kernels on both layouts and we provide an option to select between them (``NHD``
|
|
by default).
|
|
|
|
.. _ragged-layout:
|
|
|
|
Ragged Tensor
|
|
-------------
|
|
|
|
We use Ragged Tensor to store the variable length Q/K/V tensors in FlashInfer for batch prefill self-attention:
|
|
|
|
.. image:: https://raw.githubusercontent.com/flashinfer-ai/web-data/main/tutorials/ragged.png
|
|
:width: 400
|
|
:align: center
|
|
:alt: Data structure of Ragged KV-Cache.
|
|
|
|
In Ragged Tensor, all requests's Q/K/V are packed into a single ``data`` tensor without padding,
|
|
we use a ``indptr`` array (``num_requests+1`` elements, the first element is always zero)
|
|
to store the information of variable sequence lengths of each request
|
|
(``indptr[i+1]-indptr[i]`` is the sequence length of request ``i``), the ``data`` tensor has
|
|
shape ``(indptr[-1], num_heads, head_dim)`` when the layout is ``NHD``.
|
|
|
|
We can use ``data[indptr[i]:indptr[i+1]]`` to slice the keys (or values) of request ``i``.
|
|
|
|
.. note::
|
|
``indptr`` arrays across the flashinfer library should be of type ``int32``. Arrays of type ``int64`` can cause indexing errors.
|
|
|
|
FlashInfer APIs
|
|
~~~~~~~~~~~~~~~
|
|
|
|
FlashInfer provides :class:`flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper` to compute
|
|
the prefill attention between queries stored in ragged tensor and keys/values stored in ragged
|
|
KV-Cache.
|
|
|
|
.. _mask-layout:
|
|
|
|
Mask Layout (2D Ragged Tensor)
|
|
------------------------------
|
|
|
|
The aforementioned Ragged Tensor can be generalized to multiple "ragged" dimensions. For example,
|
|
the attention mask in FlashInfer is a 2D ragged tensor for batch size greater than 1:
|
|
|
|
.. image:: https://raw.githubusercontent.com/flashinfer-ai/web-data/main/tutorials/mask-layout.png
|
|
:width: 800
|
|
:align: center
|
|
:alt: Data structure of Mask Layout.
|
|
|
|
When number of requests is greater than 1, different requests might have different query length and kv length.
|
|
To avoid padding, we use a 2D ragged tensor to store attention mask. The input ``qo_indptr`` and
|
|
``kv_indptr`` arrays (both with length ``num_requests+1``) are used to store the information of
|
|
variable sequence lengths of each request,
|
|
``qo_indptr[i+1]-qo_indptr[i]`` is the query length of request ``i`` (``qo_len[i]``),
|
|
``kv_indptr[i+1]-kv_indptr[i]`` is the kv length of request ``i`` (``kv_len[i]``).
|
|
|
|
The mask arrays of all requests are flattened (with query as the first dimension, and kv as last dimension)
|
|
and concatenated into a single 1D array: ``mask_data``. FlashInfer will create a ``mask_indptr`` array implicitly
|
|
to store the start offset of each request's mask in the flattened mask array: ``mask_indptr[1:] = cumsum(qo_len * kv_len)``.
|
|
|
|
``mask_data`` has shape ``(mask_indptr[-1],)``, we can use ``mask_data[mask_indptr[i]:mask_indptr[i+1]]`` to slice the flattened
|
|
mask of request ``i``.
|
|
|
|
To save memory, we can further pack the flattened boolean mask array into a bit-packed array (1 bit per element, 8 elements
|
|
are packed together as a `uint8`) with "little" bit-order (see `numpy.packbits <https://numpy.org/doc/stable/reference/generated/numpy.packbits.html>`_
|
|
for more details). FlashInfer accepts both boolean mask and bit-packed mask. If boolean mask is provided, FlashInfer will pack it into bit-packed
|
|
array internally.
|
|
|
|
FlashInfer APIs
|
|
~~~~~~~~~~~~~~~
|
|
|
|
:class:`flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper` and :class:`flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper`
|
|
allow user to specify ``qo_indptr``, ``kv_indptr`` and custom attention mask ``custom_mask`` in ``begin_forward`` functions,
|
|
the mask data will be added to the attention score before softmax (and after softmax scaling) in the attention kernel.
|
|
|
|
:meth:`flashinfer.quantization.packbits` and :meth:`flashinfer.quantization.segment_packbits` are the utility functions
|
|
to pack boolean mask into bit-packed array.
|
|
|
|
.. _page-layout:
|
|
|
|
Page Table Layout
|
|
-----------------
|
|
|
|
When KV-Cache is dynamic (e.g. in append or decode stage), packing all keys/values is not
|
|
efficient because the sequence length per request changes over time. `vLLM <https://arxiv.org/pdf/2309.06180.pdf>`_
|
|
proposes to organize KV-Cache as a Page Table. In FlashInfer, we treat the page table as
|
|
a block sparse matrix (each used page can be viewed as an non-zero block in block sparse matrix)
|
|
and uses the `CSR format <https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.csr_matrix.html>`_
|
|
to index the pages in KV-Cache.
|
|
|
|
.. image:: https://raw.githubusercontent.com/flashinfer-ai/web-data/main/tutorials/page_layout.png
|
|
:width: 800
|
|
:align: center
|
|
:alt: Data structure of Paged KV-Cache.
|
|
|
|
For each request, we keep a record of its ``page_indices``, ``last_page_len`` which
|
|
tracks the pages used by this request and the number of entries in the last page. The KV
|
|
sequence length of request ``i`` is ``page_size * (len(page_indices[i]) - 1) + last_page_length[i]``.
|
|
|
|
.. note::
|
|
The ``last_page_len`` of each request must be greater than zero, and less than or equal to ``page_size``.
|
|
|
|
The overall ``kv_indptr`` array (with length ``num_requests+1``) can be computed as:
|
|
``[0, len(page_indices[0]), len(page_indices[0])+len(page_indices[1]), ...]``.
|
|
The overall ``kv_page_indices`` array (with length ``kv_indptr[-1]``) is the concatenation of all requests' ``page_indices``.
|
|
The overall ``kv_last_page_lens`` array (with length ``num_requests``) is the concatenation of all requests' ``last_page_length``.
|
|
|
|
The ``kv_data`` tensor could either be a single 5-D tensor or a tuple of 4-D tensors,
|
|
when stored in a single tensor, ``kv_data`` has shape:
|
|
|
|
.. code:: python
|
|
|
|
kv_cache_nhd = torch.empty(max_num_pages, 2, page_size, num_heads, head_dim, dtype=torch.bfloat16) # NHD layout
|
|
kv_cache_hnd = torch.empty(max_num_pages, 2, num_heads, page_size, head_dim, dtype=torch.bfloat16) # HND layout
|
|
|
|
when stored in a tuple of tensors, ``kv_data = (k_data, v_data)``, and each one of them has shape:
|
|
|
|
.. code:: python
|
|
|
|
k_cache_nhd = torch.empty(max_num_pages, page_size, num_heads, head_dim, dtype=torch.bfloat16) # NHD layout
|
|
k_cache_hnd = torch.empty(max_num_pages, num_heads, page_size, head_dim, dtype=torch.bfloat16) # HND layout
|
|
v_cache_nhd = torch.empty(max_num_pages, page_size, num_heads, head_dim, dtype=torch.bfloat16) # NHD layout
|
|
v_cache_hnd = torch.empty(max_num_pages, num_heads, page_size, head_dim, dtype=torch.bfloat16) # HND layout
|
|
|
|
|
|
where ``max_num_pages`` is the maximum number of pages used by all requests, ``page_size`` is the number of tokens
|
|
we fit into each page. ``2`` in single tensor storage means K/V (first one for keys, the second one for values).
|
|
|
|
.. note::
|
|
``indptr`` arrays across the flashinfer library should be of type ``int32``. Arrays of type ``int64`` can cause indexing errors. This is also true of the ``kv_page_indices`` and ``kv_last_page_lens`` arrays.
|
|
|
|
.. _mla-page-layout:
|
|
|
|
Multi-head Latent Attention Page Layout
|
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
Multi-head Latent Attention (MLA) is a new attention mechanism proposed in `DeepSeek v2 <https://arxiv.org/abs/2405.04434>`_ and was
|
|
used in later DeepSeek models. MLA unifies key cache and value cache into a single tensor, so there is no need to store them separately.
|
|
Compared to multi-head attention or grouped query attention, the KV-Cache of MLA do not have the ``num_heads`` dimension,
|
|
so there is no distinction like ``NHD`` and ``HND`` layout.
|
|
|
|
MLA separates RoPE (Rotary Positional Encoding) dimensions and other head dimensions. We use ``kpe`` (key w/ positional encoding) and ``ckv`` (compressed key/value)
|
|
to name these two components. User can store them in a single Paged KV-Cache:
|
|
|
|
.. code:: python
|
|
|
|
head_dim_ckv = 512
|
|
head_dim_kpe = 64
|
|
mla_paged_kv_cache = torch.empty(max_num_pages, page_size, head_dim_ckv + head_dim_kpe, dtype=torch.bfloat16)
|
|
ckv = mla_paged_kv_cache[:, :, :head_dim_ckv] # Slicing here does not copy or move data
|
|
kpe = mla_paged_kv_cache[:, :, head_dim_ckv:] # Slicing here does not copy or move data
|
|
|
|
|
|
and ``ckv`` and ``kpe`` can then be fed into the MLA attention kernel :class:`flashinfer.mla.BatchMLAPagedAttentionWrapper`.
|
|
|
|
FlashInfer APIs
|
|
~~~~~~~~~~~~~~~
|
|
|
|
:meth:`flashinfer.page.append_paged_kv_cache` can append a batch of keys/values (stored as ragged tensors) to the paged KV-Cache
|
|
(the pages for these appended keys/values must be allocated prior to calling this API).
|
|
|
|
:class:`flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper` and :class:`flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper` implements the decode attention
|
|
and prefill/append attention between queries stored in ragged tensors and keys/values stored in paged KV-Cache.
|
|
|
|
.. _cascade-inference-data-layout:
|
|
|
|
Multi-level Cascade Inference Data Layout
|
|
-----------------------------------------
|
|
|
|
When using multi-level `cascade inference <https://flashinfer.ai/2024/02/02/cascade-inference.html>`_,
|
|
the query and output are stored in ragged tensors, and KV-Cache of all levels are stored
|
|
in a unified Paged KV-Cache. Each level has a unique ``qo_indptr`` array which is the prefix sum of the
|
|
accumulated number of tokens to append in the subtree, as well as ``kv_page_indptr``, ``kv_page_indices``, and
|
|
``kv_last_page_len`` which has same semantics as in :ref:`page-layout` section. The following figure
|
|
introduces how to construct these data structures for append attention operation for 8 requests where we
|
|
treat their KV-Cache as 3 levels for prefix reuse:
|
|
|
|
.. image:: https://raw.githubusercontent.com/flashinfer-ai/web-data/main/tutorials/cascade_inference_data_layout.png
|
|
:width: 800
|
|
:align: center
|
|
:alt: Cascade inference data layout.
|
|
|
|
Note that we don't have to change the data layout of ragged query/output tensor or paged kv-cache for each level.
|
|
All levels share the same underlying data layout, but we use different ``qo_indptr`` / ``kv_page_indptr`` arrays
|
|
so that we can view them in different ways.
|
|
|
|
FlashInfer APIs
|
|
~~~~~~~~~~~~~~~
|
|
|
|
FlashInfer provides :class:`flashinfer.cascade.MultiLevelCascadeAttentionWrapper` to compute
|
|
the cascade attention.
|
|
|
|
FAQ
|
|
---
|
|
|
|
How do FlashInfer manages KV-Cache?
|
|
FlashInfer itself is not responsible for managing the page-table (pop and allocate new pages, etc.) and we leave the strategy
|
|
to the user: different serving engines might have different strategies to manage the page-table. FlashInfer is only responsible
|
|
for computing the attention between queries and keys/values stored in KV-Cache.
|