import math import numpy as np import pytest import torch from flashinfer import xqa def set_random_seed(seed=42): torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) np.random.seed(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False def round_up(a, b): return math.ceil(a / b) * b def div_up(a, b): return math.ceil(a / b) props = torch.cuda.get_device_properties(0) sm_count = props.multi_processor_count beam_width = 1 q_scale = 1.0 class CacheSeq: def __init__( self, pool: torch.Tensor, page_indices: torch.Tensor, nb_heads: int, idx_head: int, tokens_per_page: int = 32, ): self.pool = pool self.page_indices = page_indices self.nb_heads = nb_heads self.idx_head = idx_head self.tokens_per_page = tokens_per_page def __getitem__(self, i: int) -> torch.Tensor: page_idx = self.page_indices[i // self.tokens_per_page].to(torch.int32) idx_head = ( self.tokens_per_page * self.nb_heads * page_idx + self.tokens_per_page * self.idx_head + i % self.tokens_per_page ) return self.pool[idx_head] def ref_attention( q, k_cache_seq, v_cache_seq, seq_len, q_scale, kv_scale, x_scale, attention_sinks, sliding_win_size, valid_elems_per_head, ): head_grp_size = q.shape[0] rcp_x_scale = 1.0 / x_scale qk_scale = q_scale * kv_scale / math.sqrt(valid_elems_per_head) q_f32 = q.to(torch.float32) # [head_grp_size, valid_elems_per_head] k_cache_f32 = torch.zeros( seq_len, valid_elems_per_head, dtype=torch.float32, device="cuda" ) v_cache_f32 = torch.zeros( seq_len, valid_elems_per_head, dtype=torch.float32, device="cuda" ) for j in range(seq_len): k_cache_f32[j] = k_cache_seq[j].to(torch.float32) v_cache_f32[j] = v_cache_seq[j].to(torch.float32) # q_f32: [head_grp_size, valid_elems_per_head] # k_cache_f32: [seq_len, valid_elems_per_head] # gemm0_acc: [head_grp_size, seq_len] gemm0_acc = torch.zeros( head_grp_size, seq_len, dtype=torch.float32, device=q_f32.device ) # Calculate sliding window start position if sliding_win_size == 0 or seq_len < sliding_win_size: seq_beg = 0 else: seq_beg = seq_len - sliding_win_size # Set positions before sliding window to negative infinity (masking) if seq_beg > 0: gemm0_acc[:, :seq_beg] = float("-inf") # q_f32: [head_grp_size, valid_elems_per_head] # k_cache_f32[seq_beg:seq_len]: [valid_seq_len, valid_elems_per_head] if seq_beg < seq_len: valid_k_cache = k_cache_f32[ seq_beg:seq_len ] # [valid_seq_len, valid_elems_per_head] valid_scores = ( torch.matmul(q_f32, valid_k_cache.t()) * qk_scale ) # [head_grp_size, valid_seq_len] gemm0_acc[:, seq_beg:seq_len] = valid_scores row_max = torch.max(gemm0_acc, dim=1, keepdim=True)[0] # [head_grp_size, 1] x = torch.exp(gemm0_acc - row_max) # [head_grp_size, seq_len] row_sum = torch.sum(x, dim=1, keepdim=True) # [head_grp_size, 1] x = x * rcp_x_scale if seq_beg < seq_len: valid_x = x[:, seq_beg:seq_len] # [head_grp_size, valid_seq_len] valid_v_cache = v_cache_f32[ seq_beg:seq_len ] # [valid_seq_len, valid_elems_per_head] out = torch.matmul( valid_x, valid_v_cache ) # [head_grp_size, valid_elems_per_head] else: out = torch.zeros( head_grp_size, valid_elems_per_head, dtype=torch.float32, device=q_f32.device, ) if attention_sinks is not None: sink_weights = torch.exp( attention_sinks - row_max.squeeze(-1) ) # [head_grp_size] row_sum.squeeze(-1)[:] += sink_weights out = out * (x_scale * kv_scale) / row_sum return out @pytest.mark.parametrize("use_sliding_window", [True, False]) @pytest.mark.parametrize("use_fp16", [True, False]) @pytest.mark.parametrize("use_attention_sinks", [True, False]) @pytest.mark.parametrize("seq_len", [2, 15, 256, 514]) @pytest.mark.parametrize("batch_size", [1, 4]) @pytest.mark.parametrize("nb_k_heads", [1, 4, 8]) @pytest.mark.parametrize("tokens_per_page", [16, 64]) @pytest.mark.parametrize("valid_elems_per_head", [32, 128]) @pytest.mark.parametrize("head_grp_size", [8, 16]) def test_xqa( batch_size, nb_k_heads, seq_len, tokens_per_page, use_fp16, valid_elems_per_head, head_grp_size, use_attention_sinks, use_sliding_window, ): set_random_seed(42) nb_v_heads = nb_k_heads nb_q_heads = nb_k_heads * head_grp_size output = torch.zeros( batch_size, beam_width, nb_q_heads, valid_elems_per_head, dtype=torch.bfloat16 if not use_fp16 else torch.float16, device="cuda", ) output.fill_(float("nan")) q_heads = torch.zeros( batch_size, beam_width, nb_q_heads, valid_elems_per_head, dtype=torch.bfloat16 if not use_fp16 else torch.float16, device="cuda", ) q_heads.normal_(0, 1) if use_attention_sinks: attention_sinks = torch.zeros( nb_k_heads, head_grp_size, dtype=torch.float32, device="cuda" ) for i in range(nb_k_heads): for j in range(head_grp_size): attention_sinks[i, j] = 2.0 + float(j % 4) else: attention_sinks = None if use_sliding_window: sliding_win_size = 256 else: sliding_win_size = 0 max_seq_len = round_up(seq_len, tokens_per_page) total_nb_cache_heads = ( (nb_k_heads + nb_v_heads) * max_seq_len * beam_width * batch_size ) cache_heads = torch.zeros( total_nb_cache_heads, valid_elems_per_head, dtype=torch.bfloat16 if not use_fp16 else torch.float16, device="cuda", ) cache_heads.normal_(0, 1) nb_pages_per_seq = div_up(max_seq_len, tokens_per_page) total_nb_pages = nb_pages_per_seq * 2 * beam_width * batch_size page_list_arg = torch.zeros( batch_size, beam_width, 2, nb_pages_per_seq, dtype=torch.uint32, device="cuda" ) page_list_arg.view(-1)[:total_nb_pages] = torch.arange( total_nb_pages, dtype=torch.int32, device="cuda" ).to(torch.uint32) flattened = page_list_arg.flatten() indices = torch.randperm(flattened.numel()) shuffled_flat = flattened.to(torch.int32)[indices].to(torch.uint32) page_list_arg = shuffled_flat.view(page_list_arg.shape) def cache_head_at( batch, is_k, idx_kv_head, pos, cache_heads, page_list, beam_width, nb_k_heads, tokens_per_page, ): beam = 0 kv = 0 if is_k else 1 page_idx = page_list_arg[batch][beam][kv][pos // tokens_per_page].to( torch.int32 ) idx_head = ( tokens_per_page * (nb_k_heads * page_idx + idx_kv_head) + pos % tokens_per_page ) return cache_heads[idx_head] for batch in range(batch_size): for kv in range(2): for idx_kv_head in range(nb_k_heads): for pos in range(seq_len, max_seq_len): cache_head = cache_head_at( batch, kv == 0, idx_kv_head, pos, cache_heads, page_list_arg, beam_width, nb_k_heads, tokens_per_page, ) cache_head.fill_(0.0) seq_len_list = torch.zeros( batch_size, beam_width, dtype=torch.uint32, device="cuda" ) seq_len_list.fill_(seq_len) kv_cache_scale = torch.ones(1, dtype=torch.float32, device="cuda") nb_seq = nb_k_heads * batch_size nb_semaphores = round_up(nb_seq, 2) + 2 + nb_seq + 2 semaphores = torch.zeros(nb_semaphores, dtype=torch.uint32, device="cuda") scratch_size = 256 << 20 scratch_buf = torch.zeros(scratch_size, dtype=torch.uint8, device="cuda") xqa( use_fp16, tokens_per_page, valid_elems_per_head, head_grp_size, use_sliding_window, sliding_win_size, sm_count, nb_k_heads, q_scale, output, q_heads, attention_sinks, cache_heads, page_list_arg, max_seq_len, seq_len_list, batch_size, kv_cache_scale, semaphores, scratch_buf, ) for req in range(batch_size): for b in range(beam_width): for idx_k_head in range(nb_k_heads): k_cache_seq = CacheSeq( pool=cache_heads, page_indices=page_list_arg[req][b][0], nb_heads=nb_k_heads, idx_head=idx_k_head, tokens_per_page=tokens_per_page, ) v_cache_seq = CacheSeq( pool=cache_heads, page_indices=page_list_arg[req][b][1], nb_heads=nb_k_heads, idx_head=idx_k_head, tokens_per_page=tokens_per_page, ) ref_output = ref_attention( q=q_heads[req][b][ idx_k_head * head_grp_size : (idx_k_head + 1) * head_grp_size ], k_cache_seq=k_cache_seq, v_cache_seq=v_cache_seq, seq_len=seq_len, q_scale=q_scale, kv_scale=kv_cache_scale[0], x_scale=1.0, attention_sinks=attention_sinks[idx_k_head, :] if use_attention_sinks else None, sliding_win_size=sliding_win_size if use_sliding_window else 0, valid_elems_per_head=valid_elems_per_head, ) kernel_output = output[req][b][ idx_k_head * head_grp_size : (idx_k_head + 1) * head_grp_size ].to(torch.float32) assert torch.allclose(ref_output, kernel_output, atol=0.01, rtol=0.01)