sglang_v0.5.2/flashinfer_0.3.1/flashinfer/rope.py

1221 lines
40 KiB
Python

"""
Copyright (c) 2024 by FlashInfer team.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import functools
from typing import Optional, Tuple
import torch
from .jit import JitSpec
from .jit import env as jit_env
from .jit import gen_jit_spec
from .utils import register_custom_op, register_fake_op
def gen_rope_module() -> JitSpec:
return gen_jit_spec(
"rope",
[
jit_env.FLASHINFER_CSRC_DIR / "rope.cu",
jit_env.FLASHINFER_CSRC_DIR / "flashinfer_rope_ops.cu",
],
)
@functools.cache
def get_rope_module():
return gen_rope_module().build_and_load()
@register_custom_op("flashinfer::apply_rope", mutates_args=("q_rope", "k_rope"))
def _apply_rope(
q: torch.Tensor,
k: torch.Tensor,
q_rope: torch.Tensor,
k_rope: torch.Tensor,
indptr: torch.Tensor,
offsets: torch.Tensor,
rotary_dim: int,
interleave: bool,
rope_scale: float,
rope_theta: float,
) -> None:
get_rope_module().apply_rope(
q,
k,
q_rope,
k_rope,
indptr,
offsets,
rotary_dim,
interleave,
rope_scale,
rope_theta,
)
@register_fake_op("flashinfer::apply_rope")
def _fake_apply_rope(
q: torch.Tensor,
k: torch.Tensor,
q_rope: torch.Tensor,
k_rope: torch.Tensor,
indptr: torch.Tensor,
offsets: torch.Tensor,
rotary_dim: int,
interleave: bool,
rope_scale: float,
rope_theta: float,
) -> None:
pass
@register_custom_op("flashinfer::apply_llama31_rope", mutates_args=("q_rope", "k_rope"))
def _apply_llama31_rope(
q: torch.Tensor,
k: torch.Tensor,
q_rope: torch.Tensor,
k_rope: torch.Tensor,
indptr: torch.Tensor,
offsets: torch.Tensor,
rotary_dim: int,
interleave: bool,
rope_scale: float,
rope_theta: float,
low_freq_factor: float,
high_freq_factor: float,
old_context_len: float,
) -> None:
get_rope_module().apply_llama31_rope(
q,
k,
q_rope,
k_rope,
indptr,
offsets,
rotary_dim,
interleave,
rope_scale,
rope_theta,
low_freq_factor,
high_freq_factor,
old_context_len,
)
@register_fake_op("flashinfer::apply_llama31_rope")
def _fake_apply_llama31_rope(
q: torch.Tensor,
k: torch.Tensor,
q_rope: torch.Tensor,
k_rope: torch.Tensor,
indptr: torch.Tensor,
offsets: torch.Tensor,
rotary_dim: int,
interleave: bool,
rope_scale: float,
rope_theta: float,
low_freq_factor: float,
high_freq_factor: float,
old_context_len: float,
) -> None:
pass
@register_custom_op("flashinfer::apply_rope_pos_ids", mutates_args=("q_rope", "k_rope"))
def _apply_rope_pos_ids(
q: torch.Tensor,
k: torch.Tensor,
q_rope: torch.Tensor,
k_rope: torch.Tensor,
pos_ids: torch.Tensor,
rotary_dim: int,
interleave: bool,
rope_scale: float,
rope_theta: float,
) -> None:
get_rope_module().apply_rope_pos_ids(
q,
k,
q_rope,
k_rope,
pos_ids,
rotary_dim,
interleave,
rope_scale,
rope_theta,
)
@register_fake_op("flashinfer::apply_rope_pos_ids")
def _fake_apply_rope_pos_ids(
q: torch.Tensor,
k: torch.Tensor,
q_rope: torch.Tensor,
k_rope: torch.Tensor,
pos_ids: torch.Tensor,
rotary_dim: int,
interleave: bool,
rope_scale: float,
rope_theta: float,
) -> None:
pass
@register_custom_op(
"flashinfer::mla_rope_quantize",
mutates_args=("q_rope_out", "k_rope_out", "q_nope_out", "k_nope_out"),
)
def _mla_rope_quantize(
q_rope_in: torch.Tensor,
k_rope_in: torch.Tensor,
q_nope_in: torch.Tensor,
k_nope_in: torch.Tensor,
cos_sin_cache: torch.Tensor,
pos_ids: torch.Tensor,
q_rope_out: torch.Tensor,
k_rope_out: torch.Tensor,
q_nope_out: torch.Tensor,
k_nope_out: torch.Tensor,
quant_scale_q: float,
quant_scale_kv: float,
interleave: bool,
) -> None:
get_rope_module().mla_rope_quantize(
q_rope_in,
k_rope_in,
q_nope_in,
k_nope_in,
q_rope_out,
k_rope_out,
q_nope_out,
k_nope_out,
cos_sin_cache,
pos_ids,
quant_scale_q,
quant_scale_kv,
interleave,
)
@register_fake_op("flashinfer::mla_rope_quantize")
def _fake_mla_rope_quantize(
q_rope_in: torch.Tensor,
k_rope_in: torch.Tensor,
q_nope_in: torch.Tensor,
k_nope_in: torch.Tensor,
cos_sin_cache: torch.Tensor,
pos_ids: torch.Tensor,
q_rope_out: torch.Tensor,
k_rope_out: torch.Tensor,
q_nope_out: torch.Tensor,
k_nope_out: torch.Tensor,
quant_scale_q: float,
quant_scale_kv: float,
interleave: bool,
) -> None:
pass
@register_custom_op(
"flashinfer::apply_rope_pos_ids_cos_sin_cache", mutates_args=("q_rope", "k_rope")
)
def _apply_rope_pos_ids_cos_sin_cache(
q: torch.Tensor,
k: torch.Tensor,
q_rope: torch.Tensor,
k_rope: torch.Tensor,
cos_sin_cache: torch.Tensor,
pos_ids: torch.Tensor,
interleave: bool,
) -> None:
get_rope_module().apply_rope_pos_ids_cos_sin_cache(
q,
k,
q_rope,
k_rope,
cos_sin_cache,
pos_ids,
interleave,
)
@register_fake_op("flashinfer::apply_rope_pos_ids_cos_sin_cache")
def _fake_apply_rope_pos_ids_cos_sin_cache(
q: torch.Tensor,
k: torch.Tensor,
q_rope: torch.Tensor,
k_rope: torch.Tensor,
cos_cache: torch.Tensor,
sin_cache: torch.Tensor,
pos_ids: torch.Tensor,
interleave: bool,
) -> None:
pass
@register_custom_op(
"flashinfer::apply_llama31_rope_pos_ids", mutates_args=("q_rope", "k_rope")
)
def _apply_llama31_rope_pos_ids(
q: torch.Tensor,
k: torch.Tensor,
q_rope: torch.Tensor,
k_rope: torch.Tensor,
pos_ids: torch.Tensor,
rotary_dim: int,
interleave: bool,
rope_scale: float,
rope_theta: float,
low_freq_factor: float,
high_freq_factor: float,
old_context_len: float,
) -> None:
get_rope_module().apply_llama31_rope_pos_ids(
q,
k,
q_rope,
k_rope,
pos_ids,
rotary_dim,
interleave,
rope_scale,
rope_theta,
low_freq_factor,
high_freq_factor,
old_context_len,
)
@register_fake_op("flashinfer::apply_llama31_rope_pos_ids")
def _fake_apply_llama31_rope_pos_ids(
q: torch.Tensor,
k: torch.Tensor,
q_rope: torch.Tensor,
k_rope: torch.Tensor,
pos_ids: torch.Tensor,
rotary_dim: int,
interleave: bool,
rope_scale: float,
rope_theta: float,
low_freq_factor: float,
high_freq_factor: float,
old_context_len: float,
) -> None:
pass
def apply_rope_inplace(
q: torch.Tensor,
k: torch.Tensor,
indptr: torch.Tensor,
offsets: torch.Tensor,
rotary_dim: Optional[int] = None,
interleave: bool = False,
rope_scale: float = 1,
rope_theta: float = 1e4,
) -> None:
r"""Apply rotary embedding to a batch of queries/keys (stored as RaggedTensor) inplace.
cos/sin values are computed on the fly inside the kernel.
We use :attr:`indptr` to denote the start pointer of each segment in the batch, the i-th
segment the query of the i-th segment is ``q[indptr[i]:indptr[i+1]]`` and the key of the
i-th segment is ``k[indptr[i]:indptr[i+1]]``, the first element of :attr:`indptr` is always
0 and the last element of :attr:`indptr` is the total number of queries/keys in the batch.
Please see :ref:`Ragged Tensor tutorial <kv-layout>` for more details about the
ragged tensor.
Parameters
----------
q : torch.Tensor
Query ragged tensor, shape: ``(nnz, num_q_heads, head_dim)`, where ``nnz`` is the last
element of ``indptr``.
k : torch.Tensor
Key ragged tensor, shape: ``(nnz, num_k_heads, head_dim)``, where ``nnz`` is the last
element of ``indptr``.
indptr : torch.Tensor
Indptr tensor, shape: ``(batch_size + 1)``.
offsets : torch.Tensor
The relative position offsets of each query in the batch, shape: ``(batch_size)``.
rotary_dim : Optional[int]
The dimensions to apply RoPE, if ``None``, we apply RoPE to the entire head dimension,
otherwise, we apply RoPE to the first ``rotary_dim`` dimensions, default: ``None``.
interleave : bool
Whether to use interleaved layout in the last dimension, default: ``False``.
* If ``True``, the last dimension of the query/key tensor is interleaved, i.e.,
we rotate the even dimensions ``([..., ::2])`` and odd dimensions ``([..., 1::2])``.
* If ``False``, the last dimension of the query/key tensor is not interleaved, i.e.,
we rotate the first half dimensions ``([..., :head_dim//2])`` and the second half
dimensions ``([..., head_dim//2:])``.
rope_scale : float
The scaling factor used in the rope embedding, default: ``1``.
rope_theta : float
The theta value used in the rope embedding, default: ``1e4``.
Examples
--------
>>> import torch
>>> import flashinfer
>>> batch_size = 128
>>> qkv_len = 1024
>>> num_qo_heads = 32
>>> num_kv_heads = 32
>>> head_dim = 128
>>> nnz = batch_size * qkv_len
>>> qkv_packed = torch.randn(
>>> nnz,
>>> (num_qo_heads + 2 * num_kv_heads) * head_dim,
>>> dtype=torch.float16,
>>> device="cuda:0",
>>> )
>>> q = qkv_packed[:, : num_qo_heads * head_dim].reshape(nnz, num_qo_heads, head_dim)
>>> k = qkv_packed[
... :, num_qo_heads * head_dim : (num_qo_heads + num_kv_heads) * head_dim
... ].reshape(nnz, num_kv_heads, head_dim)
>>> indptr = torch.tensor(
... [i * qkv_len for i in range(batch_size + 1)], dtype=torch.int32, device="cuda:0"
>>> )
>>> offsets = torch.full((batch_size,), 10, dtype=torch.int32, device="cuda:0")
>>> flashinfer.apply_rope_inplace(q, k, indptr, offsets)
See Also
--------
apply_rope
"""
if rotary_dim is None:
rotary_dim = q.size(-1)
_apply_rope(
q, k, q, k, indptr, offsets, rotary_dim, interleave, rope_scale, rope_theta
)
def apply_rope_pos_ids_inplace(
q: torch.Tensor,
k: torch.Tensor,
pos_ids: torch.Tensor,
rotary_dim: Optional[int] = None,
interleave: bool = False,
rope_scale: float = 1,
rope_theta: float = 1e4,
) -> None:
r"""Apply rotary embedding to a batch of queries/keys (stored as RaggedTensor) inplace.
cos/sin values are computed on the fly inside the kernel.
We use :attr:`indptr` to denote the start pointer of each segment in the batch, the i-th
segment the query of the i-th segment is ``q[indptr[i]:indptr[i+1]]`` and the key of the
i-th segment is ``k[indptr[i]:indptr[i+1]]``, the first element of :attr:`indptr` is always
0 and the last element of :attr:`indptr` is the total number of queries/keys in the batch.
Please see :ref:`Ragged Tensor tutorial <kv-layout>` for more details about the
ragged tensor.
Parameters
----------
q : torch.Tensor
Query ragged tensor, shape: ``(nnz, num_q_heads, head_dim)`, where ``nnz`` is the last
element of ``indptr``.
k : torch.Tensor
Key ragged tensor, shape: ``(nnz, num_k_heads, head_dim)``, where ``nnz`` is the last
element of ``indptr``.
pos_ids : torch.Tensor
Position indices, shape: ``(nnz)``.
rotary_dim : Optional[int]
The dimensions to apply RoPE, if ``None``, we apply RoPE to the entire head dimension,
otherwise, we apply RoPE to the first ``rotary_dim`` dimensions, default: ``None``.
interleave : bool
Whether to use interleaved layout in the last dimension, default: ``False``.
* If ``True``, the last dimension of the query/key tensor is interleaved, i.e.,
we rotate the even dimensions ``([..., ::2])`` and odd dimensions ``([..., 1::2])``.
* If ``False``, the last dimension of the query/key tensor is not interleaved, i.e.,
we rotate the first half dimensions ``([..., :head_dim//2])`` and the second half
dimensions ``([..., head_dim//2:])``.
rope_scale : float
The scaling factor used in the rope embedding, default: ``1``.
rope_theta : float
The theta value used in the rope embedding, default: ``1e4``.
See Also
--------
apply_rope_pos_ids
"""
if rotary_dim is None:
rotary_dim = q.size(-1)
_apply_rope_pos_ids(
q, k, q, k, pos_ids, rotary_dim, interleave, rope_scale, rope_theta
)
def apply_llama31_rope_inplace(
q: torch.Tensor,
k: torch.Tensor,
indptr: torch.Tensor,
offsets: torch.Tensor,
rotary_dim: Optional[int] = None,
interleave: bool = False,
rope_scale: float = 8,
rope_theta: float = 5e5,
low_freq_factor: float = 1,
high_freq_factor: float = 4,
old_context_len: int = 8192,
) -> None:
r"""Apply Llama 3.1 style rotary embedding to a batch of queries/keys (stored as
RaggedTensor) inplace. cos/sin values are computed on the fly inside the kernel.
We use :attr:`indptr` to denote the start pointer of each segment in the batch, the i-th
segment the query of the i-th segment is ``q[indptr[i]:indptr[i+1]]`` and the key of the
i-th segment is ``k[indptr[i]:indptr[i+1]]``, the first element of :attr:`indptr` is always
0 and the last element of :attr:`indptr` is the total number of queries/keys in the batch.
Please see :ref:`Ragged Tensor tutorial <kv-layout>` for more details about the
ragged tensor.
Parameters
----------
q : torch.Tensor
Query ragged tensor, shape: ``(nnz, num_q_heads, head_dim)``, where ``nnz`` is the last
element of ``indptr``.
k : torch.Tensor
Key ragged tensor, shape: ``(nnz, num_k_heads, head_dim)``, where ``nnz`` is the last
element of ``indptr``.
indptr : torch.Tensor
Indptr tensor, shape: ``(batch_size + 1)``.
offsets : torch.Tensor
The relative position offsets of each query in the batch, shape: ``(batch_size)``.
rotary_dim : Optional[int]
The dimensions to apply RoPE, if ``None``, we apply RoPE to the entire head dimension,
otherwise, we apply RoPE to the first ``rotary_dim`` dimensions, default: ``None``.
interleave : bool
Whether to use interleaved layout in the last dimension, default: ``False``.
* If ``True``, the last dimension of the query/key tensor is interleaved, i.e.,
we rotate the even dimensions ``([..., ::2])`` and odd dimensions ``([..., 1::2])``.
* If ``False``, the last dimension of the query/key tensor is not interleaved, i.e.,
we rotate the first half dimensions ``([..., :head_dim//2])`` and the second half
dimensions ``([..., head_dim//2:])``.
rope_scale : float
The scaling factor used in the rope embedding, default: ``8``.
rope_theta : float
The theta value used in the rope embedding, default: ``5e5``.
low_freq_factor : float
The low frequency factor used in Llama 3.1 RoPE, default: ``1``.
high_freq_factor : float
The high frequency factor used in Llama 3.1 RoPE, default: ``4``.
old_context_len : int
The old context length used in Llama 3.1 RoPE, default: ``8192``.
Examples
--------
>>> import torch
>>> import flashinfer
>>> batch_size = 128
>>> qkv_len = 1024
>>> num_qo_heads = 32
>>> num_kv_heads = 32
>>> head_dim = 128
>>> nnz = batch_size * qkv_len
>>> qkv_packed = torch.randn(
>>> nnz,
>>> (num_qo_heads + 2 * num_kv_heads) * head_dim,
>>> dtype=torch.float16,
>>> device="cuda:0",
>>> )
>>> q = qkv_packed[:, : num_qo_heads * head_dim].reshape(nnz, num_qo_heads, head_dim)
>>> k = qkv_packed[
... :, num_qo_heads * head_dim : (num_qo_heads + num_kv_heads) * head_dim
... ].reshape(nnz, num_kv_heads, head_dim)
>>> indptr = torch.tensor(
... [i * qkv_len for i in range(batch_size + 1)], dtype=torch.int32, device="cuda:0"
>>> )
>>> offsets = torch.full((batch_size,), 10, dtype=torch.int32, device="cuda:0")
>>> flashinfer.apply_llama31_rope_inplace(q, k, indptr, offsets)
See Also
--------
apply_llama31_rope
"""
if rotary_dim is None:
rotary_dim = q.size(-1)
_apply_llama31_rope(
q,
k,
q,
k,
indptr,
offsets,
rotary_dim,
interleave,
rope_scale,
rope_theta,
low_freq_factor,
high_freq_factor,
float(old_context_len),
)
def apply_llama31_rope_pos_ids_inplace(
q: torch.Tensor,
k: torch.Tensor,
pos_ids: torch.Tensor,
rotary_dim: Optional[int] = None,
interleave: bool = False,
rope_scale: float = 8,
rope_theta: float = 5e5,
low_freq_factor: float = 1,
high_freq_factor: float = 4,
old_context_len: int = 8192,
) -> None:
r"""Apply Llama 3.1 style rotary embedding to a batch of queries/keys (stored as
RaggedTensor) inplace. cos/sin values are computed on the fly inside the kernel.
We use :attr:`indptr` to denote the start pointer of each segment in the batch, the i-th
segment the query of the i-th segment is ``q[indptr[i]:indptr[i+1]]`` and the key of the
i-th segment is ``k[indptr[i]:indptr[i+1]]``, the first element of :attr:`indptr` is always
0 and the last element of :attr:`indptr` is the total number of queries/keys in the batch.
Please see :ref:`Ragged Tensor tutorial <kv-layout>` for more details about the
ragged tensor.
Parameters
----------
q : torch.Tensor
Query ragged tensor, shape: ``(nnz, num_q_heads, head_dim)``, where ``nnz`` is the last
element of ``indptr``.
k : torch.Tensor
Key ragged tensor, shape: ``(nnz, num_k_heads, head_dim)``, where ``nnz`` is the last
element of ``indptr``.
pos_ids : torch.Tensor
Position indices, shape: ``(nnz)``.
rotary_dim : Optional[int]
The dimensions to apply RoPE, if ``None``, we apply RoPE to the entire head dimension,
otherwise, we apply RoPE to the first ``rotary_dim`` dimensions, default: ``None``.
interleave : bool
Whether to use interleaved layout in the last dimension, default: ``False``.
* If ``True``, the last dimension of the query/key tensor is interleaved, i.e.,
we rotate the even dimensions ``([..., ::2])`` and odd dimensions ``([..., 1::2])``.
* If ``False``, the last dimension of the query/key tensor is not interleaved, i.e.,
we rotate the first half dimensions ``([..., :head_dim//2])`` and the second half
dimensions ``([..., head_dim//2:])``.
rope_scale : float
The scaling factor used in the rope embedding, default: ``8``.
rope_theta : float
The theta value used in the rope embedding, default: ``5e5``.
low_freq_factor : float
The low frequency factor used in Llama 3.1 RoPE, default: ``1``.
high_freq_factor : float
The high frequency factor used in Llama 3.1 RoPE, default: ``4``.
old_context_len : int
The old context length used in Llama 3.1 RoPE, default: ``8192``.
See Also
--------
apply_llama31_rope_pos_ids
"""
if rotary_dim is None:
rotary_dim = q.size(-1)
_apply_llama31_rope_pos_ids(
q,
k,
q,
k,
pos_ids,
rotary_dim,
interleave,
rope_scale,
rope_theta,
low_freq_factor,
high_freq_factor,
float(old_context_len),
)
def apply_rope(
q: torch.Tensor,
k: torch.Tensor,
indptr: torch.Tensor,
offsets: torch.Tensor,
rotary_dim: Optional[int] = None,
interleave: bool = False,
rope_scale: float = 1,
rope_theta: float = 1e4,
) -> Tuple[torch.Tensor, torch.Tensor]:
r"""Apply rotary embedding to a batch of queries/keys (stored as RaggedTensor).
cos/sin values are computed on the fly inside the kernel.
We use :attr:`indptr` to denote the start pointer of each segment in the batch, the i-th
segment the query of the i-th segment is ``q[indptr[i]:indptr[i+1]]`` and the key of the
i-th segment is ``k[indptr[i]:indptr[i+1]]``, the first element of :attr:`indptr` is always
0 and the last element of :attr:`indptr` is the total number of queries/keys in the batch.
Please see :ref:`Ragged Tensor tutorial <kv-layout>` for more details about the
ragged tensor.
Parameters
----------
q : torch.Tensor
Query ragged tensor, shape: ``(nnz, num_q_heads, head_dim)`, where ``nnz`` is the last
element of ``indptr``.
k : torch.Tensor
Key ragged tensor, shape: ``(nnz, num_k_heads, head_dim)``, where ``nnz`` is the last
element of ``indptr``.
indptr : torch.Tensor
Indptr tensor, shape: ``(batch_size + 1)``.
offsets : torch.Tensor
The relative position offsets of each query in the batch, shape: ``(batch_size)``.
rotary_dim : Optional[int]
The dimensions to apply RoPE, if ``None``, we apply RoPE to the entire head dimension,
otherwise, we apply RoPE to the first ``rotary_dim`` dimensions, default: ``None``.
interleave : bool
Whether to use interleaved layout in the last dimension, default: ``False``.
* If ``True``, the last dimension of the query/key tensor is interleaved, i.e.,
we rotate the even dimensions ``([..., ::2])`` and odd dimensions ``([..., 1::2])``.
* If ``False``, the last dimension of the query/key tensor is not interleaved, i.e.,
we rotate the first half dimensions ``([..., :head_dim//2])`` and the second half
dimensions ``([..., head_dim//2:])``.
rope_scale : float
The scaling factor used in the rope embedding, default: ``1``.
rope_theta : float
The theta value used in the rope embedding, default: ``1e4``.
Returns
-------
q_rope : torch.Tensor
The rotated query tensor, shape: ``(nnz, num_q_heads, head_dim)``.
k_rope : torch.Tensor
The rotated key tensor, shape: ``(nnz, num_k_heads, head_dim)``.
Examples
--------
>>> import torch
>>> import flashinfer
>>> batch_size = 128
>>> qkv_len = 1024
>>> num_qo_heads = 32
>>> num_kv_heads = 32
>>> head_dim = 128
>>> nnz = batch_size * qkv_len
>>> qkv_packed = torch.randn(
>>> nnz,
>>> (num_qo_heads + 2 * num_kv_heads) * head_dim,
>>> dtype=torch.float16,
>>> device="cuda:0",
>>> )
>>> q = qkv_packed[:, : num_qo_heads * head_dim].reshape(nnz, num_qo_heads, head_dim)
>>> k = qkv_packed[
... :, num_qo_heads * head_dim : (num_qo_heads + num_kv_heads) * head_dim
... ].reshape(nnz, num_kv_heads, head_dim)
>>> indptr = torch.tensor(
... [i * qkv_len for i in range(batch_size + 1)], dtype=torch.int32, device="cuda:0"
>>> )
>>> offsets = torch.full((batch_size,), 10, dtype=torch.int32, device="cuda:0")
>>> q_rope, k_rope = flashinfer.apply_rope(q, k, indptr, offsets)
>>> q_rope.shape
torch.Size([131072, 32, 128])
>>> k_rope.shape
torch.Size([131072, 32, 128])
See Also
--------
apply_rope_inplace
"""
q_rope = torch.empty_like(q)
k_rope = torch.empty_like(k)
if rotary_dim is None:
rotary_dim = q.size(-1)
_apply_rope(
q,
k,
q_rope,
k_rope,
indptr,
offsets,
rotary_dim,
interleave,
rope_scale,
rope_theta,
)
return q_rope, k_rope
def apply_rope_pos_ids(
q: torch.Tensor,
k: torch.Tensor,
pos_ids: torch.Tensor,
rotary_dim: Optional[int] = None,
interleave: bool = False,
rope_scale: float = 1,
rope_theta: float = 1e4,
) -> Tuple[torch.Tensor, torch.Tensor]:
r"""Apply rotary embedding to a batch of queries/keys (stored as RaggedTensor).
cos/sin values are computed on the fly inside the kernel.
We use :attr:`indptr` to denote the start pointer of each segment in the batch, the i-th
segment the query of the i-th segment is ``q[indptr[i]:indptr[i+1]]`` and the key of the
i-th segment is ``k[indptr[i]:indptr[i+1]]``, the first element of :attr:`indptr` is always
0 and the last element of :attr:`indptr` is the total number of queries/keys in the batch.
Please see :ref:`Ragged Tensor tutorial <kv-layout>` for more details about the
ragged tensor.
Parameters
----------
q : torch.Tensor
Query ragged tensor, shape: ``(nnz, num_q_heads, head_dim)`, where ``nnz`` is the last
element of ``indptr``.
k : torch.Tensor
Key ragged tensor, shape: ``(nnz, num_k_heads, head_dim)``, where ``nnz`` is the last
element of ``indptr``.
pos_ids : torch.Tensor
Position indices, shape: ``(batch_size + 1)``.
rotary_dim : Optional[int]
The dimensions to apply RoPE, if ``None``, we apply RoPE to the entire head dimension,
otherwise, we apply RoPE to the first ``rotary_dim`` dimensions, default: ``None``.
interleave : bool
Whether to use interleaved layout in the last dimension, default: ``False``.
* If ``True``, the last dimension of the query/key tensor is interleaved, i.e.,
we rotate the even dimensions ``([..., ::2])`` and odd dimensions ``([..., 1::2])``.
* If ``False``, the last dimension of the query/key tensor is not interleaved, i.e.,
we rotate the first half dimensions ``([..., :head_dim//2])`` and the second half
dimensions ``([..., head_dim//2:])``.
rope_scale : float
The scaling factor used in the rope embedding, default: ``1``.
rope_theta : float
The theta value used in the rope embedding, default: ``1e4``.
Returns
-------
q_rope : torch.Tensor
The rotated query tensor, shape: ``(nnz, num_q_heads, head_dim)``.
k_rope : torch.Tensor
The rotated key tensor, shape: ``(nnz, num_k_heads, head_dim)``.
See Also
--------
apply_rope_inplace
"""
q_rope = torch.empty_like(q)
k_rope = torch.empty_like(k)
if rotary_dim is None:
rotary_dim = q.size(-1)
_apply_rope_pos_ids(
q, k, q_rope, k_rope, pos_ids, rotary_dim, interleave, rope_scale, rope_theta
)
return q_rope, k_rope
def apply_llama31_rope(
q: torch.Tensor,
k: torch.Tensor,
indptr: torch.Tensor,
offsets: torch.Tensor,
rotary_dim: Optional[int] = None,
interleave: bool = False,
rope_scale: float = 8,
rope_theta: float = 5e5,
low_freq_factor: float = 1,
high_freq_factor: float = 4,
old_context_len: int = 8192,
) -> Tuple[torch.Tensor, torch.Tensor]:
r"""Apply Llama 3.1 style rotary embedding to a batch of queries/keys (stored as
RaggedTensor). cos/sin values are computed on the fly inside the kernel.
We use :attr:`indptr` to denote the start pointer of each segment in the batch, the i-th
segment the query of the i-th segment is ``q[indptr[i]:indptr[i+1]]`` and the key of the
i-th segment is ``k[indptr[i]:indptr[i+1]]``, the first element of :attr:`indptr` is always
0 and the last element of :attr:`indptr` is the total number of queries/keys in the batch.
Please see :ref:`Ragged Tensor tutorial <kv-layout>` for more details about the
ragged tensor.
Parameters
----------
q : torch.Tensor
Query ragged tensor, shape: ``(nnz, num_q_heads, head_dim)``, where ``nnz`` is the last
element of ``indptr``.
k : torch.Tensor
Key ragged tensor, shape: ``(nnz, num_k_heads, head_dim)``, where ``nnz`` is the last
element of ``indptr``.
indptr : torch.Tensor
Indptr tensor, shape: ``(batch_size + 1)``.
offsets : torch.Tensor
The relative position offsets of each query in the batch, shape: ``(batch_size)``.
rotary_dim : Optional[int]
The dimensions to apply RoPE, if ``None``, we apply RoPE to the entire head dimension,
otherwise, we apply RoPE to the first ``rotary_dim`` dimensions, default: ``None``.
interleave : bool
Whether to use interleaved layout in the last dimension, default: ``False``.
* If ``True``, the last dimension of the query/key tensor is interleaved, i.e.,
we rotate the even dimensions ``([..., ::2])`` and odd dimensions ``([..., 1::2])``.
* If ``False``, the last dimension of the query/key tensor is not interleaved, i.e.,
we rotate the first half dimensions ``([..., :head_dim//2])`` and the second half
dimensions ``([..., head_dim//2:])``.
rope_scale : float
The scaling factor used in the rope embedding, default: ``8``.
rope_theta : float
The theta value used in the rope embedding, default: ``5e5``.
low_freq_factor : float
The low frequency factor used in Llama 3.1 RoPE, default: ``1``.
high_freq_factor : float
The high frequency factor used in Llama 3.1 RoPE, default: ``4``.
old_context_len : int
The old context length used in Llama 3.1 RoPE, default: ``8192``.
Returns
-------
q_rope : torch.Tensor
The rotated query tensor, shape: ``(nnz, num_q_heads, head_dim)``.
k_rope : torch.Tensor
The rotated key tensor, shape: ``(nnz, num_k_heads, head_dim)``.
Examples
--------
>>> import torch
>>> import flashinfer
>>> batch_size = 128
>>> qkv_len = 1024
>>> num_qo_heads = 32
>>> num_kv_heads = 32
>>> head_dim = 128
>>> nnz = batch_size * qkv_len
>>> qkv_packed = torch.randn(
>>> nnz,
>>> (num_qo_heads + 2 * num_kv_heads) * head_dim,
>>> dtype=torch.float16,
>>> device="cuda:0",
>>> )
>>> q = qkv_packed[:, : num_qo_heads * head_dim].reshape(nnz, num_qo_heads, head_dim)
>>> k = qkv_packed[
... :, num_qo_heads * head_dim : (num_qo_heads + num_kv_heads) * head_dim
... ].reshape(nnz, num_kv_heads, head_dim)
>>> indptr = torch.tensor(
... [i * qkv_len for i in range(batch_size + 1)], dtype=torch.int32, device="cuda:0"
>>> )
>>> offsets = torch.full((batch_size,), 10, dtype=torch.int32, device="cuda:0")
>>> q_rope, k_rope = flashinfer.apply_llama31_rope(q, k, indptr, offsets)
>>> q_rope.shape
torch.Size([131072, 32, 128])
>>> k_rope.shape
torch.Size([131072, 32, 128])
See Also
--------
apply_llama31_rope_inplace
"""
q_rope = torch.empty_like(q)
k_rope = torch.empty_like(k)
if rotary_dim is None:
rotary_dim = q.size(-1)
_apply_llama31_rope(
q,
k,
q_rope,
k_rope,
indptr,
offsets,
rotary_dim,
interleave,
rope_scale,
rope_theta,
low_freq_factor,
high_freq_factor,
float(old_context_len),
)
return q_rope, k_rope
def apply_llama31_rope_pos_ids(
q: torch.Tensor,
k: torch.Tensor,
pos_ids: torch.Tensor,
rotary_dim: Optional[int] = None,
interleave: bool = False,
rope_scale: float = 8,
rope_theta: float = 5e5,
low_freq_factor: float = 1,
high_freq_factor: float = 4,
old_context_len: int = 8192,
) -> Tuple[torch.Tensor, torch.Tensor]:
r"""Apply Llama 3.1 style rotary embedding to a batch of queries/keys (stored as
RaggedTensor). cos/sin values are computed on the fly inside the kernel.
We use :attr:`indptr` to denote the start pointer of each segment in the batch, the i-th
segment the query of the i-th segment is ``q[indptr[i]:indptr[i+1]]`` and the key of the
i-th segment is ``k[indptr[i]:indptr[i+1]]``, the first element of :attr:`indptr` is always
0 and the last element of :attr:`indptr` is the total number of queries/keys in the batch.
Please see :ref:`Ragged Tensor tutorial <kv-layout>` for more details about the
ragged tensor.
Parameters
----------
q : torch.Tensor
Query ragged tensor, shape: ``(nnz, num_q_heads, head_dim)``, where ``nnz`` is the last
element of ``indptr``.
k : torch.Tensor
Key ragged tensor, shape: ``(nnz, num_k_heads, head_dim)``, where ``nnz`` is the last
element of ``indptr``.
pos_ids : torch.Tensor
Position indices, shape: ``(nnz)``.
rotary_dim : Optional[int]
The dimensions to apply RoPE, if ``None``, we apply RoPE to the entire head dimension,
otherwise, we apply RoPE to the first ``rotary_dim`` dimensions, default: ``None``.
interleave : bool
Whether to use interleaved layout in the last dimension, default: ``False``.
* If ``True``, the last dimension of the query/key tensor is interleaved, i.e.,
we rotate the even dimensions ``([..., ::2])`` and odd dimensions ``([..., 1::2])``.
* If ``False``, the last dimension of the query/key tensor is not interleaved, i.e.,
we rotate the first half dimensions ``([..., :head_dim//2])`` and the second half
dimensions ``([..., head_dim//2:])``
rope_scale : float
The scaling factor used in the rope embedding, default: ``8``.
rope_theta : float
The theta value used in the rope embedding, default: ``5e5``.
low_freq_factor : float
The low frequency factor used in Llama 3.1 RoPE, default: ``1``.
high_freq_factor : float
The high frequency factor used in Llama 3.1 RoPE, default: ``4``.
old_context_len : int
The old context length used in Llama 3.1 RoPE, default: ``8192``.
Returns
-------
q_rope : torch.Tensor
The rotated query tensor, shape: ``(nnz, num_q_heads, head_dim)``.
k_rope : torch.Tensor
The rotated key tensor, shape: ``(nnz, num_k_heads, head_dim)``.
See Also
--------
apply_llama31_rope_pos_ids_inplace
"""
q_rope = torch.empty_like(q)
k_rope = torch.empty_like(k)
if rotary_dim is None:
rotary_dim = q.size(-1)
_apply_llama31_rope_pos_ids(
q,
k,
q_rope,
k_rope,
pos_ids,
rotary_dim,
interleave,
rope_scale,
rope_theta,
low_freq_factor,
high_freq_factor,
float(old_context_len),
)
return q_rope, k_rope
def apply_rope_with_cos_sin_cache(
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
head_size: int,
cos_sin_cache: torch.Tensor,
is_neox: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
r"""
Apply rotary embedding to keys and queries with precomputed cos/sin values.
This is designed to be compatible with the SGL/vLLM implementation.
Parameters
----------
positions : torch.Tensor
Position indices, shape: ``(nnz)``.
query : torch.Tensor
Query tensor, shape: ``(nnz, num_q_heads * head_size)``.
key : torch.Tensor
Key tensor, shape: ``(nnz, num_k_heads * head_size)``.
cos_sin_cache : torch.Tensor
Cosine and Sine cache tensor, shape: ``(max_seq_len, rotary_dim)``.
Cosine is the first half and Sine is the second half on rotary_dim.
is_neox : bool
Whether to use Neox style RoPE, default: ``True``.
* If ``True``, the last dimension of the query/key tensor is not interleaved, i.e.,
we rorate the first half dimensions ``([..., :head_dim//2])`` and the second half
dimensions ``([..., head_dim//2:])``.
* If ``False``, the last dimension of the query/key tensor is interleaved, i.e.,
we rotate the even dimensions ``([..., ::2])`` and odd dimensions ``([..., 1::2])``.
Returns
-------
query_out : torch.Tensor
The rotated query tensor, shape: ``(nnz, num_q_heads * head_size)``.
key_out : torch.Tensor
The rotated key tensor, shape: ``(nnz, num_k_heads * head_size)``.
Note
----
The rotary dimension is determined by the cosine cache and sine cache.
"""
if cos_sin_cache.dtype != torch.float32:
raise ValueError("cos_sin_cache should be float32")
query_out = torch.empty_like(query)
key_out = torch.empty_like(key)
_apply_rope_pos_ids_cos_sin_cache(
q=query.view(query.shape[0], -1, head_size),
k=key.view(key.shape[0], -1, head_size),
q_rope=query_out.view(query_out.shape[0], -1, head_size),
k_rope=key_out.view(key_out.shape[0], -1, head_size),
cos_sin_cache=cos_sin_cache,
pos_ids=positions,
interleave=(not is_neox),
)
return query_out, key_out
def apply_rope_with_cos_sin_cache_inplace(
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
head_size: int,
cos_sin_cache: torch.Tensor,
is_neox: bool = True,
) -> None:
r"""
Apply rotary embedding to keys and queries with precomputed cos/sin values.
This is designed to be compatible with the SGL/vLLM implementation.
The result is inplace applied to the input tensors.
Parameters
----------
positions : torch.Tensor
Position indices, shape: ``(nnz)``.
query : torch.Tensor
Query tensor, shape: ``(nnz, num_q_heads * head_size)``.
key : torch.Tensor
Key tensor, shape: ``(nnz, num_k_heads * head_size)``.
cos_sin_cache : torch.Tensor
Cosine and Sine cache tensor, shape: ``(max_seq_len, rotary_dim)``.
Cosine is the first half and Sine is the second half on rotary_dim.
is_neox : bool
Whether to use Neox style RoPE, default: ``True``.
* If ``True``, the last dimension of the query/key tensor is not interleaved, i.e.,
we rorate the first half dimensions ``([..., :head_dim//2])`` and the second half
dimensions ``([..., head_dim//2:])``.
* If ``False``, the last dimension of the query/key tensor is interleaved, i.e.,
we rotate the even dimensions ``([..., ::2])`` and odd dimensions ``([..., 1::2])``.
Note
----
The rotary dimension is determined by the cosine cache and sine cache.
"""
if cos_sin_cache.dtype != torch.float32:
raise ValueError("cos_sin_cache should be float32")
# pass q_rope and k_rope as q and k to perform inplace operation
_apply_rope_pos_ids_cos_sin_cache(
q=query.view(query.shape[0], -1, head_size),
k=key.view(key.shape[0], -1, head_size),
q_rope=query.view(query.shape[0], -1, head_size),
k_rope=key.view(key.shape[0], -1, head_size),
cos_sin_cache=cos_sin_cache,
pos_ids=positions,
interleave=(not is_neox),
)
def mla_rope_quantize_fp8(
q_rope: torch.Tensor,
k_rope: torch.Tensor,
q_nope: torch.Tensor,
k_nope: torch.Tensor,
cos_sin_cache: torch.Tensor,
pos_ids: torch.Tensor,
is_neox: bool = True,
quantize_dtype: Optional[torch.dtype] = None,
quant_scale_q: float = 1.0,
quant_scale_kv: float = 1.0,
q_rope_out: Optional[torch.Tensor] = None,
k_rope_out: Optional[torch.Tensor] = None,
q_nope_out: Optional[torch.Tensor] = None,
k_nope_out: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
if cos_sin_cache.dtype != torch.float32:
raise ValueError("cos_sin_cache should be float32")
# Infer quantize_dtype from output tensors or default to float8_e4m3fn
if quantize_dtype is None:
for out in (q_rope_out, k_rope_out, q_nope_out, k_nope_out):
if out is not None:
quantize_dtype = out.dtype
break
else:
quantize_dtype = torch.float8_e4m3fn
# Allocate output tensors if not provided
q_rope_out = (
q_rope_out
if q_rope_out is not None
else torch.empty_like(q_rope, dtype=quantize_dtype)
)
k_rope_out = (
k_rope_out
if k_rope_out is not None
else torch.empty_like(k_rope, dtype=quantize_dtype)
)
q_nope_out = (
q_nope_out
if q_nope_out is not None
else torch.empty_like(q_nope, dtype=quantize_dtype)
)
k_nope_out = (
k_nope_out
if k_nope_out is not None
else torch.empty_like(k_nope, dtype=quantize_dtype)
)
_mla_rope_quantize(
q_rope,
k_rope,
q_nope,
k_nope,
cos_sin_cache,
pos_ids,
q_rope_out,
k_rope_out,
q_nope_out,
k_nope_out,
quant_scale_q,
quant_scale_kv,
not is_neox, # interleave
)
return q_rope_out, k_rope_out, q_nope_out, k_nope_out