sglang_v0.5.2/sglang/test/srt/cpu/test_binding.py

29 lines
630 B
Python

import re
import unittest
import sgl_kernel
import torch
kernel = torch.ops.sgl_kernel
from sglang.test.test_utils import CustomTestCase
class TestGemm(CustomTestCase):
def test_binding(self):
start_id = 1
n_cpu = 6
expected_cores = list(map(str, range(start_id, start_id + n_cpu)))
cpu_ids = ",".join(expected_cores)
output = kernel.init_cpu_threads_env(cpu_ids)
bindings = re.findall(r"OMP tid: \d+, core (\d+)", output)
self.assertEqual(len(bindings), n_cpu)
self.assertEqual(bindings, expected_cores)
if __name__ == "__main__":
unittest.main()