sglang0.4.5.post1/python/sglang/srt/disaggregation/conn.py

82 lines
1.9 KiB
Python

from __future__ import annotations
import logging
from enum import Enum
from typing import Optional
import numpy as np
import numpy.typing as npt
logger = logging.getLogger(__name__)
class KVArgs:
engine_rank: int
kv_data_ptrs: list[int]
kv_data_lens: list[int]
kv_item_lens: list[int]
aux_data_ptrs: list[int]
aux_data_lens: list[int]
aux_item_lens: list[int]
ib_device: str
class KVManager:
def __init__(self, args: KVArgs): ...
class KVPoll:
Failed = 0
Bootstrapping = 1
WaitingForInput = 2
Transferring = 3
Success = 4
class KVSender:
def __init__(self, mgr: KVManager, bootstrap_addr: str, bootstrap_room: int):
self.has_sent = False
def init(self, num_kv_indices: int, aux_index: Optional[int] = None): ...
def send(self, kv_indices: npt.NDArray[np.int32]):
self.has_sent = True
def poll(self) -> KVPoll:
if self.has_sent is False:
# Assume handshake completed instantly
return KVPoll.WaitingForInput
else:
# Assume transfer completed instantly
return KVPoll.Success
def failure_exception(self):
raise Exception("Fake KVSender Exception")
class KVReceiver:
def __init__(
self, mgr: KVManager, bootstrap_addr: str, bootstrap_room: Optional[int] = None
):
self.has_init = False
def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
self.has_init = True
def poll(self) -> KVPoll:
if self.has_init is False:
# Assume handshake completed instantly
return KVPoll.WaitingForInput
else:
# Assume transfer completed instantly
return KVPoll.Success
def failure_exception(self):
raise Exception("Fake KVReceiver Exception")
class KVBootstrapServer:
def __init__(self, port: int): ...
def poll(self) -> KVPoll: ...