sglang_v0.5.2/sglang/benchmark/hf3fs/bench_zerocopy.py

141 lines
3.5 KiB
Python

import threading
import time
import torch
from tqdm import tqdm
from sglang.srt.distributed import (
get_world_group,
init_distributed_environment,
initialize_model_parallel,
)
from sglang.srt.managers.cache_controller import (
HiCacheController,
PrefetchOperation,
StorageOperation,
)
from sglang.srt.mem_cache.allocator import TokenToKVPoolAllocator
from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool
from sglang.srt.mem_cache.memory_pool_host import MHATokenToKVPoolHost
init_distributed_environment(
world_size=1,
rank=0,
distributed_init_method="tcp://127.0.0.1:23456",
local_rank=0,
backend="gloo",
)
initialize_model_parallel(
tensor_model_parallel_size=1,
pipeline_model_parallel_size=1,
)
group = get_world_group().cpu_group
max_total_num_tokens = 524288
page_size = 64
kv_cache_dtype = torch.bfloat16
layer_num = 64
head_num, head_dim = 8, 128
device = "cuda"
hicache_ratio = 2
hicache_size = 0
hicache_mem_layout = "page_first"
# hicache_mem_layout = "layer_first"
hicache_write_policy = "write_through"
hicache_io_backend = "kernel"
hicache_storage_backend = "hf3fs"
prefetch_threshold = 256
op_size = 1024
op_num = 16
token_to_kv_pool = MHATokenToKVPool(
max_total_num_tokens,
page_size=page_size,
dtype=kv_cache_dtype,
head_num=head_num,
head_dim=head_dim,
layer_num=layer_num,
device=device,
enable_memory_saver=True,
)
token_to_kv_pool_allocator = TokenToKVPoolAllocator(
max_total_num_tokens,
dtype=kv_cache_dtype,
device=device,
kvcache=token_to_kv_pool,
need_sort=False,
)
kv_cache = token_to_kv_pool_allocator.get_kvcache()
token_to_kv_pool_host = MHATokenToKVPoolHost(
kv_cache,
hicache_ratio,
hicache_size,
page_size,
hicache_mem_layout,
)
load_cache_event = threading.Event()
cache_controller = HiCacheController(
token_to_kv_pool_allocator,
token_to_kv_pool_host,
page_size,
group,
load_cache_event=load_cache_event,
write_policy=hicache_write_policy,
io_backend=hicache_io_backend,
storage_backend=hicache_storage_backend,
prefetch_threshold=prefetch_threshold,
)
operations = [
StorageOperation(
torch.tensor(list(range(i, i + op_size))),
list(range(i, i + op_size)),
hash_value=[f"{j}" for j in range(i, i + op_size, page_size)],
)
for i in tqdm(range(0, op_num * op_size, op_size))
]
tik = time.monotonic()
if hicache_mem_layout == "page_first":
for operation in operations:
cache_controller.zerocopy_page_backup(operation, batch_size=128)
elif hicache_mem_layout == "layer_first":
for operation in operations:
cache_controller.generic_page_backup(operation, batch_size=128)
tok = time.monotonic()
print(f"{tok-tik:.6f} s")
operations = [
PrefetchOperation(
f"{i}",
torch.tensor(list(range(i, i + op_size))),
list(range(i, i + op_size)),
f"{i}",
)
for i in tqdm(range(0, op_num * op_size, op_size))
]
for operation in operations:
operation.hash_value = [
f"{j}"
for j in range(
int(operation.last_hash), int(operation.last_hash) + op_size, page_size
)
]
tik = time.monotonic()
if hicache_mem_layout == "page_first":
for operation in operations:
cache_controller.zerocopy_page_transfer(operation, batch_size=128)
elif hicache_mem_layout == "layer_first":
for operation in operations:
cache_controller.generic_page_transfer(operation, batch_size=128)
tok = time.monotonic()
print(f"{tok-tik:.6f} s")