82 lines
1.9 KiB
Python
82 lines
1.9 KiB
Python
from __future__ import annotations
|
|
|
|
import logging
|
|
from enum import Enum
|
|
from typing import Optional
|
|
|
|
import numpy as np
|
|
import numpy.typing as npt
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class KVArgs:
|
|
engine_rank: int
|
|
kv_data_ptrs: list[int]
|
|
kv_data_lens: list[int]
|
|
kv_item_lens: list[int]
|
|
aux_data_ptrs: list[int]
|
|
aux_data_lens: list[int]
|
|
aux_item_lens: list[int]
|
|
ib_device: str
|
|
|
|
|
|
class KVManager:
|
|
def __init__(self, args: KVArgs): ...
|
|
|
|
|
|
class KVPoll:
|
|
Failed = 0
|
|
Bootstrapping = 1
|
|
WaitingForInput = 2
|
|
Transferring = 3
|
|
Success = 4
|
|
|
|
|
|
class KVSender:
|
|
def __init__(self, mgr: KVManager, bootstrap_addr: str, bootstrap_room: int):
|
|
self.has_sent = False
|
|
|
|
def init(self, num_kv_indices: int, aux_index: Optional[int] = None): ...
|
|
|
|
def send(self, kv_indices: npt.NDArray[np.int32]):
|
|
self.has_sent = True
|
|
|
|
def poll(self) -> KVPoll:
|
|
if self.has_sent is False:
|
|
# Assume handshake completed instantly
|
|
return KVPoll.WaitingForInput
|
|
else:
|
|
# Assume transfer completed instantly
|
|
return KVPoll.Success
|
|
|
|
def failure_exception(self):
|
|
raise Exception("Fake KVSender Exception")
|
|
|
|
|
|
class KVReceiver:
|
|
def __init__(
|
|
self, mgr: KVManager, bootstrap_addr: str, bootstrap_room: Optional[int] = None
|
|
):
|
|
self.has_init = False
|
|
|
|
def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
|
|
self.has_init = True
|
|
|
|
def poll(self) -> KVPoll:
|
|
if self.has_init is False:
|
|
# Assume handshake completed instantly
|
|
return KVPoll.WaitingForInput
|
|
else:
|
|
# Assume transfer completed instantly
|
|
return KVPoll.Success
|
|
|
|
def failure_exception(self):
|
|
raise Exception("Fake KVReceiver Exception")
|
|
|
|
|
|
class KVBootstrapServer:
|
|
def __init__(self, port: int): ...
|
|
|
|
def poll(self) -> KVPoll: ...
|