1221 lines
40 KiB
Python
1221 lines
40 KiB
Python
"""
|
|
Copyright (c) 2024 by FlashInfer team.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
"""
|
|
|
|
import functools
|
|
from typing import Optional, Tuple
|
|
|
|
import torch
|
|
|
|
from .jit import JitSpec
|
|
from .jit import env as jit_env
|
|
from .jit import gen_jit_spec
|
|
from .utils import register_custom_op, register_fake_op
|
|
|
|
|
|
def gen_rope_module() -> JitSpec:
|
|
return gen_jit_spec(
|
|
"rope",
|
|
[
|
|
jit_env.FLASHINFER_CSRC_DIR / "rope.cu",
|
|
jit_env.FLASHINFER_CSRC_DIR / "flashinfer_rope_ops.cu",
|
|
],
|
|
)
|
|
|
|
|
|
@functools.cache
|
|
def get_rope_module():
|
|
return gen_rope_module().build_and_load()
|
|
|
|
|
|
@register_custom_op("flashinfer::apply_rope", mutates_args=("q_rope", "k_rope"))
|
|
def _apply_rope(
|
|
q: torch.Tensor,
|
|
k: torch.Tensor,
|
|
q_rope: torch.Tensor,
|
|
k_rope: torch.Tensor,
|
|
indptr: torch.Tensor,
|
|
offsets: torch.Tensor,
|
|
rotary_dim: int,
|
|
interleave: bool,
|
|
rope_scale: float,
|
|
rope_theta: float,
|
|
) -> None:
|
|
get_rope_module().apply_rope(
|
|
q,
|
|
k,
|
|
q_rope,
|
|
k_rope,
|
|
indptr,
|
|
offsets,
|
|
rotary_dim,
|
|
interleave,
|
|
rope_scale,
|
|
rope_theta,
|
|
)
|
|
|
|
|
|
@register_fake_op("flashinfer::apply_rope")
|
|
def _fake_apply_rope(
|
|
q: torch.Tensor,
|
|
k: torch.Tensor,
|
|
q_rope: torch.Tensor,
|
|
k_rope: torch.Tensor,
|
|
indptr: torch.Tensor,
|
|
offsets: torch.Tensor,
|
|
rotary_dim: int,
|
|
interleave: bool,
|
|
rope_scale: float,
|
|
rope_theta: float,
|
|
) -> None:
|
|
pass
|
|
|
|
|
|
@register_custom_op("flashinfer::apply_llama31_rope", mutates_args=("q_rope", "k_rope"))
|
|
def _apply_llama31_rope(
|
|
q: torch.Tensor,
|
|
k: torch.Tensor,
|
|
q_rope: torch.Tensor,
|
|
k_rope: torch.Tensor,
|
|
indptr: torch.Tensor,
|
|
offsets: torch.Tensor,
|
|
rotary_dim: int,
|
|
interleave: bool,
|
|
rope_scale: float,
|
|
rope_theta: float,
|
|
low_freq_factor: float,
|
|
high_freq_factor: float,
|
|
old_context_len: float,
|
|
) -> None:
|
|
get_rope_module().apply_llama31_rope(
|
|
q,
|
|
k,
|
|
q_rope,
|
|
k_rope,
|
|
indptr,
|
|
offsets,
|
|
rotary_dim,
|
|
interleave,
|
|
rope_scale,
|
|
rope_theta,
|
|
low_freq_factor,
|
|
high_freq_factor,
|
|
old_context_len,
|
|
)
|
|
|
|
|
|
@register_fake_op("flashinfer::apply_llama31_rope")
|
|
def _fake_apply_llama31_rope(
|
|
q: torch.Tensor,
|
|
k: torch.Tensor,
|
|
q_rope: torch.Tensor,
|
|
k_rope: torch.Tensor,
|
|
indptr: torch.Tensor,
|
|
offsets: torch.Tensor,
|
|
rotary_dim: int,
|
|
interleave: bool,
|
|
rope_scale: float,
|
|
rope_theta: float,
|
|
low_freq_factor: float,
|
|
high_freq_factor: float,
|
|
old_context_len: float,
|
|
) -> None:
|
|
pass
|
|
|
|
|
|
@register_custom_op("flashinfer::apply_rope_pos_ids", mutates_args=("q_rope", "k_rope"))
|
|
def _apply_rope_pos_ids(
|
|
q: torch.Tensor,
|
|
k: torch.Tensor,
|
|
q_rope: torch.Tensor,
|
|
k_rope: torch.Tensor,
|
|
pos_ids: torch.Tensor,
|
|
rotary_dim: int,
|
|
interleave: bool,
|
|
rope_scale: float,
|
|
rope_theta: float,
|
|
) -> None:
|
|
get_rope_module().apply_rope_pos_ids(
|
|
q,
|
|
k,
|
|
q_rope,
|
|
k_rope,
|
|
pos_ids,
|
|
rotary_dim,
|
|
interleave,
|
|
rope_scale,
|
|
rope_theta,
|
|
)
|
|
|
|
|
|
@register_fake_op("flashinfer::apply_rope_pos_ids")
|
|
def _fake_apply_rope_pos_ids(
|
|
q: torch.Tensor,
|
|
k: torch.Tensor,
|
|
q_rope: torch.Tensor,
|
|
k_rope: torch.Tensor,
|
|
pos_ids: torch.Tensor,
|
|
rotary_dim: int,
|
|
interleave: bool,
|
|
rope_scale: float,
|
|
rope_theta: float,
|
|
) -> None:
|
|
pass
|
|
|
|
|
|
@register_custom_op(
|
|
"flashinfer::mla_rope_quantize",
|
|
mutates_args=("q_rope_out", "k_rope_out", "q_nope_out", "k_nope_out"),
|
|
)
|
|
def _mla_rope_quantize(
|
|
q_rope_in: torch.Tensor,
|
|
k_rope_in: torch.Tensor,
|
|
q_nope_in: torch.Tensor,
|
|
k_nope_in: torch.Tensor,
|
|
cos_sin_cache: torch.Tensor,
|
|
pos_ids: torch.Tensor,
|
|
q_rope_out: torch.Tensor,
|
|
k_rope_out: torch.Tensor,
|
|
q_nope_out: torch.Tensor,
|
|
k_nope_out: torch.Tensor,
|
|
quant_scale_q: float,
|
|
quant_scale_kv: float,
|
|
interleave: bool,
|
|
) -> None:
|
|
get_rope_module().mla_rope_quantize(
|
|
q_rope_in,
|
|
k_rope_in,
|
|
q_nope_in,
|
|
k_nope_in,
|
|
q_rope_out,
|
|
k_rope_out,
|
|
q_nope_out,
|
|
k_nope_out,
|
|
cos_sin_cache,
|
|
pos_ids,
|
|
quant_scale_q,
|
|
quant_scale_kv,
|
|
interleave,
|
|
)
|
|
|
|
|
|
@register_fake_op("flashinfer::mla_rope_quantize")
|
|
def _fake_mla_rope_quantize(
|
|
q_rope_in: torch.Tensor,
|
|
k_rope_in: torch.Tensor,
|
|
q_nope_in: torch.Tensor,
|
|
k_nope_in: torch.Tensor,
|
|
cos_sin_cache: torch.Tensor,
|
|
pos_ids: torch.Tensor,
|
|
q_rope_out: torch.Tensor,
|
|
k_rope_out: torch.Tensor,
|
|
q_nope_out: torch.Tensor,
|
|
k_nope_out: torch.Tensor,
|
|
quant_scale_q: float,
|
|
quant_scale_kv: float,
|
|
interleave: bool,
|
|
) -> None:
|
|
pass
|
|
|
|
|
|
@register_custom_op(
|
|
"flashinfer::apply_rope_pos_ids_cos_sin_cache", mutates_args=("q_rope", "k_rope")
|
|
)
|
|
def _apply_rope_pos_ids_cos_sin_cache(
|
|
q: torch.Tensor,
|
|
k: torch.Tensor,
|
|
q_rope: torch.Tensor,
|
|
k_rope: torch.Tensor,
|
|
cos_sin_cache: torch.Tensor,
|
|
pos_ids: torch.Tensor,
|
|
interleave: bool,
|
|
) -> None:
|
|
get_rope_module().apply_rope_pos_ids_cos_sin_cache(
|
|
q,
|
|
k,
|
|
q_rope,
|
|
k_rope,
|
|
cos_sin_cache,
|
|
pos_ids,
|
|
interleave,
|
|
)
|
|
|
|
|
|
@register_fake_op("flashinfer::apply_rope_pos_ids_cos_sin_cache")
|
|
def _fake_apply_rope_pos_ids_cos_sin_cache(
|
|
q: torch.Tensor,
|
|
k: torch.Tensor,
|
|
q_rope: torch.Tensor,
|
|
k_rope: torch.Tensor,
|
|
cos_cache: torch.Tensor,
|
|
sin_cache: torch.Tensor,
|
|
pos_ids: torch.Tensor,
|
|
interleave: bool,
|
|
) -> None:
|
|
pass
|
|
|
|
|
|
@register_custom_op(
|
|
"flashinfer::apply_llama31_rope_pos_ids", mutates_args=("q_rope", "k_rope")
|
|
)
|
|
def _apply_llama31_rope_pos_ids(
|
|
q: torch.Tensor,
|
|
k: torch.Tensor,
|
|
q_rope: torch.Tensor,
|
|
k_rope: torch.Tensor,
|
|
pos_ids: torch.Tensor,
|
|
rotary_dim: int,
|
|
interleave: bool,
|
|
rope_scale: float,
|
|
rope_theta: float,
|
|
low_freq_factor: float,
|
|
high_freq_factor: float,
|
|
old_context_len: float,
|
|
) -> None:
|
|
get_rope_module().apply_llama31_rope_pos_ids(
|
|
q,
|
|
k,
|
|
q_rope,
|
|
k_rope,
|
|
pos_ids,
|
|
rotary_dim,
|
|
interleave,
|
|
rope_scale,
|
|
rope_theta,
|
|
low_freq_factor,
|
|
high_freq_factor,
|
|
old_context_len,
|
|
)
|
|
|
|
|
|
@register_fake_op("flashinfer::apply_llama31_rope_pos_ids")
|
|
def _fake_apply_llama31_rope_pos_ids(
|
|
q: torch.Tensor,
|
|
k: torch.Tensor,
|
|
q_rope: torch.Tensor,
|
|
k_rope: torch.Tensor,
|
|
pos_ids: torch.Tensor,
|
|
rotary_dim: int,
|
|
interleave: bool,
|
|
rope_scale: float,
|
|
rope_theta: float,
|
|
low_freq_factor: float,
|
|
high_freq_factor: float,
|
|
old_context_len: float,
|
|
) -> None:
|
|
pass
|
|
|
|
|
|
def apply_rope_inplace(
|
|
q: torch.Tensor,
|
|
k: torch.Tensor,
|
|
indptr: torch.Tensor,
|
|
offsets: torch.Tensor,
|
|
rotary_dim: Optional[int] = None,
|
|
interleave: bool = False,
|
|
rope_scale: float = 1,
|
|
rope_theta: float = 1e4,
|
|
) -> None:
|
|
r"""Apply rotary embedding to a batch of queries/keys (stored as RaggedTensor) inplace.
|
|
cos/sin values are computed on the fly inside the kernel.
|
|
|
|
We use :attr:`indptr` to denote the start pointer of each segment in the batch, the i-th
|
|
segment the query of the i-th segment is ``q[indptr[i]:indptr[i+1]]`` and the key of the
|
|
i-th segment is ``k[indptr[i]:indptr[i+1]]``, the first element of :attr:`indptr` is always
|
|
0 and the last element of :attr:`indptr` is the total number of queries/keys in the batch.
|
|
Please see :ref:`Ragged Tensor tutorial <kv-layout>` for more details about the
|
|
ragged tensor.
|
|
|
|
Parameters
|
|
----------
|
|
q : torch.Tensor
|
|
Query ragged tensor, shape: ``(nnz, num_q_heads, head_dim)`, where ``nnz`` is the last
|
|
element of ``indptr``.
|
|
k : torch.Tensor
|
|
Key ragged tensor, shape: ``(nnz, num_k_heads, head_dim)``, where ``nnz`` is the last
|
|
element of ``indptr``.
|
|
indptr : torch.Tensor
|
|
Indptr tensor, shape: ``(batch_size + 1)``.
|
|
offsets : torch.Tensor
|
|
The relative position offsets of each query in the batch, shape: ``(batch_size)``.
|
|
rotary_dim : Optional[int]
|
|
The dimensions to apply RoPE, if ``None``, we apply RoPE to the entire head dimension,
|
|
otherwise, we apply RoPE to the first ``rotary_dim`` dimensions, default: ``None``.
|
|
interleave : bool
|
|
Whether to use interleaved layout in the last dimension, default: ``False``.
|
|
|
|
* If ``True``, the last dimension of the query/key tensor is interleaved, i.e.,
|
|
we rotate the even dimensions ``([..., ::2])`` and odd dimensions ``([..., 1::2])``.
|
|
|
|
* If ``False``, the last dimension of the query/key tensor is not interleaved, i.e.,
|
|
we rotate the first half dimensions ``([..., :head_dim//2])`` and the second half
|
|
dimensions ``([..., head_dim//2:])``.
|
|
|
|
rope_scale : float
|
|
The scaling factor used in the rope embedding, default: ``1``.
|
|
rope_theta : float
|
|
The theta value used in the rope embedding, default: ``1e4``.
|
|
|
|
Examples
|
|
--------
|
|
>>> import torch
|
|
>>> import flashinfer
|
|
>>> batch_size = 128
|
|
>>> qkv_len = 1024
|
|
>>> num_qo_heads = 32
|
|
>>> num_kv_heads = 32
|
|
>>> head_dim = 128
|
|
>>> nnz = batch_size * qkv_len
|
|
>>> qkv_packed = torch.randn(
|
|
>>> nnz,
|
|
>>> (num_qo_heads + 2 * num_kv_heads) * head_dim,
|
|
>>> dtype=torch.float16,
|
|
>>> device="cuda:0",
|
|
>>> )
|
|
>>> q = qkv_packed[:, : num_qo_heads * head_dim].reshape(nnz, num_qo_heads, head_dim)
|
|
>>> k = qkv_packed[
|
|
... :, num_qo_heads * head_dim : (num_qo_heads + num_kv_heads) * head_dim
|
|
... ].reshape(nnz, num_kv_heads, head_dim)
|
|
>>> indptr = torch.tensor(
|
|
... [i * qkv_len for i in range(batch_size + 1)], dtype=torch.int32, device="cuda:0"
|
|
>>> )
|
|
>>> offsets = torch.full((batch_size,), 10, dtype=torch.int32, device="cuda:0")
|
|
>>> flashinfer.apply_rope_inplace(q, k, indptr, offsets)
|
|
|
|
See Also
|
|
--------
|
|
apply_rope
|
|
"""
|
|
if rotary_dim is None:
|
|
rotary_dim = q.size(-1)
|
|
_apply_rope(
|
|
q, k, q, k, indptr, offsets, rotary_dim, interleave, rope_scale, rope_theta
|
|
)
|
|
|
|
|
|
def apply_rope_pos_ids_inplace(
|
|
q: torch.Tensor,
|
|
k: torch.Tensor,
|
|
pos_ids: torch.Tensor,
|
|
rotary_dim: Optional[int] = None,
|
|
interleave: bool = False,
|
|
rope_scale: float = 1,
|
|
rope_theta: float = 1e4,
|
|
) -> None:
|
|
r"""Apply rotary embedding to a batch of queries/keys (stored as RaggedTensor) inplace.
|
|
cos/sin values are computed on the fly inside the kernel.
|
|
|
|
We use :attr:`indptr` to denote the start pointer of each segment in the batch, the i-th
|
|
segment the query of the i-th segment is ``q[indptr[i]:indptr[i+1]]`` and the key of the
|
|
i-th segment is ``k[indptr[i]:indptr[i+1]]``, the first element of :attr:`indptr` is always
|
|
0 and the last element of :attr:`indptr` is the total number of queries/keys in the batch.
|
|
Please see :ref:`Ragged Tensor tutorial <kv-layout>` for more details about the
|
|
ragged tensor.
|
|
|
|
Parameters
|
|
----------
|
|
q : torch.Tensor
|
|
Query ragged tensor, shape: ``(nnz, num_q_heads, head_dim)`, where ``nnz`` is the last
|
|
element of ``indptr``.
|
|
k : torch.Tensor
|
|
Key ragged tensor, shape: ``(nnz, num_k_heads, head_dim)``, where ``nnz`` is the last
|
|
element of ``indptr``.
|
|
pos_ids : torch.Tensor
|
|
Position indices, shape: ``(nnz)``.
|
|
rotary_dim : Optional[int]
|
|
The dimensions to apply RoPE, if ``None``, we apply RoPE to the entire head dimension,
|
|
otherwise, we apply RoPE to the first ``rotary_dim`` dimensions, default: ``None``.
|
|
interleave : bool
|
|
Whether to use interleaved layout in the last dimension, default: ``False``.
|
|
|
|
* If ``True``, the last dimension of the query/key tensor is interleaved, i.e.,
|
|
we rotate the even dimensions ``([..., ::2])`` and odd dimensions ``([..., 1::2])``.
|
|
|
|
* If ``False``, the last dimension of the query/key tensor is not interleaved, i.e.,
|
|
we rotate the first half dimensions ``([..., :head_dim//2])`` and the second half
|
|
dimensions ``([..., head_dim//2:])``.
|
|
|
|
rope_scale : float
|
|
The scaling factor used in the rope embedding, default: ``1``.
|
|
rope_theta : float
|
|
The theta value used in the rope embedding, default: ``1e4``.
|
|
|
|
See Also
|
|
--------
|
|
apply_rope_pos_ids
|
|
"""
|
|
if rotary_dim is None:
|
|
rotary_dim = q.size(-1)
|
|
_apply_rope_pos_ids(
|
|
q, k, q, k, pos_ids, rotary_dim, interleave, rope_scale, rope_theta
|
|
)
|
|
|
|
|
|
def apply_llama31_rope_inplace(
|
|
q: torch.Tensor,
|
|
k: torch.Tensor,
|
|
indptr: torch.Tensor,
|
|
offsets: torch.Tensor,
|
|
rotary_dim: Optional[int] = None,
|
|
interleave: bool = False,
|
|
rope_scale: float = 8,
|
|
rope_theta: float = 5e5,
|
|
low_freq_factor: float = 1,
|
|
high_freq_factor: float = 4,
|
|
old_context_len: int = 8192,
|
|
) -> None:
|
|
r"""Apply Llama 3.1 style rotary embedding to a batch of queries/keys (stored as
|
|
RaggedTensor) inplace. cos/sin values are computed on the fly inside the kernel.
|
|
|
|
We use :attr:`indptr` to denote the start pointer of each segment in the batch, the i-th
|
|
segment the query of the i-th segment is ``q[indptr[i]:indptr[i+1]]`` and the key of the
|
|
i-th segment is ``k[indptr[i]:indptr[i+1]]``, the first element of :attr:`indptr` is always
|
|
0 and the last element of :attr:`indptr` is the total number of queries/keys in the batch.
|
|
Please see :ref:`Ragged Tensor tutorial <kv-layout>` for more details about the
|
|
ragged tensor.
|
|
|
|
Parameters
|
|
----------
|
|
q : torch.Tensor
|
|
Query ragged tensor, shape: ``(nnz, num_q_heads, head_dim)``, where ``nnz`` is the last
|
|
element of ``indptr``.
|
|
k : torch.Tensor
|
|
Key ragged tensor, shape: ``(nnz, num_k_heads, head_dim)``, where ``nnz`` is the last
|
|
element of ``indptr``.
|
|
indptr : torch.Tensor
|
|
Indptr tensor, shape: ``(batch_size + 1)``.
|
|
offsets : torch.Tensor
|
|
The relative position offsets of each query in the batch, shape: ``(batch_size)``.
|
|
rotary_dim : Optional[int]
|
|
The dimensions to apply RoPE, if ``None``, we apply RoPE to the entire head dimension,
|
|
otherwise, we apply RoPE to the first ``rotary_dim`` dimensions, default: ``None``.
|
|
interleave : bool
|
|
Whether to use interleaved layout in the last dimension, default: ``False``.
|
|
|
|
* If ``True``, the last dimension of the query/key tensor is interleaved, i.e.,
|
|
we rotate the even dimensions ``([..., ::2])`` and odd dimensions ``([..., 1::2])``.
|
|
|
|
* If ``False``, the last dimension of the query/key tensor is not interleaved, i.e.,
|
|
we rotate the first half dimensions ``([..., :head_dim//2])`` and the second half
|
|
dimensions ``([..., head_dim//2:])``.
|
|
|
|
rope_scale : float
|
|
The scaling factor used in the rope embedding, default: ``8``.
|
|
rope_theta : float
|
|
The theta value used in the rope embedding, default: ``5e5``.
|
|
low_freq_factor : float
|
|
The low frequency factor used in Llama 3.1 RoPE, default: ``1``.
|
|
high_freq_factor : float
|
|
The high frequency factor used in Llama 3.1 RoPE, default: ``4``.
|
|
old_context_len : int
|
|
The old context length used in Llama 3.1 RoPE, default: ``8192``.
|
|
|
|
Examples
|
|
--------
|
|
>>> import torch
|
|
>>> import flashinfer
|
|
>>> batch_size = 128
|
|
>>> qkv_len = 1024
|
|
>>> num_qo_heads = 32
|
|
>>> num_kv_heads = 32
|
|
>>> head_dim = 128
|
|
>>> nnz = batch_size * qkv_len
|
|
>>> qkv_packed = torch.randn(
|
|
>>> nnz,
|
|
>>> (num_qo_heads + 2 * num_kv_heads) * head_dim,
|
|
>>> dtype=torch.float16,
|
|
>>> device="cuda:0",
|
|
>>> )
|
|
>>> q = qkv_packed[:, : num_qo_heads * head_dim].reshape(nnz, num_qo_heads, head_dim)
|
|
>>> k = qkv_packed[
|
|
... :, num_qo_heads * head_dim : (num_qo_heads + num_kv_heads) * head_dim
|
|
... ].reshape(nnz, num_kv_heads, head_dim)
|
|
>>> indptr = torch.tensor(
|
|
... [i * qkv_len for i in range(batch_size + 1)], dtype=torch.int32, device="cuda:0"
|
|
>>> )
|
|
>>> offsets = torch.full((batch_size,), 10, dtype=torch.int32, device="cuda:0")
|
|
>>> flashinfer.apply_llama31_rope_inplace(q, k, indptr, offsets)
|
|
|
|
See Also
|
|
--------
|
|
apply_llama31_rope
|
|
"""
|
|
if rotary_dim is None:
|
|
rotary_dim = q.size(-1)
|
|
_apply_llama31_rope(
|
|
q,
|
|
k,
|
|
q,
|
|
k,
|
|
indptr,
|
|
offsets,
|
|
rotary_dim,
|
|
interleave,
|
|
rope_scale,
|
|
rope_theta,
|
|
low_freq_factor,
|
|
high_freq_factor,
|
|
float(old_context_len),
|
|
)
|
|
|
|
|
|
def apply_llama31_rope_pos_ids_inplace(
|
|
q: torch.Tensor,
|
|
k: torch.Tensor,
|
|
pos_ids: torch.Tensor,
|
|
rotary_dim: Optional[int] = None,
|
|
interleave: bool = False,
|
|
rope_scale: float = 8,
|
|
rope_theta: float = 5e5,
|
|
low_freq_factor: float = 1,
|
|
high_freq_factor: float = 4,
|
|
old_context_len: int = 8192,
|
|
) -> None:
|
|
r"""Apply Llama 3.1 style rotary embedding to a batch of queries/keys (stored as
|
|
RaggedTensor) inplace. cos/sin values are computed on the fly inside the kernel.
|
|
|
|
We use :attr:`indptr` to denote the start pointer of each segment in the batch, the i-th
|
|
segment the query of the i-th segment is ``q[indptr[i]:indptr[i+1]]`` and the key of the
|
|
i-th segment is ``k[indptr[i]:indptr[i+1]]``, the first element of :attr:`indptr` is always
|
|
0 and the last element of :attr:`indptr` is the total number of queries/keys in the batch.
|
|
Please see :ref:`Ragged Tensor tutorial <kv-layout>` for more details about the
|
|
ragged tensor.
|
|
|
|
Parameters
|
|
----------
|
|
q : torch.Tensor
|
|
Query ragged tensor, shape: ``(nnz, num_q_heads, head_dim)``, where ``nnz`` is the last
|
|
element of ``indptr``.
|
|
k : torch.Tensor
|
|
Key ragged tensor, shape: ``(nnz, num_k_heads, head_dim)``, where ``nnz`` is the last
|
|
element of ``indptr``.
|
|
pos_ids : torch.Tensor
|
|
Position indices, shape: ``(nnz)``.
|
|
rotary_dim : Optional[int]
|
|
The dimensions to apply RoPE, if ``None``, we apply RoPE to the entire head dimension,
|
|
otherwise, we apply RoPE to the first ``rotary_dim`` dimensions, default: ``None``.
|
|
interleave : bool
|
|
Whether to use interleaved layout in the last dimension, default: ``False``.
|
|
|
|
* If ``True``, the last dimension of the query/key tensor is interleaved, i.e.,
|
|
we rotate the even dimensions ``([..., ::2])`` and odd dimensions ``([..., 1::2])``.
|
|
|
|
* If ``False``, the last dimension of the query/key tensor is not interleaved, i.e.,
|
|
we rotate the first half dimensions ``([..., :head_dim//2])`` and the second half
|
|
dimensions ``([..., head_dim//2:])``.
|
|
|
|
rope_scale : float
|
|
The scaling factor used in the rope embedding, default: ``8``.
|
|
rope_theta : float
|
|
The theta value used in the rope embedding, default: ``5e5``.
|
|
low_freq_factor : float
|
|
The low frequency factor used in Llama 3.1 RoPE, default: ``1``.
|
|
high_freq_factor : float
|
|
The high frequency factor used in Llama 3.1 RoPE, default: ``4``.
|
|
old_context_len : int
|
|
The old context length used in Llama 3.1 RoPE, default: ``8192``.
|
|
|
|
See Also
|
|
--------
|
|
apply_llama31_rope_pos_ids
|
|
"""
|
|
if rotary_dim is None:
|
|
rotary_dim = q.size(-1)
|
|
_apply_llama31_rope_pos_ids(
|
|
q,
|
|
k,
|
|
q,
|
|
k,
|
|
pos_ids,
|
|
rotary_dim,
|
|
interleave,
|
|
rope_scale,
|
|
rope_theta,
|
|
low_freq_factor,
|
|
high_freq_factor,
|
|
float(old_context_len),
|
|
)
|
|
|
|
|
|
def apply_rope(
|
|
q: torch.Tensor,
|
|
k: torch.Tensor,
|
|
indptr: torch.Tensor,
|
|
offsets: torch.Tensor,
|
|
rotary_dim: Optional[int] = None,
|
|
interleave: bool = False,
|
|
rope_scale: float = 1,
|
|
rope_theta: float = 1e4,
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
r"""Apply rotary embedding to a batch of queries/keys (stored as RaggedTensor).
|
|
cos/sin values are computed on the fly inside the kernel.
|
|
|
|
We use :attr:`indptr` to denote the start pointer of each segment in the batch, the i-th
|
|
segment the query of the i-th segment is ``q[indptr[i]:indptr[i+1]]`` and the key of the
|
|
i-th segment is ``k[indptr[i]:indptr[i+1]]``, the first element of :attr:`indptr` is always
|
|
0 and the last element of :attr:`indptr` is the total number of queries/keys in the batch.
|
|
Please see :ref:`Ragged Tensor tutorial <kv-layout>` for more details about the
|
|
ragged tensor.
|
|
|
|
Parameters
|
|
----------
|
|
q : torch.Tensor
|
|
Query ragged tensor, shape: ``(nnz, num_q_heads, head_dim)`, where ``nnz`` is the last
|
|
element of ``indptr``.
|
|
k : torch.Tensor
|
|
Key ragged tensor, shape: ``(nnz, num_k_heads, head_dim)``, where ``nnz`` is the last
|
|
element of ``indptr``.
|
|
indptr : torch.Tensor
|
|
Indptr tensor, shape: ``(batch_size + 1)``.
|
|
offsets : torch.Tensor
|
|
The relative position offsets of each query in the batch, shape: ``(batch_size)``.
|
|
rotary_dim : Optional[int]
|
|
The dimensions to apply RoPE, if ``None``, we apply RoPE to the entire head dimension,
|
|
otherwise, we apply RoPE to the first ``rotary_dim`` dimensions, default: ``None``.
|
|
interleave : bool
|
|
Whether to use interleaved layout in the last dimension, default: ``False``.
|
|
|
|
* If ``True``, the last dimension of the query/key tensor is interleaved, i.e.,
|
|
we rotate the even dimensions ``([..., ::2])`` and odd dimensions ``([..., 1::2])``.
|
|
|
|
* If ``False``, the last dimension of the query/key tensor is not interleaved, i.e.,
|
|
we rotate the first half dimensions ``([..., :head_dim//2])`` and the second half
|
|
dimensions ``([..., head_dim//2:])``.
|
|
|
|
rope_scale : float
|
|
The scaling factor used in the rope embedding, default: ``1``.
|
|
rope_theta : float
|
|
The theta value used in the rope embedding, default: ``1e4``.
|
|
|
|
Returns
|
|
-------
|
|
q_rope : torch.Tensor
|
|
The rotated query tensor, shape: ``(nnz, num_q_heads, head_dim)``.
|
|
k_rope : torch.Tensor
|
|
The rotated key tensor, shape: ``(nnz, num_k_heads, head_dim)``.
|
|
|
|
Examples
|
|
--------
|
|
>>> import torch
|
|
>>> import flashinfer
|
|
>>> batch_size = 128
|
|
>>> qkv_len = 1024
|
|
>>> num_qo_heads = 32
|
|
>>> num_kv_heads = 32
|
|
>>> head_dim = 128
|
|
>>> nnz = batch_size * qkv_len
|
|
>>> qkv_packed = torch.randn(
|
|
>>> nnz,
|
|
>>> (num_qo_heads + 2 * num_kv_heads) * head_dim,
|
|
>>> dtype=torch.float16,
|
|
>>> device="cuda:0",
|
|
>>> )
|
|
>>> q = qkv_packed[:, : num_qo_heads * head_dim].reshape(nnz, num_qo_heads, head_dim)
|
|
>>> k = qkv_packed[
|
|
... :, num_qo_heads * head_dim : (num_qo_heads + num_kv_heads) * head_dim
|
|
... ].reshape(nnz, num_kv_heads, head_dim)
|
|
>>> indptr = torch.tensor(
|
|
... [i * qkv_len for i in range(batch_size + 1)], dtype=torch.int32, device="cuda:0"
|
|
>>> )
|
|
>>> offsets = torch.full((batch_size,), 10, dtype=torch.int32, device="cuda:0")
|
|
>>> q_rope, k_rope = flashinfer.apply_rope(q, k, indptr, offsets)
|
|
>>> q_rope.shape
|
|
torch.Size([131072, 32, 128])
|
|
>>> k_rope.shape
|
|
torch.Size([131072, 32, 128])
|
|
|
|
See Also
|
|
--------
|
|
apply_rope_inplace
|
|
"""
|
|
q_rope = torch.empty_like(q)
|
|
k_rope = torch.empty_like(k)
|
|
if rotary_dim is None:
|
|
rotary_dim = q.size(-1)
|
|
_apply_rope(
|
|
q,
|
|
k,
|
|
q_rope,
|
|
k_rope,
|
|
indptr,
|
|
offsets,
|
|
rotary_dim,
|
|
interleave,
|
|
rope_scale,
|
|
rope_theta,
|
|
)
|
|
return q_rope, k_rope
|
|
|
|
|
|
def apply_rope_pos_ids(
|
|
q: torch.Tensor,
|
|
k: torch.Tensor,
|
|
pos_ids: torch.Tensor,
|
|
rotary_dim: Optional[int] = None,
|
|
interleave: bool = False,
|
|
rope_scale: float = 1,
|
|
rope_theta: float = 1e4,
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
r"""Apply rotary embedding to a batch of queries/keys (stored as RaggedTensor).
|
|
cos/sin values are computed on the fly inside the kernel.
|
|
|
|
We use :attr:`indptr` to denote the start pointer of each segment in the batch, the i-th
|
|
segment the query of the i-th segment is ``q[indptr[i]:indptr[i+1]]`` and the key of the
|
|
i-th segment is ``k[indptr[i]:indptr[i+1]]``, the first element of :attr:`indptr` is always
|
|
0 and the last element of :attr:`indptr` is the total number of queries/keys in the batch.
|
|
Please see :ref:`Ragged Tensor tutorial <kv-layout>` for more details about the
|
|
ragged tensor.
|
|
|
|
Parameters
|
|
----------
|
|
q : torch.Tensor
|
|
Query ragged tensor, shape: ``(nnz, num_q_heads, head_dim)`, where ``nnz`` is the last
|
|
element of ``indptr``.
|
|
k : torch.Tensor
|
|
Key ragged tensor, shape: ``(nnz, num_k_heads, head_dim)``, where ``nnz`` is the last
|
|
element of ``indptr``.
|
|
pos_ids : torch.Tensor
|
|
Position indices, shape: ``(batch_size + 1)``.
|
|
rotary_dim : Optional[int]
|
|
The dimensions to apply RoPE, if ``None``, we apply RoPE to the entire head dimension,
|
|
otherwise, we apply RoPE to the first ``rotary_dim`` dimensions, default: ``None``.
|
|
interleave : bool
|
|
Whether to use interleaved layout in the last dimension, default: ``False``.
|
|
|
|
* If ``True``, the last dimension of the query/key tensor is interleaved, i.e.,
|
|
we rotate the even dimensions ``([..., ::2])`` and odd dimensions ``([..., 1::2])``.
|
|
|
|
* If ``False``, the last dimension of the query/key tensor is not interleaved, i.e.,
|
|
we rotate the first half dimensions ``([..., :head_dim//2])`` and the second half
|
|
dimensions ``([..., head_dim//2:])``.
|
|
|
|
rope_scale : float
|
|
The scaling factor used in the rope embedding, default: ``1``.
|
|
rope_theta : float
|
|
The theta value used in the rope embedding, default: ``1e4``.
|
|
|
|
Returns
|
|
-------
|
|
q_rope : torch.Tensor
|
|
The rotated query tensor, shape: ``(nnz, num_q_heads, head_dim)``.
|
|
k_rope : torch.Tensor
|
|
The rotated key tensor, shape: ``(nnz, num_k_heads, head_dim)``.
|
|
|
|
See Also
|
|
--------
|
|
apply_rope_inplace
|
|
"""
|
|
q_rope = torch.empty_like(q)
|
|
k_rope = torch.empty_like(k)
|
|
if rotary_dim is None:
|
|
rotary_dim = q.size(-1)
|
|
_apply_rope_pos_ids(
|
|
q, k, q_rope, k_rope, pos_ids, rotary_dim, interleave, rope_scale, rope_theta
|
|
)
|
|
return q_rope, k_rope
|
|
|
|
|
|
def apply_llama31_rope(
|
|
q: torch.Tensor,
|
|
k: torch.Tensor,
|
|
indptr: torch.Tensor,
|
|
offsets: torch.Tensor,
|
|
rotary_dim: Optional[int] = None,
|
|
interleave: bool = False,
|
|
rope_scale: float = 8,
|
|
rope_theta: float = 5e5,
|
|
low_freq_factor: float = 1,
|
|
high_freq_factor: float = 4,
|
|
old_context_len: int = 8192,
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
r"""Apply Llama 3.1 style rotary embedding to a batch of queries/keys (stored as
|
|
RaggedTensor). cos/sin values are computed on the fly inside the kernel.
|
|
|
|
We use :attr:`indptr` to denote the start pointer of each segment in the batch, the i-th
|
|
segment the query of the i-th segment is ``q[indptr[i]:indptr[i+1]]`` and the key of the
|
|
i-th segment is ``k[indptr[i]:indptr[i+1]]``, the first element of :attr:`indptr` is always
|
|
0 and the last element of :attr:`indptr` is the total number of queries/keys in the batch.
|
|
Please see :ref:`Ragged Tensor tutorial <kv-layout>` for more details about the
|
|
ragged tensor.
|
|
|
|
Parameters
|
|
----------
|
|
q : torch.Tensor
|
|
Query ragged tensor, shape: ``(nnz, num_q_heads, head_dim)``, where ``nnz`` is the last
|
|
element of ``indptr``.
|
|
k : torch.Tensor
|
|
Key ragged tensor, shape: ``(nnz, num_k_heads, head_dim)``, where ``nnz`` is the last
|
|
element of ``indptr``.
|
|
indptr : torch.Tensor
|
|
Indptr tensor, shape: ``(batch_size + 1)``.
|
|
offsets : torch.Tensor
|
|
The relative position offsets of each query in the batch, shape: ``(batch_size)``.
|
|
rotary_dim : Optional[int]
|
|
The dimensions to apply RoPE, if ``None``, we apply RoPE to the entire head dimension,
|
|
otherwise, we apply RoPE to the first ``rotary_dim`` dimensions, default: ``None``.
|
|
interleave : bool
|
|
Whether to use interleaved layout in the last dimension, default: ``False``.
|
|
|
|
* If ``True``, the last dimension of the query/key tensor is interleaved, i.e.,
|
|
we rotate the even dimensions ``([..., ::2])`` and odd dimensions ``([..., 1::2])``.
|
|
|
|
* If ``False``, the last dimension of the query/key tensor is not interleaved, i.e.,
|
|
we rotate the first half dimensions ``([..., :head_dim//2])`` and the second half
|
|
dimensions ``([..., head_dim//2:])``.
|
|
|
|
rope_scale : float
|
|
The scaling factor used in the rope embedding, default: ``8``.
|
|
rope_theta : float
|
|
The theta value used in the rope embedding, default: ``5e5``.
|
|
low_freq_factor : float
|
|
The low frequency factor used in Llama 3.1 RoPE, default: ``1``.
|
|
high_freq_factor : float
|
|
The high frequency factor used in Llama 3.1 RoPE, default: ``4``.
|
|
old_context_len : int
|
|
The old context length used in Llama 3.1 RoPE, default: ``8192``.
|
|
|
|
Returns
|
|
-------
|
|
q_rope : torch.Tensor
|
|
The rotated query tensor, shape: ``(nnz, num_q_heads, head_dim)``.
|
|
k_rope : torch.Tensor
|
|
The rotated key tensor, shape: ``(nnz, num_k_heads, head_dim)``.
|
|
|
|
Examples
|
|
--------
|
|
>>> import torch
|
|
>>> import flashinfer
|
|
>>> batch_size = 128
|
|
>>> qkv_len = 1024
|
|
>>> num_qo_heads = 32
|
|
>>> num_kv_heads = 32
|
|
>>> head_dim = 128
|
|
>>> nnz = batch_size * qkv_len
|
|
>>> qkv_packed = torch.randn(
|
|
>>> nnz,
|
|
>>> (num_qo_heads + 2 * num_kv_heads) * head_dim,
|
|
>>> dtype=torch.float16,
|
|
>>> device="cuda:0",
|
|
>>> )
|
|
>>> q = qkv_packed[:, : num_qo_heads * head_dim].reshape(nnz, num_qo_heads, head_dim)
|
|
>>> k = qkv_packed[
|
|
... :, num_qo_heads * head_dim : (num_qo_heads + num_kv_heads) * head_dim
|
|
... ].reshape(nnz, num_kv_heads, head_dim)
|
|
>>> indptr = torch.tensor(
|
|
... [i * qkv_len for i in range(batch_size + 1)], dtype=torch.int32, device="cuda:0"
|
|
>>> )
|
|
>>> offsets = torch.full((batch_size,), 10, dtype=torch.int32, device="cuda:0")
|
|
>>> q_rope, k_rope = flashinfer.apply_llama31_rope(q, k, indptr, offsets)
|
|
>>> q_rope.shape
|
|
torch.Size([131072, 32, 128])
|
|
>>> k_rope.shape
|
|
torch.Size([131072, 32, 128])
|
|
|
|
See Also
|
|
--------
|
|
apply_llama31_rope_inplace
|
|
"""
|
|
q_rope = torch.empty_like(q)
|
|
k_rope = torch.empty_like(k)
|
|
if rotary_dim is None:
|
|
rotary_dim = q.size(-1)
|
|
_apply_llama31_rope(
|
|
q,
|
|
k,
|
|
q_rope,
|
|
k_rope,
|
|
indptr,
|
|
offsets,
|
|
rotary_dim,
|
|
interleave,
|
|
rope_scale,
|
|
rope_theta,
|
|
low_freq_factor,
|
|
high_freq_factor,
|
|
float(old_context_len),
|
|
)
|
|
return q_rope, k_rope
|
|
|
|
|
|
def apply_llama31_rope_pos_ids(
|
|
q: torch.Tensor,
|
|
k: torch.Tensor,
|
|
pos_ids: torch.Tensor,
|
|
rotary_dim: Optional[int] = None,
|
|
interleave: bool = False,
|
|
rope_scale: float = 8,
|
|
rope_theta: float = 5e5,
|
|
low_freq_factor: float = 1,
|
|
high_freq_factor: float = 4,
|
|
old_context_len: int = 8192,
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
r"""Apply Llama 3.1 style rotary embedding to a batch of queries/keys (stored as
|
|
RaggedTensor). cos/sin values are computed on the fly inside the kernel.
|
|
|
|
We use :attr:`indptr` to denote the start pointer of each segment in the batch, the i-th
|
|
segment the query of the i-th segment is ``q[indptr[i]:indptr[i+1]]`` and the key of the
|
|
i-th segment is ``k[indptr[i]:indptr[i+1]]``, the first element of :attr:`indptr` is always
|
|
0 and the last element of :attr:`indptr` is the total number of queries/keys in the batch.
|
|
Please see :ref:`Ragged Tensor tutorial <kv-layout>` for more details about the
|
|
ragged tensor.
|
|
|
|
Parameters
|
|
----------
|
|
q : torch.Tensor
|
|
Query ragged tensor, shape: ``(nnz, num_q_heads, head_dim)``, where ``nnz`` is the last
|
|
element of ``indptr``.
|
|
k : torch.Tensor
|
|
Key ragged tensor, shape: ``(nnz, num_k_heads, head_dim)``, where ``nnz`` is the last
|
|
element of ``indptr``.
|
|
pos_ids : torch.Tensor
|
|
Position indices, shape: ``(nnz)``.
|
|
rotary_dim : Optional[int]
|
|
The dimensions to apply RoPE, if ``None``, we apply RoPE to the entire head dimension,
|
|
otherwise, we apply RoPE to the first ``rotary_dim`` dimensions, default: ``None``.
|
|
interleave : bool
|
|
Whether to use interleaved layout in the last dimension, default: ``False``.
|
|
|
|
* If ``True``, the last dimension of the query/key tensor is interleaved, i.e.,
|
|
we rotate the even dimensions ``([..., ::2])`` and odd dimensions ``([..., 1::2])``.
|
|
|
|
* If ``False``, the last dimension of the query/key tensor is not interleaved, i.e.,
|
|
we rotate the first half dimensions ``([..., :head_dim//2])`` and the second half
|
|
dimensions ``([..., head_dim//2:])``
|
|
rope_scale : float
|
|
The scaling factor used in the rope embedding, default: ``8``.
|
|
rope_theta : float
|
|
The theta value used in the rope embedding, default: ``5e5``.
|
|
low_freq_factor : float
|
|
The low frequency factor used in Llama 3.1 RoPE, default: ``1``.
|
|
high_freq_factor : float
|
|
The high frequency factor used in Llama 3.1 RoPE, default: ``4``.
|
|
old_context_len : int
|
|
The old context length used in Llama 3.1 RoPE, default: ``8192``.
|
|
|
|
Returns
|
|
-------
|
|
q_rope : torch.Tensor
|
|
The rotated query tensor, shape: ``(nnz, num_q_heads, head_dim)``.
|
|
k_rope : torch.Tensor
|
|
The rotated key tensor, shape: ``(nnz, num_k_heads, head_dim)``.
|
|
|
|
See Also
|
|
--------
|
|
apply_llama31_rope_pos_ids_inplace
|
|
"""
|
|
q_rope = torch.empty_like(q)
|
|
k_rope = torch.empty_like(k)
|
|
if rotary_dim is None:
|
|
rotary_dim = q.size(-1)
|
|
_apply_llama31_rope_pos_ids(
|
|
q,
|
|
k,
|
|
q_rope,
|
|
k_rope,
|
|
pos_ids,
|
|
rotary_dim,
|
|
interleave,
|
|
rope_scale,
|
|
rope_theta,
|
|
low_freq_factor,
|
|
high_freq_factor,
|
|
float(old_context_len),
|
|
)
|
|
return q_rope, k_rope
|
|
|
|
|
|
def apply_rope_with_cos_sin_cache(
|
|
positions: torch.Tensor,
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
head_size: int,
|
|
cos_sin_cache: torch.Tensor,
|
|
is_neox: bool = True,
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
r"""
|
|
Apply rotary embedding to keys and queries with precomputed cos/sin values.
|
|
This is designed to be compatible with the SGL/vLLM implementation.
|
|
|
|
Parameters
|
|
----------
|
|
positions : torch.Tensor
|
|
Position indices, shape: ``(nnz)``.
|
|
query : torch.Tensor
|
|
Query tensor, shape: ``(nnz, num_q_heads * head_size)``.
|
|
key : torch.Tensor
|
|
Key tensor, shape: ``(nnz, num_k_heads * head_size)``.
|
|
cos_sin_cache : torch.Tensor
|
|
Cosine and Sine cache tensor, shape: ``(max_seq_len, rotary_dim)``.
|
|
Cosine is the first half and Sine is the second half on rotary_dim.
|
|
is_neox : bool
|
|
Whether to use Neox style RoPE, default: ``True``.
|
|
|
|
* If ``True``, the last dimension of the query/key tensor is not interleaved, i.e.,
|
|
we rorate the first half dimensions ``([..., :head_dim//2])`` and the second half
|
|
dimensions ``([..., head_dim//2:])``.
|
|
|
|
* If ``False``, the last dimension of the query/key tensor is interleaved, i.e.,
|
|
we rotate the even dimensions ``([..., ::2])`` and odd dimensions ``([..., 1::2])``.
|
|
|
|
Returns
|
|
-------
|
|
query_out : torch.Tensor
|
|
The rotated query tensor, shape: ``(nnz, num_q_heads * head_size)``.
|
|
key_out : torch.Tensor
|
|
The rotated key tensor, shape: ``(nnz, num_k_heads * head_size)``.
|
|
|
|
Note
|
|
----
|
|
The rotary dimension is determined by the cosine cache and sine cache.
|
|
"""
|
|
if cos_sin_cache.dtype != torch.float32:
|
|
raise ValueError("cos_sin_cache should be float32")
|
|
|
|
query_out = torch.empty_like(query)
|
|
key_out = torch.empty_like(key)
|
|
|
|
_apply_rope_pos_ids_cos_sin_cache(
|
|
q=query.view(query.shape[0], -1, head_size),
|
|
k=key.view(key.shape[0], -1, head_size),
|
|
q_rope=query_out.view(query_out.shape[0], -1, head_size),
|
|
k_rope=key_out.view(key_out.shape[0], -1, head_size),
|
|
cos_sin_cache=cos_sin_cache,
|
|
pos_ids=positions,
|
|
interleave=(not is_neox),
|
|
)
|
|
|
|
return query_out, key_out
|
|
|
|
|
|
def apply_rope_with_cos_sin_cache_inplace(
|
|
positions: torch.Tensor,
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
head_size: int,
|
|
cos_sin_cache: torch.Tensor,
|
|
is_neox: bool = True,
|
|
) -> None:
|
|
r"""
|
|
Apply rotary embedding to keys and queries with precomputed cos/sin values.
|
|
This is designed to be compatible with the SGL/vLLM implementation.
|
|
The result is inplace applied to the input tensors.
|
|
|
|
Parameters
|
|
----------
|
|
positions : torch.Tensor
|
|
Position indices, shape: ``(nnz)``.
|
|
query : torch.Tensor
|
|
Query tensor, shape: ``(nnz, num_q_heads * head_size)``.
|
|
key : torch.Tensor
|
|
Key tensor, shape: ``(nnz, num_k_heads * head_size)``.
|
|
cos_sin_cache : torch.Tensor
|
|
Cosine and Sine cache tensor, shape: ``(max_seq_len, rotary_dim)``.
|
|
Cosine is the first half and Sine is the second half on rotary_dim.
|
|
is_neox : bool
|
|
Whether to use Neox style RoPE, default: ``True``.
|
|
|
|
* If ``True``, the last dimension of the query/key tensor is not interleaved, i.e.,
|
|
we rorate the first half dimensions ``([..., :head_dim//2])`` and the second half
|
|
dimensions ``([..., head_dim//2:])``.
|
|
|
|
* If ``False``, the last dimension of the query/key tensor is interleaved, i.e.,
|
|
we rotate the even dimensions ``([..., ::2])`` and odd dimensions ``([..., 1::2])``.
|
|
Note
|
|
----
|
|
The rotary dimension is determined by the cosine cache and sine cache.
|
|
"""
|
|
if cos_sin_cache.dtype != torch.float32:
|
|
raise ValueError("cos_sin_cache should be float32")
|
|
|
|
# pass q_rope and k_rope as q and k to perform inplace operation
|
|
_apply_rope_pos_ids_cos_sin_cache(
|
|
q=query.view(query.shape[0], -1, head_size),
|
|
k=key.view(key.shape[0], -1, head_size),
|
|
q_rope=query.view(query.shape[0], -1, head_size),
|
|
k_rope=key.view(key.shape[0], -1, head_size),
|
|
cos_sin_cache=cos_sin_cache,
|
|
pos_ids=positions,
|
|
interleave=(not is_neox),
|
|
)
|
|
|
|
|
|
def mla_rope_quantize_fp8(
|
|
q_rope: torch.Tensor,
|
|
k_rope: torch.Tensor,
|
|
q_nope: torch.Tensor,
|
|
k_nope: torch.Tensor,
|
|
cos_sin_cache: torch.Tensor,
|
|
pos_ids: torch.Tensor,
|
|
is_neox: bool = True,
|
|
quantize_dtype: Optional[torch.dtype] = None,
|
|
quant_scale_q: float = 1.0,
|
|
quant_scale_kv: float = 1.0,
|
|
q_rope_out: Optional[torch.Tensor] = None,
|
|
k_rope_out: Optional[torch.Tensor] = None,
|
|
q_nope_out: Optional[torch.Tensor] = None,
|
|
k_nope_out: Optional[torch.Tensor] = None,
|
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
if cos_sin_cache.dtype != torch.float32:
|
|
raise ValueError("cos_sin_cache should be float32")
|
|
|
|
# Infer quantize_dtype from output tensors or default to float8_e4m3fn
|
|
if quantize_dtype is None:
|
|
for out in (q_rope_out, k_rope_out, q_nope_out, k_nope_out):
|
|
if out is not None:
|
|
quantize_dtype = out.dtype
|
|
break
|
|
else:
|
|
quantize_dtype = torch.float8_e4m3fn
|
|
|
|
# Allocate output tensors if not provided
|
|
q_rope_out = (
|
|
q_rope_out
|
|
if q_rope_out is not None
|
|
else torch.empty_like(q_rope, dtype=quantize_dtype)
|
|
)
|
|
k_rope_out = (
|
|
k_rope_out
|
|
if k_rope_out is not None
|
|
else torch.empty_like(k_rope, dtype=quantize_dtype)
|
|
)
|
|
q_nope_out = (
|
|
q_nope_out
|
|
if q_nope_out is not None
|
|
else torch.empty_like(q_nope, dtype=quantize_dtype)
|
|
)
|
|
k_nope_out = (
|
|
k_nope_out
|
|
if k_nope_out is not None
|
|
else torch.empty_like(k_nope, dtype=quantize_dtype)
|
|
)
|
|
|
|
_mla_rope_quantize(
|
|
q_rope,
|
|
k_rope,
|
|
q_nope,
|
|
k_nope,
|
|
cos_sin_cache,
|
|
pos_ids,
|
|
q_rope_out,
|
|
k_rope_out,
|
|
q_nope_out,
|
|
k_nope_out,
|
|
quant_scale_q,
|
|
quant_scale_kv,
|
|
not is_neox, # interleave
|
|
)
|
|
|
|
return q_rope_out, k_rope_out, q_nope_out, k_nope_out
|