from __future__ import annotations import logging from enum import Enum from typing import Optional import numpy as np import numpy.typing as npt logger = logging.getLogger(__name__) class KVArgs: engine_rank: int kv_data_ptrs: list[int] kv_data_lens: list[int] kv_item_lens: list[int] aux_data_ptrs: list[int] aux_data_lens: list[int] aux_item_lens: list[int] ib_device: str class KVManager: def __init__(self, args: KVArgs): ... class KVPoll: Failed = 0 Bootstrapping = 1 WaitingForInput = 2 Transferring = 3 Success = 4 class KVSender: def __init__(self, mgr: KVManager, bootstrap_addr: str, bootstrap_room: int): self.has_sent = False def init(self, num_kv_indices: int, aux_index: Optional[int] = None): ... def send(self, kv_indices: npt.NDArray[np.int32]): self.has_sent = True def poll(self) -> KVPoll: if self.has_sent is False: # Assume handshake completed instantly return KVPoll.WaitingForInput else: # Assume transfer completed instantly return KVPoll.Success def failure_exception(self): raise Exception("Fake KVSender Exception") class KVReceiver: def __init__( self, mgr: KVManager, bootstrap_addr: str, bootstrap_room: Optional[int] = None ): self.has_init = False def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None): self.has_init = True def poll(self) -> KVPoll: if self.has_init is False: # Assume handshake completed instantly return KVPoll.WaitingForInput else: # Assume transfer completed instantly return KVPoll.Success def failure_exception(self): raise Exception("Fake KVReceiver Exception") class KVBootstrapServer: def __init__(self, port: int): ... def poll(self) -> KVPoll: ...