import itertools import unittest import sgl_kernel import torch from utils import precision from sglang.srt.layers.moe.topk import ( biased_grouped_topk_impl as native_biased_grouped_topk, ) from sglang.srt.layers.moe.topk import fused_topk_torch_native as native_fused_topk from sglang.srt.layers.moe.topk import grouped_topk_gpu as native_grouped_topk from sglang.srt.models.llama4 import Llama4MoE from sglang.test.test_utils import CustomTestCase torch.manual_seed(1234) # This is used by the Deepseek-V2 model class TestGroupedTopK(CustomTestCase): def _run_single_test(self, M, E, G, topk, topk_group, renormalize, dtype): torch.manual_seed(1234) # expand gating_output by M, otherwise bfloat16 fall into same value aftering truncating hidden_states = torch.randn(M, 100, dtype=dtype) gating_output = torch.randn(M, E, dtype=dtype) * 2 * M ref_topk_weights, ref_topk_ids = native_grouped_topk( hidden_states.float(), gating_output.float(), topk, renormalize, G, topk_group, ) # fused version topk_weights, topk_ids = torch.ops.sgl_kernel.grouped_topk_cpu( hidden_states, gating_output, topk, renormalize, G, topk_group, 0, None, None, ) res = torch.zeros(M, E, dtype=torch.float) ref = torch.zeros(M, E, dtype=torch.float) res.scatter_(1, topk_ids.long(), topk_weights) ref.scatter_(1, ref_topk_ids.long(), ref_topk_weights) torch.testing.assert_close(res, ref) def test_grouped_topk(self): for renormalize in [True, False]: self._run_single_test(123, 8, 2, 2, 1, renormalize, torch.bfloat16) self._run_single_test(123, 16, 4, 3, 2, renormalize, torch.bfloat16) self._run_single_test(123, 32, 4, 3, 2, renormalize, torch.bfloat16) self._run_single_test(1123, 32, 4, 3, 2, renormalize, torch.bfloat16) self._run_single_test(123, 64, 1, 6, 1, renormalize, torch.bfloat16) self._run_single_test(123, 256, 8, 4, 8, renormalize, torch.bfloat16) self._run_single_test(123, 160, 8, 6, 2, renormalize, torch.bfloat16) # DeepSeek V2/V3/R1 uses biased_grouped_top class TestBiasedGroupedTopK(CustomTestCase): def _run_single_test(self, M, E, G, topk, topk_group, renormalize, dtype): torch.manual_seed(1234) # expand gating_output by M, otherwise bfloat16 fall into same value aftering truncating hidden_states = torch.randn(M, 100, dtype=dtype) gating_output = torch.randn(M, E, dtype=dtype) * 2 * M correction_bias = torch.randn(E, dtype=dtype) ref_topk_weights, ref_topk_ids = native_biased_grouped_topk( hidden_states.float(), gating_output.float(), correction_bias.float(), topk, renormalize, G, topk_group, ) # fused version topk_weights, topk_ids = torch.ops.sgl_kernel.biased_grouped_topk_cpu( hidden_states, gating_output, correction_bias, topk, renormalize, G, topk_group, 0, None, None, ) res = torch.zeros(M, E, dtype=torch.float) ref = torch.zeros(M, E, dtype=torch.float) res.scatter_(1, topk_ids.long(), topk_weights) ref.scatter_(1, ref_topk_ids.long(), ref_topk_weights) torch.testing.assert_close(res, ref) def test_biased_grouped_topk(self): for renormalize in [True, False]: self._run_single_test(122, 256, 8, 8, 2, renormalize, torch.bfloat16) class TestTopK(CustomTestCase): def _run_single_test(self, M, E, topk, renormalize, dtype): torch.manual_seed(1998) # expand gating_output by M, otherwise bfloat16 fall into same value aftering truncating hidden_states = torch.randn(M, 100, dtype=dtype) gating_output = torch.randn(M, E, dtype=dtype) * 2 * M ref_topk_weights, ref_topk_ids = native_fused_topk( hidden_states.float(), gating_output.float(), topk, renormalize, ) # fused version topk_weights, topk_ids = torch.ops.sgl_kernel.topk_softmax_cpu( hidden_states, gating_output, topk, renormalize ) res = torch.zeros(M, E, dtype=torch.float) ref = torch.zeros(M, E, dtype=torch.float) res.scatter_(1, topk_ids.long(), topk_weights) ref.scatter_(1, ref_topk_ids.long(), ref_topk_weights) torch.testing.assert_close(res, ref) def test_topk(self): for renormalize in [True, False]: self._run_single_test(123, 8, 2, renormalize, torch.bfloat16) self._run_single_test(123, 16, 3, renormalize, torch.bfloat16) self._run_single_test(123, 32, 3, renormalize, torch.bfloat16) self._run_single_test(123, 32, 3, renormalize, torch.bfloat16) self._run_single_test(123, 64, 6, renormalize, torch.bfloat16) self._run_single_test(123, 256, 4, renormalize, torch.bfloat16) self._run_single_test(123, 160, 6, renormalize, torch.bfloat16) class TestCustomTopK(CustomTestCase): def _run_single_test( self, M, E, topk, renormalize, dtype, native_custom_f, fused_custom_f ): torch.manual_seed(16) # expand gating_output by M, otherwise bfloat16 fall into same value aftering truncating hidden_states = torch.randn(M, 100, dtype=dtype) gating_output = torch.randn(M, E, dtype=dtype) * 2 * M ref_topk_weights, ref_topk_ids = native_custom_f( hidden_states.float(), gating_output.float(), topk, renormalize, ) # fused version topk_weights, topk_ids = fused_custom_f( hidden_states, gating_output, topk, renormalize ) res = torch.zeros(M, E, dtype=torch.float) ref = torch.zeros(M, E, dtype=torch.float) res.scatter_(1, topk_ids.long(), topk_weights) ref.scatter_(1, ref_topk_ids.long(), ref_topk_weights) torch.testing.assert_close(res, ref) def test_custom_topk(self): test_custom_functions = [ (Llama4MoE.custom_routing_function, torch.ops.sgl_kernel.topk_sigmoid_cpu) ] for native_custom_f, fused_custom_f in test_custom_functions: self._run_single_test( 123, 8, 1, False, torch.bfloat16, native_custom_f, fused_custom_f ) self._run_single_test( 123, 16, 1, False, torch.bfloat16, native_custom_f, fused_custom_f ) self._run_single_test( 123, 32, 1, False, torch.bfloat16, native_custom_f, fused_custom_f ) if __name__ == "__main__": unittest.main()