""" Copyright (c) 2024 by FlashInfer team. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import functools from typing import Optional, Tuple import torch from .jit import JitSpec from .jit import env as jit_env from .jit import gen_jit_spec from .utils import register_custom_op, register_fake_op def gen_rope_module() -> JitSpec: return gen_jit_spec( "rope", [ jit_env.FLASHINFER_CSRC_DIR / "rope.cu", jit_env.FLASHINFER_CSRC_DIR / "flashinfer_rope_ops.cu", ], ) @functools.cache def get_rope_module(): return gen_rope_module().build_and_load() @register_custom_op("flashinfer::apply_rope", mutates_args=("q_rope", "k_rope")) def _apply_rope( q: torch.Tensor, k: torch.Tensor, q_rope: torch.Tensor, k_rope: torch.Tensor, indptr: torch.Tensor, offsets: torch.Tensor, rotary_dim: int, interleave: bool, rope_scale: float, rope_theta: float, ) -> None: get_rope_module().apply_rope( q, k, q_rope, k_rope, indptr, offsets, rotary_dim, interleave, rope_scale, rope_theta, ) @register_fake_op("flashinfer::apply_rope") def _fake_apply_rope( q: torch.Tensor, k: torch.Tensor, q_rope: torch.Tensor, k_rope: torch.Tensor, indptr: torch.Tensor, offsets: torch.Tensor, rotary_dim: int, interleave: bool, rope_scale: float, rope_theta: float, ) -> None: pass @register_custom_op("flashinfer::apply_llama31_rope", mutates_args=("q_rope", "k_rope")) def _apply_llama31_rope( q: torch.Tensor, k: torch.Tensor, q_rope: torch.Tensor, k_rope: torch.Tensor, indptr: torch.Tensor, offsets: torch.Tensor, rotary_dim: int, interleave: bool, rope_scale: float, rope_theta: float, low_freq_factor: float, high_freq_factor: float, old_context_len: float, ) -> None: get_rope_module().apply_llama31_rope( q, k, q_rope, k_rope, indptr, offsets, rotary_dim, interleave, rope_scale, rope_theta, low_freq_factor, high_freq_factor, old_context_len, ) @register_fake_op("flashinfer::apply_llama31_rope") def _fake_apply_llama31_rope( q: torch.Tensor, k: torch.Tensor, q_rope: torch.Tensor, k_rope: torch.Tensor, indptr: torch.Tensor, offsets: torch.Tensor, rotary_dim: int, interleave: bool, rope_scale: float, rope_theta: float, low_freq_factor: float, high_freq_factor: float, old_context_len: float, ) -> None: pass @register_custom_op("flashinfer::apply_rope_pos_ids", mutates_args=("q_rope", "k_rope")) def _apply_rope_pos_ids( q: torch.Tensor, k: torch.Tensor, q_rope: torch.Tensor, k_rope: torch.Tensor, pos_ids: torch.Tensor, rotary_dim: int, interleave: bool, rope_scale: float, rope_theta: float, ) -> None: get_rope_module().apply_rope_pos_ids( q, k, q_rope, k_rope, pos_ids, rotary_dim, interleave, rope_scale, rope_theta, ) @register_fake_op("flashinfer::apply_rope_pos_ids") def _fake_apply_rope_pos_ids( q: torch.Tensor, k: torch.Tensor, q_rope: torch.Tensor, k_rope: torch.Tensor, pos_ids: torch.Tensor, rotary_dim: int, interleave: bool, rope_scale: float, rope_theta: float, ) -> None: pass @register_custom_op( "flashinfer::mla_rope_quantize", mutates_args=("q_rope_out", "k_rope_out", "q_nope_out", "k_nope_out"), ) def _mla_rope_quantize( q_rope_in: torch.Tensor, k_rope_in: torch.Tensor, q_nope_in: torch.Tensor, k_nope_in: torch.Tensor, cos_sin_cache: torch.Tensor, pos_ids: torch.Tensor, q_rope_out: torch.Tensor, k_rope_out: torch.Tensor, q_nope_out: torch.Tensor, k_nope_out: torch.Tensor, quant_scale_q: float, quant_scale_kv: float, interleave: bool, ) -> None: get_rope_module().mla_rope_quantize( q_rope_in, k_rope_in, q_nope_in, k_nope_in, q_rope_out, k_rope_out, q_nope_out, k_nope_out, cos_sin_cache, pos_ids, quant_scale_q, quant_scale_kv, interleave, ) @register_fake_op("flashinfer::mla_rope_quantize") def _fake_mla_rope_quantize( q_rope_in: torch.Tensor, k_rope_in: torch.Tensor, q_nope_in: torch.Tensor, k_nope_in: torch.Tensor, cos_sin_cache: torch.Tensor, pos_ids: torch.Tensor, q_rope_out: torch.Tensor, k_rope_out: torch.Tensor, q_nope_out: torch.Tensor, k_nope_out: torch.Tensor, quant_scale_q: float, quant_scale_kv: float, interleave: bool, ) -> None: pass @register_custom_op( "flashinfer::apply_rope_pos_ids_cos_sin_cache", mutates_args=("q_rope", "k_rope") ) def _apply_rope_pos_ids_cos_sin_cache( q: torch.Tensor, k: torch.Tensor, q_rope: torch.Tensor, k_rope: torch.Tensor, cos_sin_cache: torch.Tensor, pos_ids: torch.Tensor, interleave: bool, ) -> None: get_rope_module().apply_rope_pos_ids_cos_sin_cache( q, k, q_rope, k_rope, cos_sin_cache, pos_ids, interleave, ) @register_fake_op("flashinfer::apply_rope_pos_ids_cos_sin_cache") def _fake_apply_rope_pos_ids_cos_sin_cache( q: torch.Tensor, k: torch.Tensor, q_rope: torch.Tensor, k_rope: torch.Tensor, cos_cache: torch.Tensor, sin_cache: torch.Tensor, pos_ids: torch.Tensor, interleave: bool, ) -> None: pass @register_custom_op( "flashinfer::apply_llama31_rope_pos_ids", mutates_args=("q_rope", "k_rope") ) def _apply_llama31_rope_pos_ids( q: torch.Tensor, k: torch.Tensor, q_rope: torch.Tensor, k_rope: torch.Tensor, pos_ids: torch.Tensor, rotary_dim: int, interleave: bool, rope_scale: float, rope_theta: float, low_freq_factor: float, high_freq_factor: float, old_context_len: float, ) -> None: get_rope_module().apply_llama31_rope_pos_ids( q, k, q_rope, k_rope, pos_ids, rotary_dim, interleave, rope_scale, rope_theta, low_freq_factor, high_freq_factor, old_context_len, ) @register_fake_op("flashinfer::apply_llama31_rope_pos_ids") def _fake_apply_llama31_rope_pos_ids( q: torch.Tensor, k: torch.Tensor, q_rope: torch.Tensor, k_rope: torch.Tensor, pos_ids: torch.Tensor, rotary_dim: int, interleave: bool, rope_scale: float, rope_theta: float, low_freq_factor: float, high_freq_factor: float, old_context_len: float, ) -> None: pass def apply_rope_inplace( q: torch.Tensor, k: torch.Tensor, indptr: torch.Tensor, offsets: torch.Tensor, rotary_dim: Optional[int] = None, interleave: bool = False, rope_scale: float = 1, rope_theta: float = 1e4, ) -> None: r"""Apply rotary embedding to a batch of queries/keys (stored as RaggedTensor) inplace. cos/sin values are computed on the fly inside the kernel. We use :attr:`indptr` to denote the start pointer of each segment in the batch, the i-th segment the query of the i-th segment is ``q[indptr[i]:indptr[i+1]]`` and the key of the i-th segment is ``k[indptr[i]:indptr[i+1]]``, the first element of :attr:`indptr` is always 0 and the last element of :attr:`indptr` is the total number of queries/keys in the batch. Please see :ref:`Ragged Tensor tutorial ` for more details about the ragged tensor. Parameters ---------- q : torch.Tensor Query ragged tensor, shape: ``(nnz, num_q_heads, head_dim)`, where ``nnz`` is the last element of ``indptr``. k : torch.Tensor Key ragged tensor, shape: ``(nnz, num_k_heads, head_dim)``, where ``nnz`` is the last element of ``indptr``. indptr : torch.Tensor Indptr tensor, shape: ``(batch_size + 1)``. offsets : torch.Tensor The relative position offsets of each query in the batch, shape: ``(batch_size)``. rotary_dim : Optional[int] The dimensions to apply RoPE, if ``None``, we apply RoPE to the entire head dimension, otherwise, we apply RoPE to the first ``rotary_dim`` dimensions, default: ``None``. interleave : bool Whether to use interleaved layout in the last dimension, default: ``False``. * If ``True``, the last dimension of the query/key tensor is interleaved, i.e., we rotate the even dimensions ``([..., ::2])`` and odd dimensions ``([..., 1::2])``. * If ``False``, the last dimension of the query/key tensor is not interleaved, i.e., we rotate the first half dimensions ``([..., :head_dim//2])`` and the second half dimensions ``([..., head_dim//2:])``. rope_scale : float The scaling factor used in the rope embedding, default: ``1``. rope_theta : float The theta value used in the rope embedding, default: ``1e4``. Examples -------- >>> import torch >>> import flashinfer >>> batch_size = 128 >>> qkv_len = 1024 >>> num_qo_heads = 32 >>> num_kv_heads = 32 >>> head_dim = 128 >>> nnz = batch_size * qkv_len >>> qkv_packed = torch.randn( >>> nnz, >>> (num_qo_heads + 2 * num_kv_heads) * head_dim, >>> dtype=torch.float16, >>> device="cuda:0", >>> ) >>> q = qkv_packed[:, : num_qo_heads * head_dim].reshape(nnz, num_qo_heads, head_dim) >>> k = qkv_packed[ ... :, num_qo_heads * head_dim : (num_qo_heads + num_kv_heads) * head_dim ... ].reshape(nnz, num_kv_heads, head_dim) >>> indptr = torch.tensor( ... [i * qkv_len for i in range(batch_size + 1)], dtype=torch.int32, device="cuda:0" >>> ) >>> offsets = torch.full((batch_size,), 10, dtype=torch.int32, device="cuda:0") >>> flashinfer.apply_rope_inplace(q, k, indptr, offsets) See Also -------- apply_rope """ if rotary_dim is None: rotary_dim = q.size(-1) _apply_rope( q, k, q, k, indptr, offsets, rotary_dim, interleave, rope_scale, rope_theta ) def apply_rope_pos_ids_inplace( q: torch.Tensor, k: torch.Tensor, pos_ids: torch.Tensor, rotary_dim: Optional[int] = None, interleave: bool = False, rope_scale: float = 1, rope_theta: float = 1e4, ) -> None: r"""Apply rotary embedding to a batch of queries/keys (stored as RaggedTensor) inplace. cos/sin values are computed on the fly inside the kernel. We use :attr:`indptr` to denote the start pointer of each segment in the batch, the i-th segment the query of the i-th segment is ``q[indptr[i]:indptr[i+1]]`` and the key of the i-th segment is ``k[indptr[i]:indptr[i+1]]``, the first element of :attr:`indptr` is always 0 and the last element of :attr:`indptr` is the total number of queries/keys in the batch. Please see :ref:`Ragged Tensor tutorial ` for more details about the ragged tensor. Parameters ---------- q : torch.Tensor Query ragged tensor, shape: ``(nnz, num_q_heads, head_dim)`, where ``nnz`` is the last element of ``indptr``. k : torch.Tensor Key ragged tensor, shape: ``(nnz, num_k_heads, head_dim)``, where ``nnz`` is the last element of ``indptr``. pos_ids : torch.Tensor Position indices, shape: ``(nnz)``. rotary_dim : Optional[int] The dimensions to apply RoPE, if ``None``, we apply RoPE to the entire head dimension, otherwise, we apply RoPE to the first ``rotary_dim`` dimensions, default: ``None``. interleave : bool Whether to use interleaved layout in the last dimension, default: ``False``. * If ``True``, the last dimension of the query/key tensor is interleaved, i.e., we rotate the even dimensions ``([..., ::2])`` and odd dimensions ``([..., 1::2])``. * If ``False``, the last dimension of the query/key tensor is not interleaved, i.e., we rotate the first half dimensions ``([..., :head_dim//2])`` and the second half dimensions ``([..., head_dim//2:])``. rope_scale : float The scaling factor used in the rope embedding, default: ``1``. rope_theta : float The theta value used in the rope embedding, default: ``1e4``. See Also -------- apply_rope_pos_ids """ if rotary_dim is None: rotary_dim = q.size(-1) _apply_rope_pos_ids( q, k, q, k, pos_ids, rotary_dim, interleave, rope_scale, rope_theta ) def apply_llama31_rope_inplace( q: torch.Tensor, k: torch.Tensor, indptr: torch.Tensor, offsets: torch.Tensor, rotary_dim: Optional[int] = None, interleave: bool = False, rope_scale: float = 8, rope_theta: float = 5e5, low_freq_factor: float = 1, high_freq_factor: float = 4, old_context_len: int = 8192, ) -> None: r"""Apply Llama 3.1 style rotary embedding to a batch of queries/keys (stored as RaggedTensor) inplace. cos/sin values are computed on the fly inside the kernel. We use :attr:`indptr` to denote the start pointer of each segment in the batch, the i-th segment the query of the i-th segment is ``q[indptr[i]:indptr[i+1]]`` and the key of the i-th segment is ``k[indptr[i]:indptr[i+1]]``, the first element of :attr:`indptr` is always 0 and the last element of :attr:`indptr` is the total number of queries/keys in the batch. Please see :ref:`Ragged Tensor tutorial ` for more details about the ragged tensor. Parameters ---------- q : torch.Tensor Query ragged tensor, shape: ``(nnz, num_q_heads, head_dim)``, where ``nnz`` is the last element of ``indptr``. k : torch.Tensor Key ragged tensor, shape: ``(nnz, num_k_heads, head_dim)``, where ``nnz`` is the last element of ``indptr``. indptr : torch.Tensor Indptr tensor, shape: ``(batch_size + 1)``. offsets : torch.Tensor The relative position offsets of each query in the batch, shape: ``(batch_size)``. rotary_dim : Optional[int] The dimensions to apply RoPE, if ``None``, we apply RoPE to the entire head dimension, otherwise, we apply RoPE to the first ``rotary_dim`` dimensions, default: ``None``. interleave : bool Whether to use interleaved layout in the last dimension, default: ``False``. * If ``True``, the last dimension of the query/key tensor is interleaved, i.e., we rotate the even dimensions ``([..., ::2])`` and odd dimensions ``([..., 1::2])``. * If ``False``, the last dimension of the query/key tensor is not interleaved, i.e., we rotate the first half dimensions ``([..., :head_dim//2])`` and the second half dimensions ``([..., head_dim//2:])``. rope_scale : float The scaling factor used in the rope embedding, default: ``8``. rope_theta : float The theta value used in the rope embedding, default: ``5e5``. low_freq_factor : float The low frequency factor used in Llama 3.1 RoPE, default: ``1``. high_freq_factor : float The high frequency factor used in Llama 3.1 RoPE, default: ``4``. old_context_len : int The old context length used in Llama 3.1 RoPE, default: ``8192``. Examples -------- >>> import torch >>> import flashinfer >>> batch_size = 128 >>> qkv_len = 1024 >>> num_qo_heads = 32 >>> num_kv_heads = 32 >>> head_dim = 128 >>> nnz = batch_size * qkv_len >>> qkv_packed = torch.randn( >>> nnz, >>> (num_qo_heads + 2 * num_kv_heads) * head_dim, >>> dtype=torch.float16, >>> device="cuda:0", >>> ) >>> q = qkv_packed[:, : num_qo_heads * head_dim].reshape(nnz, num_qo_heads, head_dim) >>> k = qkv_packed[ ... :, num_qo_heads * head_dim : (num_qo_heads + num_kv_heads) * head_dim ... ].reshape(nnz, num_kv_heads, head_dim) >>> indptr = torch.tensor( ... [i * qkv_len for i in range(batch_size + 1)], dtype=torch.int32, device="cuda:0" >>> ) >>> offsets = torch.full((batch_size,), 10, dtype=torch.int32, device="cuda:0") >>> flashinfer.apply_llama31_rope_inplace(q, k, indptr, offsets) See Also -------- apply_llama31_rope """ if rotary_dim is None: rotary_dim = q.size(-1) _apply_llama31_rope( q, k, q, k, indptr, offsets, rotary_dim, interleave, rope_scale, rope_theta, low_freq_factor, high_freq_factor, float(old_context_len), ) def apply_llama31_rope_pos_ids_inplace( q: torch.Tensor, k: torch.Tensor, pos_ids: torch.Tensor, rotary_dim: Optional[int] = None, interleave: bool = False, rope_scale: float = 8, rope_theta: float = 5e5, low_freq_factor: float = 1, high_freq_factor: float = 4, old_context_len: int = 8192, ) -> None: r"""Apply Llama 3.1 style rotary embedding to a batch of queries/keys (stored as RaggedTensor) inplace. cos/sin values are computed on the fly inside the kernel. We use :attr:`indptr` to denote the start pointer of each segment in the batch, the i-th segment the query of the i-th segment is ``q[indptr[i]:indptr[i+1]]`` and the key of the i-th segment is ``k[indptr[i]:indptr[i+1]]``, the first element of :attr:`indptr` is always 0 and the last element of :attr:`indptr` is the total number of queries/keys in the batch. Please see :ref:`Ragged Tensor tutorial ` for more details about the ragged tensor. Parameters ---------- q : torch.Tensor Query ragged tensor, shape: ``(nnz, num_q_heads, head_dim)``, where ``nnz`` is the last element of ``indptr``. k : torch.Tensor Key ragged tensor, shape: ``(nnz, num_k_heads, head_dim)``, where ``nnz`` is the last element of ``indptr``. pos_ids : torch.Tensor Position indices, shape: ``(nnz)``. rotary_dim : Optional[int] The dimensions to apply RoPE, if ``None``, we apply RoPE to the entire head dimension, otherwise, we apply RoPE to the first ``rotary_dim`` dimensions, default: ``None``. interleave : bool Whether to use interleaved layout in the last dimension, default: ``False``. * If ``True``, the last dimension of the query/key tensor is interleaved, i.e., we rotate the even dimensions ``([..., ::2])`` and odd dimensions ``([..., 1::2])``. * If ``False``, the last dimension of the query/key tensor is not interleaved, i.e., we rotate the first half dimensions ``([..., :head_dim//2])`` and the second half dimensions ``([..., head_dim//2:])``. rope_scale : float The scaling factor used in the rope embedding, default: ``8``. rope_theta : float The theta value used in the rope embedding, default: ``5e5``. low_freq_factor : float The low frequency factor used in Llama 3.1 RoPE, default: ``1``. high_freq_factor : float The high frequency factor used in Llama 3.1 RoPE, default: ``4``. old_context_len : int The old context length used in Llama 3.1 RoPE, default: ``8192``. See Also -------- apply_llama31_rope_pos_ids """ if rotary_dim is None: rotary_dim = q.size(-1) _apply_llama31_rope_pos_ids( q, k, q, k, pos_ids, rotary_dim, interleave, rope_scale, rope_theta, low_freq_factor, high_freq_factor, float(old_context_len), ) def apply_rope( q: torch.Tensor, k: torch.Tensor, indptr: torch.Tensor, offsets: torch.Tensor, rotary_dim: Optional[int] = None, interleave: bool = False, rope_scale: float = 1, rope_theta: float = 1e4, ) -> Tuple[torch.Tensor, torch.Tensor]: r"""Apply rotary embedding to a batch of queries/keys (stored as RaggedTensor). cos/sin values are computed on the fly inside the kernel. We use :attr:`indptr` to denote the start pointer of each segment in the batch, the i-th segment the query of the i-th segment is ``q[indptr[i]:indptr[i+1]]`` and the key of the i-th segment is ``k[indptr[i]:indptr[i+1]]``, the first element of :attr:`indptr` is always 0 and the last element of :attr:`indptr` is the total number of queries/keys in the batch. Please see :ref:`Ragged Tensor tutorial ` for more details about the ragged tensor. Parameters ---------- q : torch.Tensor Query ragged tensor, shape: ``(nnz, num_q_heads, head_dim)`, where ``nnz`` is the last element of ``indptr``. k : torch.Tensor Key ragged tensor, shape: ``(nnz, num_k_heads, head_dim)``, where ``nnz`` is the last element of ``indptr``. indptr : torch.Tensor Indptr tensor, shape: ``(batch_size + 1)``. offsets : torch.Tensor The relative position offsets of each query in the batch, shape: ``(batch_size)``. rotary_dim : Optional[int] The dimensions to apply RoPE, if ``None``, we apply RoPE to the entire head dimension, otherwise, we apply RoPE to the first ``rotary_dim`` dimensions, default: ``None``. interleave : bool Whether to use interleaved layout in the last dimension, default: ``False``. * If ``True``, the last dimension of the query/key tensor is interleaved, i.e., we rotate the even dimensions ``([..., ::2])`` and odd dimensions ``([..., 1::2])``. * If ``False``, the last dimension of the query/key tensor is not interleaved, i.e., we rotate the first half dimensions ``([..., :head_dim//2])`` and the second half dimensions ``([..., head_dim//2:])``. rope_scale : float The scaling factor used in the rope embedding, default: ``1``. rope_theta : float The theta value used in the rope embedding, default: ``1e4``. Returns ------- q_rope : torch.Tensor The rotated query tensor, shape: ``(nnz, num_q_heads, head_dim)``. k_rope : torch.Tensor The rotated key tensor, shape: ``(nnz, num_k_heads, head_dim)``. Examples -------- >>> import torch >>> import flashinfer >>> batch_size = 128 >>> qkv_len = 1024 >>> num_qo_heads = 32 >>> num_kv_heads = 32 >>> head_dim = 128 >>> nnz = batch_size * qkv_len >>> qkv_packed = torch.randn( >>> nnz, >>> (num_qo_heads + 2 * num_kv_heads) * head_dim, >>> dtype=torch.float16, >>> device="cuda:0", >>> ) >>> q = qkv_packed[:, : num_qo_heads * head_dim].reshape(nnz, num_qo_heads, head_dim) >>> k = qkv_packed[ ... :, num_qo_heads * head_dim : (num_qo_heads + num_kv_heads) * head_dim ... ].reshape(nnz, num_kv_heads, head_dim) >>> indptr = torch.tensor( ... [i * qkv_len for i in range(batch_size + 1)], dtype=torch.int32, device="cuda:0" >>> ) >>> offsets = torch.full((batch_size,), 10, dtype=torch.int32, device="cuda:0") >>> q_rope, k_rope = flashinfer.apply_rope(q, k, indptr, offsets) >>> q_rope.shape torch.Size([131072, 32, 128]) >>> k_rope.shape torch.Size([131072, 32, 128]) See Also -------- apply_rope_inplace """ q_rope = torch.empty_like(q) k_rope = torch.empty_like(k) if rotary_dim is None: rotary_dim = q.size(-1) _apply_rope( q, k, q_rope, k_rope, indptr, offsets, rotary_dim, interleave, rope_scale, rope_theta, ) return q_rope, k_rope def apply_rope_pos_ids( q: torch.Tensor, k: torch.Tensor, pos_ids: torch.Tensor, rotary_dim: Optional[int] = None, interleave: bool = False, rope_scale: float = 1, rope_theta: float = 1e4, ) -> Tuple[torch.Tensor, torch.Tensor]: r"""Apply rotary embedding to a batch of queries/keys (stored as RaggedTensor). cos/sin values are computed on the fly inside the kernel. We use :attr:`indptr` to denote the start pointer of each segment in the batch, the i-th segment the query of the i-th segment is ``q[indptr[i]:indptr[i+1]]`` and the key of the i-th segment is ``k[indptr[i]:indptr[i+1]]``, the first element of :attr:`indptr` is always 0 and the last element of :attr:`indptr` is the total number of queries/keys in the batch. Please see :ref:`Ragged Tensor tutorial ` for more details about the ragged tensor. Parameters ---------- q : torch.Tensor Query ragged tensor, shape: ``(nnz, num_q_heads, head_dim)`, where ``nnz`` is the last element of ``indptr``. k : torch.Tensor Key ragged tensor, shape: ``(nnz, num_k_heads, head_dim)``, where ``nnz`` is the last element of ``indptr``. pos_ids : torch.Tensor Position indices, shape: ``(batch_size + 1)``. rotary_dim : Optional[int] The dimensions to apply RoPE, if ``None``, we apply RoPE to the entire head dimension, otherwise, we apply RoPE to the first ``rotary_dim`` dimensions, default: ``None``. interleave : bool Whether to use interleaved layout in the last dimension, default: ``False``. * If ``True``, the last dimension of the query/key tensor is interleaved, i.e., we rotate the even dimensions ``([..., ::2])`` and odd dimensions ``([..., 1::2])``. * If ``False``, the last dimension of the query/key tensor is not interleaved, i.e., we rotate the first half dimensions ``([..., :head_dim//2])`` and the second half dimensions ``([..., head_dim//2:])``. rope_scale : float The scaling factor used in the rope embedding, default: ``1``. rope_theta : float The theta value used in the rope embedding, default: ``1e4``. Returns ------- q_rope : torch.Tensor The rotated query tensor, shape: ``(nnz, num_q_heads, head_dim)``. k_rope : torch.Tensor The rotated key tensor, shape: ``(nnz, num_k_heads, head_dim)``. See Also -------- apply_rope_inplace """ q_rope = torch.empty_like(q) k_rope = torch.empty_like(k) if rotary_dim is None: rotary_dim = q.size(-1) _apply_rope_pos_ids( q, k, q_rope, k_rope, pos_ids, rotary_dim, interleave, rope_scale, rope_theta ) return q_rope, k_rope def apply_llama31_rope( q: torch.Tensor, k: torch.Tensor, indptr: torch.Tensor, offsets: torch.Tensor, rotary_dim: Optional[int] = None, interleave: bool = False, rope_scale: float = 8, rope_theta: float = 5e5, low_freq_factor: float = 1, high_freq_factor: float = 4, old_context_len: int = 8192, ) -> Tuple[torch.Tensor, torch.Tensor]: r"""Apply Llama 3.1 style rotary embedding to a batch of queries/keys (stored as RaggedTensor). cos/sin values are computed on the fly inside the kernel. We use :attr:`indptr` to denote the start pointer of each segment in the batch, the i-th segment the query of the i-th segment is ``q[indptr[i]:indptr[i+1]]`` and the key of the i-th segment is ``k[indptr[i]:indptr[i+1]]``, the first element of :attr:`indptr` is always 0 and the last element of :attr:`indptr` is the total number of queries/keys in the batch. Please see :ref:`Ragged Tensor tutorial ` for more details about the ragged tensor. Parameters ---------- q : torch.Tensor Query ragged tensor, shape: ``(nnz, num_q_heads, head_dim)``, where ``nnz`` is the last element of ``indptr``. k : torch.Tensor Key ragged tensor, shape: ``(nnz, num_k_heads, head_dim)``, where ``nnz`` is the last element of ``indptr``. indptr : torch.Tensor Indptr tensor, shape: ``(batch_size + 1)``. offsets : torch.Tensor The relative position offsets of each query in the batch, shape: ``(batch_size)``. rotary_dim : Optional[int] The dimensions to apply RoPE, if ``None``, we apply RoPE to the entire head dimension, otherwise, we apply RoPE to the first ``rotary_dim`` dimensions, default: ``None``. interleave : bool Whether to use interleaved layout in the last dimension, default: ``False``. * If ``True``, the last dimension of the query/key tensor is interleaved, i.e., we rotate the even dimensions ``([..., ::2])`` and odd dimensions ``([..., 1::2])``. * If ``False``, the last dimension of the query/key tensor is not interleaved, i.e., we rotate the first half dimensions ``([..., :head_dim//2])`` and the second half dimensions ``([..., head_dim//2:])``. rope_scale : float The scaling factor used in the rope embedding, default: ``8``. rope_theta : float The theta value used in the rope embedding, default: ``5e5``. low_freq_factor : float The low frequency factor used in Llama 3.1 RoPE, default: ``1``. high_freq_factor : float The high frequency factor used in Llama 3.1 RoPE, default: ``4``. old_context_len : int The old context length used in Llama 3.1 RoPE, default: ``8192``. Returns ------- q_rope : torch.Tensor The rotated query tensor, shape: ``(nnz, num_q_heads, head_dim)``. k_rope : torch.Tensor The rotated key tensor, shape: ``(nnz, num_k_heads, head_dim)``. Examples -------- >>> import torch >>> import flashinfer >>> batch_size = 128 >>> qkv_len = 1024 >>> num_qo_heads = 32 >>> num_kv_heads = 32 >>> head_dim = 128 >>> nnz = batch_size * qkv_len >>> qkv_packed = torch.randn( >>> nnz, >>> (num_qo_heads + 2 * num_kv_heads) * head_dim, >>> dtype=torch.float16, >>> device="cuda:0", >>> ) >>> q = qkv_packed[:, : num_qo_heads * head_dim].reshape(nnz, num_qo_heads, head_dim) >>> k = qkv_packed[ ... :, num_qo_heads * head_dim : (num_qo_heads + num_kv_heads) * head_dim ... ].reshape(nnz, num_kv_heads, head_dim) >>> indptr = torch.tensor( ... [i * qkv_len for i in range(batch_size + 1)], dtype=torch.int32, device="cuda:0" >>> ) >>> offsets = torch.full((batch_size,), 10, dtype=torch.int32, device="cuda:0") >>> q_rope, k_rope = flashinfer.apply_llama31_rope(q, k, indptr, offsets) >>> q_rope.shape torch.Size([131072, 32, 128]) >>> k_rope.shape torch.Size([131072, 32, 128]) See Also -------- apply_llama31_rope_inplace """ q_rope = torch.empty_like(q) k_rope = torch.empty_like(k) if rotary_dim is None: rotary_dim = q.size(-1) _apply_llama31_rope( q, k, q_rope, k_rope, indptr, offsets, rotary_dim, interleave, rope_scale, rope_theta, low_freq_factor, high_freq_factor, float(old_context_len), ) return q_rope, k_rope def apply_llama31_rope_pos_ids( q: torch.Tensor, k: torch.Tensor, pos_ids: torch.Tensor, rotary_dim: Optional[int] = None, interleave: bool = False, rope_scale: float = 8, rope_theta: float = 5e5, low_freq_factor: float = 1, high_freq_factor: float = 4, old_context_len: int = 8192, ) -> Tuple[torch.Tensor, torch.Tensor]: r"""Apply Llama 3.1 style rotary embedding to a batch of queries/keys (stored as RaggedTensor). cos/sin values are computed on the fly inside the kernel. We use :attr:`indptr` to denote the start pointer of each segment in the batch, the i-th segment the query of the i-th segment is ``q[indptr[i]:indptr[i+1]]`` and the key of the i-th segment is ``k[indptr[i]:indptr[i+1]]``, the first element of :attr:`indptr` is always 0 and the last element of :attr:`indptr` is the total number of queries/keys in the batch. Please see :ref:`Ragged Tensor tutorial ` for more details about the ragged tensor. Parameters ---------- q : torch.Tensor Query ragged tensor, shape: ``(nnz, num_q_heads, head_dim)``, where ``nnz`` is the last element of ``indptr``. k : torch.Tensor Key ragged tensor, shape: ``(nnz, num_k_heads, head_dim)``, where ``nnz`` is the last element of ``indptr``. pos_ids : torch.Tensor Position indices, shape: ``(nnz)``. rotary_dim : Optional[int] The dimensions to apply RoPE, if ``None``, we apply RoPE to the entire head dimension, otherwise, we apply RoPE to the first ``rotary_dim`` dimensions, default: ``None``. interleave : bool Whether to use interleaved layout in the last dimension, default: ``False``. * If ``True``, the last dimension of the query/key tensor is interleaved, i.e., we rotate the even dimensions ``([..., ::2])`` and odd dimensions ``([..., 1::2])``. * If ``False``, the last dimension of the query/key tensor is not interleaved, i.e., we rotate the first half dimensions ``([..., :head_dim//2])`` and the second half dimensions ``([..., head_dim//2:])`` rope_scale : float The scaling factor used in the rope embedding, default: ``8``. rope_theta : float The theta value used in the rope embedding, default: ``5e5``. low_freq_factor : float The low frequency factor used in Llama 3.1 RoPE, default: ``1``. high_freq_factor : float The high frequency factor used in Llama 3.1 RoPE, default: ``4``. old_context_len : int The old context length used in Llama 3.1 RoPE, default: ``8192``. Returns ------- q_rope : torch.Tensor The rotated query tensor, shape: ``(nnz, num_q_heads, head_dim)``. k_rope : torch.Tensor The rotated key tensor, shape: ``(nnz, num_k_heads, head_dim)``. See Also -------- apply_llama31_rope_pos_ids_inplace """ q_rope = torch.empty_like(q) k_rope = torch.empty_like(k) if rotary_dim is None: rotary_dim = q.size(-1) _apply_llama31_rope_pos_ids( q, k, q_rope, k_rope, pos_ids, rotary_dim, interleave, rope_scale, rope_theta, low_freq_factor, high_freq_factor, float(old_context_len), ) return q_rope, k_rope def apply_rope_with_cos_sin_cache( positions: torch.Tensor, query: torch.Tensor, key: torch.Tensor, head_size: int, cos_sin_cache: torch.Tensor, is_neox: bool = True, ) -> Tuple[torch.Tensor, torch.Tensor]: r""" Apply rotary embedding to keys and queries with precomputed cos/sin values. This is designed to be compatible with the SGL/vLLM implementation. Parameters ---------- positions : torch.Tensor Position indices, shape: ``(nnz)``. query : torch.Tensor Query tensor, shape: ``(nnz, num_q_heads * head_size)``. key : torch.Tensor Key tensor, shape: ``(nnz, num_k_heads * head_size)``. cos_sin_cache : torch.Tensor Cosine and Sine cache tensor, shape: ``(max_seq_len, rotary_dim)``. Cosine is the first half and Sine is the second half on rotary_dim. is_neox : bool Whether to use Neox style RoPE, default: ``True``. * If ``True``, the last dimension of the query/key tensor is not interleaved, i.e., we rorate the first half dimensions ``([..., :head_dim//2])`` and the second half dimensions ``([..., head_dim//2:])``. * If ``False``, the last dimension of the query/key tensor is interleaved, i.e., we rotate the even dimensions ``([..., ::2])`` and odd dimensions ``([..., 1::2])``. Returns ------- query_out : torch.Tensor The rotated query tensor, shape: ``(nnz, num_q_heads * head_size)``. key_out : torch.Tensor The rotated key tensor, shape: ``(nnz, num_k_heads * head_size)``. Note ---- The rotary dimension is determined by the cosine cache and sine cache. """ if cos_sin_cache.dtype != torch.float32: raise ValueError("cos_sin_cache should be float32") query_out = torch.empty_like(query) key_out = torch.empty_like(key) _apply_rope_pos_ids_cos_sin_cache( q=query.view(query.shape[0], -1, head_size), k=key.view(key.shape[0], -1, head_size), q_rope=query_out.view(query_out.shape[0], -1, head_size), k_rope=key_out.view(key_out.shape[0], -1, head_size), cos_sin_cache=cos_sin_cache, pos_ids=positions, interleave=(not is_neox), ) return query_out, key_out def apply_rope_with_cos_sin_cache_inplace( positions: torch.Tensor, query: torch.Tensor, key: torch.Tensor, head_size: int, cos_sin_cache: torch.Tensor, is_neox: bool = True, ) -> None: r""" Apply rotary embedding to keys and queries with precomputed cos/sin values. This is designed to be compatible with the SGL/vLLM implementation. The result is inplace applied to the input tensors. Parameters ---------- positions : torch.Tensor Position indices, shape: ``(nnz)``. query : torch.Tensor Query tensor, shape: ``(nnz, num_q_heads * head_size)``. key : torch.Tensor Key tensor, shape: ``(nnz, num_k_heads * head_size)``. cos_sin_cache : torch.Tensor Cosine and Sine cache tensor, shape: ``(max_seq_len, rotary_dim)``. Cosine is the first half and Sine is the second half on rotary_dim. is_neox : bool Whether to use Neox style RoPE, default: ``True``. * If ``True``, the last dimension of the query/key tensor is not interleaved, i.e., we rorate the first half dimensions ``([..., :head_dim//2])`` and the second half dimensions ``([..., head_dim//2:])``. * If ``False``, the last dimension of the query/key tensor is interleaved, i.e., we rotate the even dimensions ``([..., ::2])`` and odd dimensions ``([..., 1::2])``. Note ---- The rotary dimension is determined by the cosine cache and sine cache. """ if cos_sin_cache.dtype != torch.float32: raise ValueError("cos_sin_cache should be float32") # pass q_rope and k_rope as q and k to perform inplace operation _apply_rope_pos_ids_cos_sin_cache( q=query.view(query.shape[0], -1, head_size), k=key.view(key.shape[0], -1, head_size), q_rope=query.view(query.shape[0], -1, head_size), k_rope=key.view(key.shape[0], -1, head_size), cos_sin_cache=cos_sin_cache, pos_ids=positions, interleave=(not is_neox), ) def mla_rope_quantize_fp8( q_rope: torch.Tensor, k_rope: torch.Tensor, q_nope: torch.Tensor, k_nope: torch.Tensor, cos_sin_cache: torch.Tensor, pos_ids: torch.Tensor, is_neox: bool = True, quantize_dtype: Optional[torch.dtype] = None, quant_scale_q: float = 1.0, quant_scale_kv: float = 1.0, q_rope_out: Optional[torch.Tensor] = None, k_rope_out: Optional[torch.Tensor] = None, q_nope_out: Optional[torch.Tensor] = None, k_nope_out: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: if cos_sin_cache.dtype != torch.float32: raise ValueError("cos_sin_cache should be float32") # Infer quantize_dtype from output tensors or default to float8_e4m3fn if quantize_dtype is None: for out in (q_rope_out, k_rope_out, q_nope_out, k_nope_out): if out is not None: quantize_dtype = out.dtype break else: quantize_dtype = torch.float8_e4m3fn # Allocate output tensors if not provided q_rope_out = ( q_rope_out if q_rope_out is not None else torch.empty_like(q_rope, dtype=quantize_dtype) ) k_rope_out = ( k_rope_out if k_rope_out is not None else torch.empty_like(k_rope, dtype=quantize_dtype) ) q_nope_out = ( q_nope_out if q_nope_out is not None else torch.empty_like(q_nope, dtype=quantize_dtype) ) k_nope_out = ( k_nope_out if k_nope_out is not None else torch.empty_like(k_nope, dtype=quantize_dtype) ) _mla_rope_quantize( q_rope, k_rope, q_nope, k_nope, cos_sin_cache, pos_ids, q_rope_out, k_rope_out, q_nope_out, k_nope_out, quant_scale_q, quant_scale_kv, not is_neox, # interleave ) return q_rope_out, k_rope_out, q_nope_out, k_nope_out