141 lines
3.5 KiB
Python
141 lines
3.5 KiB
Python
import threading
|
|
import time
|
|
|
|
import torch
|
|
from tqdm import tqdm
|
|
|
|
from sglang.srt.distributed import (
|
|
get_world_group,
|
|
init_distributed_environment,
|
|
initialize_model_parallel,
|
|
)
|
|
from sglang.srt.managers.cache_controller import (
|
|
HiCacheController,
|
|
PrefetchOperation,
|
|
StorageOperation,
|
|
)
|
|
from sglang.srt.mem_cache.allocator import TokenToKVPoolAllocator
|
|
from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool
|
|
from sglang.srt.mem_cache.memory_pool_host import MHATokenToKVPoolHost
|
|
|
|
init_distributed_environment(
|
|
world_size=1,
|
|
rank=0,
|
|
distributed_init_method="tcp://127.0.0.1:23456",
|
|
local_rank=0,
|
|
backend="gloo",
|
|
)
|
|
|
|
initialize_model_parallel(
|
|
tensor_model_parallel_size=1,
|
|
pipeline_model_parallel_size=1,
|
|
)
|
|
|
|
group = get_world_group().cpu_group
|
|
|
|
max_total_num_tokens = 524288
|
|
page_size = 64
|
|
kv_cache_dtype = torch.bfloat16
|
|
layer_num = 64
|
|
head_num, head_dim = 8, 128
|
|
device = "cuda"
|
|
hicache_ratio = 2
|
|
hicache_size = 0
|
|
hicache_mem_layout = "page_first"
|
|
# hicache_mem_layout = "layer_first"
|
|
hicache_write_policy = "write_through"
|
|
hicache_io_backend = "kernel"
|
|
hicache_storage_backend = "hf3fs"
|
|
prefetch_threshold = 256
|
|
|
|
op_size = 1024
|
|
op_num = 16
|
|
|
|
token_to_kv_pool = MHATokenToKVPool(
|
|
max_total_num_tokens,
|
|
page_size=page_size,
|
|
dtype=kv_cache_dtype,
|
|
head_num=head_num,
|
|
head_dim=head_dim,
|
|
layer_num=layer_num,
|
|
device=device,
|
|
enable_memory_saver=True,
|
|
)
|
|
|
|
token_to_kv_pool_allocator = TokenToKVPoolAllocator(
|
|
max_total_num_tokens,
|
|
dtype=kv_cache_dtype,
|
|
device=device,
|
|
kvcache=token_to_kv_pool,
|
|
need_sort=False,
|
|
)
|
|
|
|
kv_cache = token_to_kv_pool_allocator.get_kvcache()
|
|
token_to_kv_pool_host = MHATokenToKVPoolHost(
|
|
kv_cache,
|
|
hicache_ratio,
|
|
hicache_size,
|
|
page_size,
|
|
hicache_mem_layout,
|
|
)
|
|
|
|
load_cache_event = threading.Event()
|
|
cache_controller = HiCacheController(
|
|
token_to_kv_pool_allocator,
|
|
token_to_kv_pool_host,
|
|
page_size,
|
|
group,
|
|
load_cache_event=load_cache_event,
|
|
write_policy=hicache_write_policy,
|
|
io_backend=hicache_io_backend,
|
|
storage_backend=hicache_storage_backend,
|
|
prefetch_threshold=prefetch_threshold,
|
|
)
|
|
|
|
operations = [
|
|
StorageOperation(
|
|
torch.tensor(list(range(i, i + op_size))),
|
|
list(range(i, i + op_size)),
|
|
hash_value=[f"{j}" for j in range(i, i + op_size, page_size)],
|
|
)
|
|
for i in tqdm(range(0, op_num * op_size, op_size))
|
|
]
|
|
|
|
tik = time.monotonic()
|
|
if hicache_mem_layout == "page_first":
|
|
for operation in operations:
|
|
cache_controller.zerocopy_page_backup(operation, batch_size=128)
|
|
elif hicache_mem_layout == "layer_first":
|
|
for operation in operations:
|
|
cache_controller.generic_page_backup(operation, batch_size=128)
|
|
tok = time.monotonic()
|
|
print(f"{tok-tik:.6f} s")
|
|
|
|
operations = [
|
|
PrefetchOperation(
|
|
f"{i}",
|
|
torch.tensor(list(range(i, i + op_size))),
|
|
list(range(i, i + op_size)),
|
|
f"{i}",
|
|
)
|
|
for i in tqdm(range(0, op_num * op_size, op_size))
|
|
]
|
|
|
|
for operation in operations:
|
|
operation.hash_value = [
|
|
f"{j}"
|
|
for j in range(
|
|
int(operation.last_hash), int(operation.last_hash) + op_size, page_size
|
|
)
|
|
]
|
|
|
|
tik = time.monotonic()
|
|
if hicache_mem_layout == "page_first":
|
|
for operation in operations:
|
|
cache_controller.zerocopy_page_transfer(operation, batch_size=128)
|
|
elif hicache_mem_layout == "layer_first":
|
|
for operation in operations:
|
|
cache_controller.generic_page_transfer(operation, batch_size=128)
|
|
tok = time.monotonic()
|
|
print(f"{tok-tik:.6f} s")
|