from typing import Any, Dict, List, Optional, Tuple, Union import pytest import torch from sgl_kernel import apply_rope_with_cos_sin_cache_inplace # vLLM torch native def _apply_rotary_emb( x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, is_neox_style: bool, ) -> torch.Tensor: """ Args: x: [num_tokens, num_heads, head_size] cos: [num_tokens, head_size // 2] sin: [num_tokens, head_size // 2] is_neox_style: Whether to use the Neox-style or GPT-J-style rotary positional embeddings. """ cos = cos.unsqueeze(-2).to(x.dtype) sin = sin.unsqueeze(-2).to(x.dtype) if is_neox_style: x1, x2 = torch.chunk(x, 2, dim=-1) else: x1 = x[..., ::2] x2 = x[..., 1::2] o1 = x1 * cos - x2 * sin o2 = x2 * cos + x1 * sin if is_neox_style: return torch.cat((o1, o2), dim=-1) else: return torch.stack((o1, o2), dim=-1).flatten(-2) class RotaryEmbedding(torch.nn.Module): # Reference: https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/rotary_embedding.py def __init__( self, head_size: int, rotary_dim: int, max_position_embeddings: int, base: int, is_neox_style: bool, dtype: torch.dtype, ) -> None: super().__init__() self.head_size = head_size self.rotary_dim = rotary_dim self.max_position_embeddings = max_position_embeddings self.base = base self.is_neox_style = is_neox_style self.dtype = dtype cache = self._compute_cos_sin_cache() self.cos_sin_cache: torch.Tensor self.register_buffer("cos_sin_cache", cache, persistent=False) def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: inv_freq = 1.0 / ( base ** ( torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim ) ) return inv_freq def _compute_cos_sin_cache(self) -> torch.Tensor: """Compute the cos and sin cache.""" inv_freq = self._compute_inv_freq(self.base) t = torch.arange(self.max_position_embeddings, dtype=torch.float) freqs = torch.einsum("i,j -> ij", t, inv_freq) cos = freqs.cos() sin = freqs.sin() cache = torch.cat((cos, sin), dim=-1) return cache def forward_native( self, positions: torch.Tensor, query: torch.Tensor, key: torch.Tensor, offsets: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """A PyTorch-native implementation of forward().""" if offsets is not None: positions = positions + offsets positions = positions.flatten() num_tokens = positions.shape[0] cos_sin = self.cos_sin_cache.index_select(0, positions) # Modification: float32 is required for the rotary embedding to work correctly query = query.to(torch.float32) key = key.to(torch.float32) cos, sin = cos_sin.chunk(2, dim=-1) query_shape = query.shape query = query.view(num_tokens, -1, self.head_size) query_rot = query[..., : self.rotary_dim] query_pass = query[..., self.rotary_dim :] query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style) query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) key_shape = key.shape key = key.view(num_tokens, -1, self.head_size) key_rot = key[..., : self.rotary_dim] key_pass = key[..., self.rotary_dim :] key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style) key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) # Modification: convert to the correct dtype query = query.to(self.dtype) key = key.to(self.dtype) return query, key class FlashInferRotaryEmbedding(RotaryEmbedding): def forward_cuda( self, positions: torch.Tensor, query: torch.Tensor, key: torch.Tensor, offsets: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: apply_rope_with_cos_sin_cache_inplace( positions=positions, query=query, key=key, head_size=self.head_size, cos_sin_cache=self.cos_sin_cache, is_neox=self.is_neox_style, ) return query, key @pytest.mark.parametrize( "head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype, device, batch_size, seq_len, num_q_heads, num_kv_heads", [ (64, 64, 32, 8000, True, torch.bfloat16, "cuda", 32, 32, 1, 1), (256, 128, 4096, 10000, True, torch.bfloat16, "cuda", 2, 512, 4, 2), (512, 128, 311, 10000, True, torch.bfloat16, "cuda", 3, 39, 4, 2), (128, 128, 2048, 10000, False, torch.bfloat16, "cuda", 2, 512, 32, 8), (128, 128, 2048, 10000, False, torch.bfloat16, "cuda", 2, 512, 16, 4), (512, 128, 311, 10000, False, torch.bfloat16, "cuda", 3, 39, 4, 2), ], ) def test_correctness( head_size: int, rotary_dim: int, max_position_embeddings: int, base: int, is_neox_style: bool, dtype: torch.dtype, device: str, batch_size: int, seq_len: int, num_q_heads: int, num_kv_heads: int, ): rope_ref = RotaryEmbedding( head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype ).to(device) rope_flashinfer = FlashInferRotaryEmbedding( head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype ).to(device) pos_ids = torch.arange(seq_len, device=device).repeat(batch_size) query = torch.randn( batch_size * seq_len, num_q_heads * head_size, dtype=dtype, device=device ) key = torch.randn( batch_size * seq_len, num_kv_heads * head_size, dtype=dtype, device=device ) query_ref, key_ref = query.clone(), key.clone() query_flashinfer, key_flashinfer = query.clone(), key.clone() query_ref_out, key_ref_out = rope_ref.forward_native(pos_ids, query_ref, key_ref) query_flashinfer_out, key_flashinfer_out = rope_flashinfer.forward_cuda( pos_ids, query_flashinfer, key_flashinfer ) print(query_ref_out) print(query_flashinfer_out) torch.testing.assert_close( query_ref_out, query_flashinfer_out, atol=1e-2, rtol=1e-2 ) torch.testing.assert_close(key_ref_out, key_flashinfer_out, atol=1e-2, rtol=1e-2) if __name__ == "__main__": pytest.main([__file__])