import functools from typing import Optional import torch from sgl_kernel import silu_and_mul def get_scalar_type(num_bits: int, has_zp: bool): from sgl_kernel.scalar_type import scalar_types if has_zp: assert num_bits == 4 return scalar_types.uint4 else: return scalar_types.uint4b8 if num_bits == 4 else scalar_types.uint8b128 def fused_marlin_moe( hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, w1_scale: torch.Tensor, w2_scale: torch.Tensor, gating_output: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, global_num_experts: int = -1, expert_map: Optional[torch.Tensor] = None, g_idx1: Optional[torch.Tensor] = None, g_idx2: Optional[torch.Tensor] = None, sort_indices1: Optional[torch.Tensor] = None, sort_indices2: Optional[torch.Tensor] = None, w1_zeros: Optional[torch.Tensor] = None, w2_zeros: Optional[torch.Tensor] = None, workspace: Optional[torch.Tensor] = None, num_bits: int = 8, is_k_full: bool = True, inplace: bool = False, ) -> torch.Tensor: """ This function computes a Mixture of Experts (MoE) layer using two sets of weights, w1 and w2, and top-k gating mechanism. Parameters: - hidden_states (torch.Tensor): The input tensor to the MoE layer. - w1 (torch.Tensor): The first set of expert weights. - w2 (torch.Tensor): The second set of expert weights. - w1_scale (torch.Tensor): Scale to be used for w1. - w2_scale (torch.Tensor): Scale to be used for w2. - gating_output (torch.Tensor): The output of the gating operation (before softmax). - g_idx1 (Optional[torch.Tensor]): The first set of act_order indices. - g_idx2 (Optional[torch.Tensor]): The second set of act_order indices. - sort_indices1 (Optional[torch.Tensor]): The first act_order input permutation. - sort_indices2 (Optional[torch.Tensor]): The second act_order input permutation. - topk_weights (torch.Tensor): Top-k weights. - topk_ids (torch.Tensor): Indices of topk-k elements. - w1_zeros (Optional[torch.Tensor]): Optional zero points to be used for w1. - w2_zeros (Optional[torch.Tensor]): Optional zero points to be used for w2. - num_bits (bool): The number of bits in expert weights quantization. Returns: - torch.Tensor: The output tensor after applying the MoE layer. """ # Delay the import to avoid circular dependency from sglang.srt.layers.moe.fused_moe_triton import ( moe_align_block_size, try_get_optimal_moe_config, ) # Check constraints. assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" assert hidden_states.shape[1] == w1.shape[1] * 16, "Hidden size mismatch w1" assert hidden_states.shape[1] == w2.shape[2] // ( num_bits // 2 ), "Hidden size mismatch w2" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert w1.is_contiguous(), "Expert weights1 must be contiguous" assert w2.is_contiguous(), "Expert weights2 must be contiguous" assert hidden_states.dtype in [torch.float16, torch.bfloat16] assert num_bits in [4, 8] M, K = hidden_states.shape E = w1.shape[0] N = w2.shape[1] * 16 topk = topk_ids.shape[1] get_config_func = functools.partial( try_get_optimal_moe_config, w1.shape, w2.shape, topk_ids.shape[1], None, is_marlin=True, ) config = get_config_func(M) block_size_m = config["BLOCK_SIZE_M"] if global_num_experts == -1: global_num_experts = E sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( topk_ids, block_size_m, global_num_experts ) if workspace is None: max_workspace_size = (max(2 * N, K) // 64) * ( sorted_token_ids.size(0) // block_size_m ) device = hidden_states.device sms = torch.cuda.get_device_properties(device).multi_processor_count max_workspace_size = min(max_workspace_size, sms * 4) workspace = torch.zeros( max_workspace_size, dtype=torch.int, device=device, requires_grad=False ) scalar_type1 = get_scalar_type(num_bits, w1_zeros is not None) scalar_type2 = get_scalar_type(num_bits, w2_zeros is not None) intermediate_cache2 = torch.empty( (M * topk_ids.shape[1], N), device=hidden_states.device, dtype=hidden_states.dtype, ) intermediate_cache13 = torch.empty( (M * topk_ids.shape[1] * max(2 * N, K),), device=hidden_states.device, dtype=hidden_states.dtype, ) intermediate_cache1 = intermediate_cache13[: M * topk_ids.shape[1] * 2 * N] intermediate_cache1 = intermediate_cache1.view(-1, 2 * N) intermediate_cache3 = intermediate_cache13[: M * topk_ids.shape[1] * K] intermediate_cache3 = intermediate_cache3.view(-1, K) use_atomic_add = ( hidden_states.dtype == torch.half or torch.cuda.get_device_capability(hidden_states.device)[0] >= 9 ) intermediate_cache1 = torch.ops.sgl_kernel.moe_wna16_marlin_gemm.default( hidden_states, intermediate_cache1, w1, w1_scale, w1_zeros, g_idx1, sort_indices1, workspace, sorted_token_ids, expert_ids, num_tokens_post_padded, topk_weights, moe_block_size=block_size_m, top_k=topk, mul_topk_weights=False, is_ep=expert_map is not None, b_q_type_id=scalar_type1.id, size_m=M, size_n=2 * N, size_k=K, is_k_full=is_k_full, use_atomic_add=use_atomic_add, use_fp32_reduce=True, is_zp_float=False, ) silu_and_mul(intermediate_cache1.view(-1, 2 * N), intermediate_cache2) if expert_map is not None: intermediate_cache3.zero_() intermediate_cache3 = torch.ops.sgl_kernel.moe_wna16_marlin_gemm.default( intermediate_cache2, intermediate_cache3, w2, w2_scale, w2_zeros, g_idx2, sort_indices2, workspace, sorted_token_ids, expert_ids, num_tokens_post_padded, topk_weights, moe_block_size=block_size_m, top_k=1, mul_topk_weights=True, is_ep=expert_map is not None, b_q_type_id=scalar_type2.id, size_m=M * topk, size_n=K, size_k=N, is_k_full=is_k_full, use_atomic_add=use_atomic_add, use_fp32_reduce=True, is_zp_float=False, ).view(-1, topk, K) output = hidden_states if inplace else torch.empty_like(hidden_states) return torch.sum( intermediate_cache3.view(*intermediate_cache3.shape), dim=1, out=output ) def fused_marlin_moe_fake( hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, w1_scale: torch.Tensor, w2_scale: torch.Tensor, gating_output: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, g_idx1: Optional[torch.Tensor] = None, g_idx2: Optional[torch.Tensor] = None, sort_indices1: Optional[torch.Tensor] = None, sort_indices2: Optional[torch.Tensor] = None, w1_zeros: Optional[torch.Tensor] = None, w2_zeros: Optional[torch.Tensor] = None, num_bits: int = 8, is_k_full: bool = True, ) -> torch.Tensor: return torch.empty_like(hidden_states)