sglang.0.4.8.post1/sglang/test/srt/test_schedule_policy.py

66 lines
2.2 KiB
Python

import unittest
from sglang.srt.managers.schedule_batch import Req
from sglang.srt.managers.schedule_policy import (
CacheAgnosticPolicy,
CacheAwarePolicy,
SchedulePolicy,
)
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.test.test_utils import CustomTestCase
class TestSchedulePolicy(CustomTestCase):
def setUp(self):
self.tree_cache = RadixCache(None, None, False)
def test_init_with_cache_aware_policy(self):
policy = SchedulePolicy(
policy="lpm", tree_cache=self.tree_cache, enable_hierarchical_cache=True
)
self.assertEqual(policy.policy, CacheAwarePolicy.LPM)
def test_init_with_cache_agnostic_policy(self):
policy = SchedulePolicy(
policy="fcfs", tree_cache=self.tree_cache, enable_hierarchical_cache=True
)
self.assertEqual(policy.policy, CacheAgnosticPolicy.FCFS)
def test_init_with_unknown_policy(self):
with self.assertRaises(ValueError):
SchedulePolicy(
policy="invalid",
tree_cache=self.tree_cache,
enable_hierarchical_cache=True,
)
def test_init_with_disabled_cache(self):
disabled_tree_cache = RadixCache(None, None, disable=True, page_size=1)
policy = SchedulePolicy(
policy="lpm", tree_cache=disabled_tree_cache, enable_hierarchical_cache=True
)
self.assertEqual(policy.policy, CacheAgnosticPolicy.FCFS)
def test_calc_priority_fcfs(self):
tree_cache = RadixCache(None, None, False)
waiting_queue = [
Req(1, "a b", [1, 2], SamplingParams()),
Req(3, "a b c", [1, 2, 3], SamplingParams()),
Req(2, "a", [1], SamplingParams()),
]
policy = SchedulePolicy(
policy="fcfs", tree_cache=tree_cache, enable_hierarchical_cache=True
)
policy.calc_priority(waiting_queue)
# Check if FCFS keeps the original order
self.assertEqual(waiting_queue[0].rid, 1)
self.assertEqual(waiting_queue[1].rid, 3)
self.assertEqual(waiting_queue[2].rid, 2)
if __name__ == "__main__":
unittest.main()