""" Copyright (c) 2024 by FlashInfer team. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import pytest import torch from rope_reference import * import flashinfer @pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) @pytest.mark.parametrize("qkv_len", [1, 4, 19, 204]) @pytest.mark.parametrize("num_qo_heads", [8, 16]) @pytest.mark.parametrize("num_kv_heads", [8]) @pytest.mark.parametrize("offset", [0, 15, 99]) @pytest.mark.parametrize("head_dim", [64, 128, 256]) @pytest.mark.parametrize("llama_version", ["llama", "llama31"]) @pytest.mark.parametrize("partial_rotary_factor", [0.25, 0.5, 0.75, 1.0]) @pytest.mark.parametrize("inplace", [False, True]) def test_rope( batch_size, qkv_len, num_qo_heads, num_kv_heads, offset, head_dim, llama_version, partial_rotary_factor, inplace, ): rotary_dim = int(head_dim * partial_rotary_factor) nnz = batch_size * qkv_len qkv_packed = torch.randn( nnz, (num_qo_heads + 2 * num_kv_heads) * head_dim, dtype=torch.float16, device="cuda:0", ) q = qkv_packed[:, : num_qo_heads * head_dim].reshape(nnz, num_qo_heads, head_dim) k = qkv_packed[ :, num_qo_heads * head_dim : (num_qo_heads + num_kv_heads) * head_dim ].reshape(nnz, num_kv_heads, head_dim) indptr = torch.tensor( [i * qkv_len for i in range(batch_size + 1)], dtype=torch.int32, device="cuda:0" ) offsets = torch.full((batch_size,), offset, dtype=torch.int32, device="cuda:0") # reference implementation if llama_version == "llama": freqs_cis = precompute_freqs_cis( rotary_dim, qkv_len + offset, 10000.0, use_scaled=False, device="cuda:0" ).to("cuda:0") else: freqs_cis = precompute_freqs_cis( rotary_dim, qkv_len + offset, 5e5, use_scaled=True, device="cuda:0" ).to("cuda:0") q_rot_ref, k_rot_ref = apply_rotary_emb( q.reshape(batch_size, qkv_len, num_qo_heads, head_dim)[..., :rotary_dim], k.reshape(batch_size, qkv_len, num_kv_heads, head_dim)[..., :rotary_dim], freqs_cis[offset : offset + qkv_len], ) q_pass_ref = q.reshape(batch_size, qkv_len, num_qo_heads, head_dim)[ ..., rotary_dim: ] k_pass_ref = k.reshape(batch_size, qkv_len, num_kv_heads, head_dim)[ ..., rotary_dim: ] q_rope_ref = torch.cat([q_rot_ref, q_pass_ref], dim=-1).reshape( nnz, num_qo_heads, head_dim ) k_rope_ref = torch.cat([k_rot_ref, k_pass_ref], dim=-1).reshape( nnz, num_kv_heads, head_dim ) # flashinfer implementation if llama_version == "llama": if inplace: flashinfer.apply_rope_inplace( q, k, indptr, offsets, rotary_dim=rotary_dim, interleave=True, rope_theta=1e4, ) q_rope, k_rope = q, k else: q_rope, k_rope = flashinfer.apply_rope( q, k, indptr, offsets, rotary_dim=rotary_dim, interleave=True, rope_theta=1e4, ) else: if inplace: flashinfer.apply_llama31_rope_inplace( q, k, indptr, offsets, rotary_dim=rotary_dim, interleave=True, rope_theta=5e5, ) q_rope, k_rope = q, k else: q_rope, k_rope = flashinfer.apply_llama31_rope( q, k, indptr, offsets, rotary_dim=rotary_dim, interleave=True, rope_theta=5e5, ) # compare torch.testing.assert_close(q_rope_ref, q_rope, rtol=1e-3, atol=1e-3) torch.testing.assert_close(k_rope_ref, k_rope, rtol=1e-3, atol=1e-3) @pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) @pytest.mark.parametrize("qkv_len", [1, 4, 19, 204]) @pytest.mark.parametrize("num_qo_heads", [8, 16]) @pytest.mark.parametrize("num_kv_heads", [8]) @pytest.mark.parametrize("offset", [0, 15, 99]) @pytest.mark.parametrize("head_dim", [64, 128, 256]) @pytest.mark.parametrize("llama_version", ["llama", "llama31"]) @pytest.mark.parametrize("partial_rotary_factor", [0.25, 0.5, 0.75, 1.0]) @pytest.mark.parametrize("inplace", [False, True]) @pytest.mark.parametrize("interleave", [True, False]) @pytest.mark.parametrize("idtype", [torch.int32, torch.int64]) def test_rope_pos_ids( batch_size, qkv_len, num_qo_heads, num_kv_heads, offset, head_dim, llama_version, partial_rotary_factor, inplace, interleave, idtype, ): rotary_dim = int(head_dim * partial_rotary_factor) nnz = batch_size * qkv_len qkv_packed = torch.randn( nnz, (num_qo_heads + 2 * num_kv_heads) * head_dim, dtype=torch.float16, device="cuda:0", ) q = qkv_packed[:, : num_qo_heads * head_dim].reshape(nnz, num_qo_heads, head_dim) k = qkv_packed[ :, num_qo_heads * head_dim : (num_qo_heads + num_kv_heads) * head_dim ].reshape(nnz, num_kv_heads, head_dim) indptr = torch.tensor( [i * qkv_len for i in range(batch_size + 1)], dtype=idtype, device="cuda:0" ) offsets = torch.full((batch_size,), offset, dtype=idtype, device="cuda:0") pos_ids = torch.cat( [ torch.arange(offset, qkv_len + offset, dtype=idtype) for _ in range(batch_size) ] ).to("cuda:0") if llama_version == "llama": if inplace: q_clone, k_clone = q.clone(), k.clone() flashinfer.apply_rope_inplace( q, k, indptr, offsets, rotary_dim=rotary_dim, interleave=interleave, rope_theta=1e4, ) q_rope, k_rope = q, k flashinfer.apply_rope_pos_ids_inplace( q_clone, k_clone, pos_ids, rotary_dim=rotary_dim, interleave=interleave, rope_theta=1e4, ) q_rope_pos_ids, k_rope_pos_ids = q_clone, k_clone else: q_rope, k_rope = flashinfer.apply_rope( q, k, indptr, offsets, rotary_dim=rotary_dim, interleave=interleave, rope_theta=1e4, ) q_rope_pos_ids, k_rope_pos_ids = flashinfer.apply_rope_pos_ids( q, k, pos_ids, rotary_dim=rotary_dim, interleave=interleave, rope_theta=1e4, ) else: if inplace: q_clone, k_clone = q.clone(), k.clone() flashinfer.apply_llama31_rope_inplace( q, k, indptr, offsets, rotary_dim=rotary_dim, interleave=interleave, rope_theta=5e5, ) q_rope, k_rope = q, k flashinfer.apply_llama31_rope_pos_ids_inplace( q_clone, k_clone, pos_ids, rotary_dim=rotary_dim, interleave=interleave, rope_theta=5e5, ) q_rope_pos_ids, k_rope_pos_ids = q_clone, k_clone else: q_rope, k_rope = flashinfer.apply_llama31_rope( q, k, indptr, offsets, rotary_dim=rotary_dim, interleave=interleave, rope_theta=5e5, ) q_rope_pos_ids, k_rope_pos_ids = flashinfer.apply_llama31_rope_pos_ids( q, k, pos_ids, rotary_dim=rotary_dim, interleave=interleave, rope_theta=5e5, ) # compare torch.testing.assert_close(q_rope_pos_ids, q_rope, rtol=1e-3, atol=1e-3) torch.testing.assert_close(k_rope_pos_ids, k_rope, rtol=1e-3, atol=1e-3) class FlashInferRotaryEmbedding(RotaryEmbedding): def forward_cuda( self, positions: torch.Tensor, query: torch.Tensor, key: torch.Tensor, offsets: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: flashinfer.apply_rope_with_cos_sin_cache_inplace( positions=positions, query=query, key=key, head_size=self.head_size, cos_sin_cache=self.cos_sin_cache, is_neox=self.is_neox_style, ) return query, key @pytest.mark.parametrize( "head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype, device, batch_size, seq_len, num_q_heads, num_kv_heads", [ (64, 64, 32, 8000, True, torch.bfloat16, "cuda", 32, 32, 1, 1), (256, 128, 4096, 10000, True, torch.bfloat16, "cuda", 2, 512, 4, 2), (64, 32, 2048, 8432, True, torch.bfloat16, "cuda", 2, 199, 4, 1), (64, 64, 32, 8000, False, torch.bfloat16, "cuda", 32, 32, 1, 1), (64, 64, 32, 8000, False, torch.bfloat16, "cuda", 32, 32, 1, 1), (256, 128, 4096, 9231, False, torch.bfloat16, "cuda", 3, 231, 4, 2), ], ) def test_rope_cos_sin_cache( head_size: int, rotary_dim: int, max_position_embeddings: int, base: int, is_neox_style: bool, dtype: torch.dtype, device: str, batch_size: int, seq_len: int, num_q_heads: int, num_kv_heads: int, ): rope_ref = RotaryEmbedding( head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype, device, ) rope_flashinfer = FlashInferRotaryEmbedding( head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype, device, ) pos_ids = torch.arange(seq_len, device=device).repeat(batch_size) query = torch.randn( batch_size * seq_len, num_q_heads * head_size, dtype=dtype, device=device ) key = torch.randn( batch_size * seq_len, num_kv_heads * head_size, dtype=dtype, device=device ) query_ref, key_ref = query.clone(), key.clone() query_flashinfer, key_flashinfer = query.clone(), key.clone() query_ref_out, key_ref_out = rope_ref.forward_native(pos_ids, query_ref, key_ref) query_flashinfer_out, key_flashinfer_out = rope_flashinfer.forward_cuda( pos_ids, query_flashinfer, key_flashinfer ) torch.testing.assert_close( query_ref_out, query_flashinfer_out, atol=1e-2, rtol=1e-2 ) torch.testing.assert_close(key_ref_out, key_flashinfer_out, atol=1e-2, rtol=1e-2) @pytest.mark.parametrize("num_tokens", [1, 19, 128, 199, 899, 2047]) @pytest.mark.parametrize("input_dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("quant_dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) def test_mla_rope_quantize( num_tokens, input_dtype, quant_dtype, ): device = "cuda:0" num_qo_heads = 128 q_in = torch.randn(num_tokens, num_qo_heads, 576, dtype=input_dtype, device=device) k_in = torch.randn(num_tokens, 576, dtype=input_dtype, device=device) pos_ids = torch.arange(num_tokens, device=device) # baseline rope_flashinfer = FlashInferRotaryEmbedding( 576, 64, 4096, 10000, False, # is_neox_style input_dtype, device, ) q_out_f16_ref, k_out_f16_ref = rope_flashinfer.forward_native(pos_ids, q_in, k_in) q_out_f8_ref, k_out_f8_ref = map( lambda x: x.to(quant_dtype), (q_out_f16_ref, k_out_f16_ref), ) q_out = torch.empty_like(q_in, dtype=quant_dtype) k_out = torch.empty_like(k_in, dtype=quant_dtype) flashinfer.rope.mla_rope_quantize_fp8( q_in[..., :64], k_in[..., :64], q_in[..., 64:], k_in[..., 64:], rope_flashinfer.cos_sin_cache, pos_ids, is_neox=False, q_rope_out=q_out[..., :64], k_rope_out=k_out[..., :64], q_nope_out=q_out[..., 64:], k_nope_out=k_out[..., 64:], quant_scale_q=1.0, quant_scale_kv=1.0, ) torch.testing.assert_close( q_out_f8_ref.float(), q_out.float(), atol=1e-2, rtol=2e-1 ) torch.testing.assert_close( k_out_f8_ref.float(), k_out.float(), atol=1e-2, rtol=2e-1 ) if __name__ == "__main__": # test_rope(2, 1, 8, 8, 1, 128, "llama", 1.0, False) # test_rope_pos_ids(2, 1, 8, 8, 1, 128, "llama31", 1.0, False) # test_rope_cos_sin_cache( # 64, 64, 32, 8000, True, torch.bfloat16, "cuda", 32, 32, 1, 1 # ) test_mla_rope_quantize(1, 1, torch.float16, torch.float8_e4m3fn)