sglang_v0.5.2/flashinfer_0.3.1/tests/test_rope.py

423 lines
13 KiB
Python

"""
Copyright (c) 2024 by FlashInfer team.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import pytest
import torch
from rope_reference import *
import flashinfer
@pytest.mark.parametrize("batch_size", [1, 19, 99, 989])
@pytest.mark.parametrize("qkv_len", [1, 4, 19, 204])
@pytest.mark.parametrize("num_qo_heads", [8, 16])
@pytest.mark.parametrize("num_kv_heads", [8])
@pytest.mark.parametrize("offset", [0, 15, 99])
@pytest.mark.parametrize("head_dim", [64, 128, 256])
@pytest.mark.parametrize("llama_version", ["llama", "llama31"])
@pytest.mark.parametrize("partial_rotary_factor", [0.25, 0.5, 0.75, 1.0])
@pytest.mark.parametrize("inplace", [False, True])
def test_rope(
batch_size,
qkv_len,
num_qo_heads,
num_kv_heads,
offset,
head_dim,
llama_version,
partial_rotary_factor,
inplace,
):
rotary_dim = int(head_dim * partial_rotary_factor)
nnz = batch_size * qkv_len
qkv_packed = torch.randn(
nnz,
(num_qo_heads + 2 * num_kv_heads) * head_dim,
dtype=torch.float16,
device="cuda:0",
)
q = qkv_packed[:, : num_qo_heads * head_dim].reshape(nnz, num_qo_heads, head_dim)
k = qkv_packed[
:, num_qo_heads * head_dim : (num_qo_heads + num_kv_heads) * head_dim
].reshape(nnz, num_kv_heads, head_dim)
indptr = torch.tensor(
[i * qkv_len for i in range(batch_size + 1)], dtype=torch.int32, device="cuda:0"
)
offsets = torch.full((batch_size,), offset, dtype=torch.int32, device="cuda:0")
# reference implementation
if llama_version == "llama":
freqs_cis = precompute_freqs_cis(
rotary_dim, qkv_len + offset, 10000.0, use_scaled=False, device="cuda:0"
).to("cuda:0")
else:
freqs_cis = precompute_freqs_cis(
rotary_dim, qkv_len + offset, 5e5, use_scaled=True, device="cuda:0"
).to("cuda:0")
q_rot_ref, k_rot_ref = apply_rotary_emb(
q.reshape(batch_size, qkv_len, num_qo_heads, head_dim)[..., :rotary_dim],
k.reshape(batch_size, qkv_len, num_kv_heads, head_dim)[..., :rotary_dim],
freqs_cis[offset : offset + qkv_len],
)
q_pass_ref = q.reshape(batch_size, qkv_len, num_qo_heads, head_dim)[
..., rotary_dim:
]
k_pass_ref = k.reshape(batch_size, qkv_len, num_kv_heads, head_dim)[
..., rotary_dim:
]
q_rope_ref = torch.cat([q_rot_ref, q_pass_ref], dim=-1).reshape(
nnz, num_qo_heads, head_dim
)
k_rope_ref = torch.cat([k_rot_ref, k_pass_ref], dim=-1).reshape(
nnz, num_kv_heads, head_dim
)
# flashinfer implementation
if llama_version == "llama":
if inplace:
flashinfer.apply_rope_inplace(
q,
k,
indptr,
offsets,
rotary_dim=rotary_dim,
interleave=True,
rope_theta=1e4,
)
q_rope, k_rope = q, k
else:
q_rope, k_rope = flashinfer.apply_rope(
q,
k,
indptr,
offsets,
rotary_dim=rotary_dim,
interleave=True,
rope_theta=1e4,
)
else:
if inplace:
flashinfer.apply_llama31_rope_inplace(
q,
k,
indptr,
offsets,
rotary_dim=rotary_dim,
interleave=True,
rope_theta=5e5,
)
q_rope, k_rope = q, k
else:
q_rope, k_rope = flashinfer.apply_llama31_rope(
q,
k,
indptr,
offsets,
rotary_dim=rotary_dim,
interleave=True,
rope_theta=5e5,
)
# compare
torch.testing.assert_close(q_rope_ref, q_rope, rtol=1e-3, atol=1e-3)
torch.testing.assert_close(k_rope_ref, k_rope, rtol=1e-3, atol=1e-3)
@pytest.mark.parametrize("batch_size", [1, 19, 99, 989])
@pytest.mark.parametrize("qkv_len", [1, 4, 19, 204])
@pytest.mark.parametrize("num_qo_heads", [8, 16])
@pytest.mark.parametrize("num_kv_heads", [8])
@pytest.mark.parametrize("offset", [0, 15, 99])
@pytest.mark.parametrize("head_dim", [64, 128, 256])
@pytest.mark.parametrize("llama_version", ["llama", "llama31"])
@pytest.mark.parametrize("partial_rotary_factor", [0.25, 0.5, 0.75, 1.0])
@pytest.mark.parametrize("inplace", [False, True])
@pytest.mark.parametrize("interleave", [True, False])
@pytest.mark.parametrize("idtype", [torch.int32, torch.int64])
def test_rope_pos_ids(
batch_size,
qkv_len,
num_qo_heads,
num_kv_heads,
offset,
head_dim,
llama_version,
partial_rotary_factor,
inplace,
interleave,
idtype,
):
rotary_dim = int(head_dim * partial_rotary_factor)
nnz = batch_size * qkv_len
qkv_packed = torch.randn(
nnz,
(num_qo_heads + 2 * num_kv_heads) * head_dim,
dtype=torch.float16,
device="cuda:0",
)
q = qkv_packed[:, : num_qo_heads * head_dim].reshape(nnz, num_qo_heads, head_dim)
k = qkv_packed[
:, num_qo_heads * head_dim : (num_qo_heads + num_kv_heads) * head_dim
].reshape(nnz, num_kv_heads, head_dim)
indptr = torch.tensor(
[i * qkv_len for i in range(batch_size + 1)], dtype=idtype, device="cuda:0"
)
offsets = torch.full((batch_size,), offset, dtype=idtype, device="cuda:0")
pos_ids = torch.cat(
[
torch.arange(offset, qkv_len + offset, dtype=idtype)
for _ in range(batch_size)
]
).to("cuda:0")
if llama_version == "llama":
if inplace:
q_clone, k_clone = q.clone(), k.clone()
flashinfer.apply_rope_inplace(
q,
k,
indptr,
offsets,
rotary_dim=rotary_dim,
interleave=interleave,
rope_theta=1e4,
)
q_rope, k_rope = q, k
flashinfer.apply_rope_pos_ids_inplace(
q_clone,
k_clone,
pos_ids,
rotary_dim=rotary_dim,
interleave=interleave,
rope_theta=1e4,
)
q_rope_pos_ids, k_rope_pos_ids = q_clone, k_clone
else:
q_rope, k_rope = flashinfer.apply_rope(
q,
k,
indptr,
offsets,
rotary_dim=rotary_dim,
interleave=interleave,
rope_theta=1e4,
)
q_rope_pos_ids, k_rope_pos_ids = flashinfer.apply_rope_pos_ids(
q,
k,
pos_ids,
rotary_dim=rotary_dim,
interleave=interleave,
rope_theta=1e4,
)
else:
if inplace:
q_clone, k_clone = q.clone(), k.clone()
flashinfer.apply_llama31_rope_inplace(
q,
k,
indptr,
offsets,
rotary_dim=rotary_dim,
interleave=interleave,
rope_theta=5e5,
)
q_rope, k_rope = q, k
flashinfer.apply_llama31_rope_pos_ids_inplace(
q_clone,
k_clone,
pos_ids,
rotary_dim=rotary_dim,
interleave=interleave,
rope_theta=5e5,
)
q_rope_pos_ids, k_rope_pos_ids = q_clone, k_clone
else:
q_rope, k_rope = flashinfer.apply_llama31_rope(
q,
k,
indptr,
offsets,
rotary_dim=rotary_dim,
interleave=interleave,
rope_theta=5e5,
)
q_rope_pos_ids, k_rope_pos_ids = flashinfer.apply_llama31_rope_pos_ids(
q,
k,
pos_ids,
rotary_dim=rotary_dim,
interleave=interleave,
rope_theta=5e5,
)
# compare
torch.testing.assert_close(q_rope_pos_ids, q_rope, rtol=1e-3, atol=1e-3)
torch.testing.assert_close(k_rope_pos_ids, k_rope, rtol=1e-3, atol=1e-3)
class FlashInferRotaryEmbedding(RotaryEmbedding):
def forward_cuda(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
flashinfer.apply_rope_with_cos_sin_cache_inplace(
positions=positions,
query=query,
key=key,
head_size=self.head_size,
cos_sin_cache=self.cos_sin_cache,
is_neox=self.is_neox_style,
)
return query, key
@pytest.mark.parametrize(
"head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype, device, batch_size, seq_len, num_q_heads, num_kv_heads",
[
(64, 64, 32, 8000, True, torch.bfloat16, "cuda", 32, 32, 1, 1),
(256, 128, 4096, 10000, True, torch.bfloat16, "cuda", 2, 512, 4, 2),
(64, 32, 2048, 8432, True, torch.bfloat16, "cuda", 2, 199, 4, 1),
(64, 64, 32, 8000, False, torch.bfloat16, "cuda", 32, 32, 1, 1),
(64, 64, 32, 8000, False, torch.bfloat16, "cuda", 32, 32, 1, 1),
(256, 128, 4096, 9231, False, torch.bfloat16, "cuda", 3, 231, 4, 2),
],
)
def test_rope_cos_sin_cache(
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
is_neox_style: bool,
dtype: torch.dtype,
device: str,
batch_size: int,
seq_len: int,
num_q_heads: int,
num_kv_heads: int,
):
rope_ref = RotaryEmbedding(
head_size,
rotary_dim,
max_position_embeddings,
base,
is_neox_style,
dtype,
device,
)
rope_flashinfer = FlashInferRotaryEmbedding(
head_size,
rotary_dim,
max_position_embeddings,
base,
is_neox_style,
dtype,
device,
)
pos_ids = torch.arange(seq_len, device=device).repeat(batch_size)
query = torch.randn(
batch_size * seq_len, num_q_heads * head_size, dtype=dtype, device=device
)
key = torch.randn(
batch_size * seq_len, num_kv_heads * head_size, dtype=dtype, device=device
)
query_ref, key_ref = query.clone(), key.clone()
query_flashinfer, key_flashinfer = query.clone(), key.clone()
query_ref_out, key_ref_out = rope_ref.forward_native(pos_ids, query_ref, key_ref)
query_flashinfer_out, key_flashinfer_out = rope_flashinfer.forward_cuda(
pos_ids, query_flashinfer, key_flashinfer
)
torch.testing.assert_close(
query_ref_out, query_flashinfer_out, atol=1e-2, rtol=1e-2
)
torch.testing.assert_close(key_ref_out, key_flashinfer_out, atol=1e-2, rtol=1e-2)
@pytest.mark.parametrize("num_tokens", [1, 19, 128, 199, 899, 2047])
@pytest.mark.parametrize("input_dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("quant_dtype", [torch.float8_e4m3fn, torch.float8_e5m2])
def test_mla_rope_quantize(
num_tokens,
input_dtype,
quant_dtype,
):
device = "cuda:0"
num_qo_heads = 128
q_in = torch.randn(num_tokens, num_qo_heads, 576, dtype=input_dtype, device=device)
k_in = torch.randn(num_tokens, 576, dtype=input_dtype, device=device)
pos_ids = torch.arange(num_tokens, device=device)
# baseline
rope_flashinfer = FlashInferRotaryEmbedding(
576,
64,
4096,
10000,
False, # is_neox_style
input_dtype,
device,
)
q_out_f16_ref, k_out_f16_ref = rope_flashinfer.forward_native(pos_ids, q_in, k_in)
q_out_f8_ref, k_out_f8_ref = map(
lambda x: x.to(quant_dtype),
(q_out_f16_ref, k_out_f16_ref),
)
q_out = torch.empty_like(q_in, dtype=quant_dtype)
k_out = torch.empty_like(k_in, dtype=quant_dtype)
flashinfer.rope.mla_rope_quantize_fp8(
q_in[..., :64],
k_in[..., :64],
q_in[..., 64:],
k_in[..., 64:],
rope_flashinfer.cos_sin_cache,
pos_ids,
is_neox=False,
q_rope_out=q_out[..., :64],
k_rope_out=k_out[..., :64],
q_nope_out=q_out[..., 64:],
k_nope_out=k_out[..., 64:],
quant_scale_q=1.0,
quant_scale_kv=1.0,
)
torch.testing.assert_close(
q_out_f8_ref.float(), q_out.float(), atol=1e-2, rtol=2e-1
)
torch.testing.assert_close(
k_out_f8_ref.float(), k_out.float(), atol=1e-2, rtol=2e-1
)
if __name__ == "__main__":
# test_rope(2, 1, 8, 8, 1, 128, "llama", 1.0, False)
# test_rope_pos_ids(2, 1, 8, 8, 1, 128, "llama31", 1.0, False)
# test_rope_cos_sin_cache(
# 64, 64, 32, 8000, True, torch.bfloat16, "cuda", 32, 32, 1, 1
# )
test_mla_rope_quantize(1, 1, torch.float16, torch.float8_e4m3fn)