sglang_v0.5.2/sglang/benchmark/hf3fs/bench_storage.py

259 lines
7.1 KiB
Python

import json
import logging
import os
import random
import time
from typing import List
import torch
from tqdm import tqdm
from sglang.srt.mem_cache.storage.hf3fs.mini_3fs_metadata_server import (
Hf3fsLocalMetadataClient,
)
from sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs import HiCacheHF3FS
def print_stats(x: List[int]):
x = sorted(x)
lenx = len(x)
print(
f"mean = {sum(x)/len(x):.2f}, "
f"min = {min(x):.2f}, "
f"p25 = {x[int(lenx*0.25)]:.2f}, "
f"p50 = {x[int(lenx*0.5)]:.2f}, "
f"p75 = {x[int(lenx*0.75)]:.2f}, "
f"max = {max(x):.2f}"
)
def test():
# Qwen3-32B
layer_num = 64
head_num, head_dim = 8, 128
kv_lora_rank, qk_rope_head_dim = 0, 0
store_dtype = torch.bfloat16
tokens_per_page = 64
file_path_prefix = "/data/test"
file_size = 128 << 20
numjobs = 16
bytes_per_page = 16 << 20
entries = 2
dtype = store_dtype
config_path = os.getenv(HiCacheHF3FS.default_env_var)
assert config_path
try:
with open(config_path, "w") as f:
json.dump(
{
"file_path_prefix": file_path_prefix,
"file_size": file_size,
"numjobs": numjobs,
"entries": entries,
},
f,
)
except Exception as e:
raise RuntimeError(f"Failed to dump config to {config_path}: {str(e)}")
hicache_hf3fs = HiCacheHF3FS.from_env_config(bytes_per_page, dtype)
numel = 2 * tokens_per_page * layer_num * head_num * head_dim
assert numel * dtype.itemsize == bytes_per_page
num_pages = 10
tensors = {}
for i in range(num_pages):
k = f"key_{i}"
v = torch.randn((numel,)).to(dtype=dtype)
ok = hicache_hf3fs.set(k, v)
if i < (file_size // bytes_per_page):
assert ok, f"Failed to insert {k}"
else:
assert not ok
tensors[k] = v
assert hicache_hf3fs.get("key_8") is None
assert hicache_hf3fs.get("key_9") is None
start = 0
for i in range(start, start + hicache_hf3fs.num_pages):
k = f"key_{i}"
assert hicache_hf3fs.exists(k)
out = hicache_hf3fs.get(k)
assert out is not None
v = tensors[k]
assert torch.allclose(v, out, atol=1e-3), f"Tensor mismatch for {k}"
assert not hicache_hf3fs.exists("not_exists")
hicache_hf3fs.delete("key_7")
v2 = torch.randn((numel,)).to(dtype=dtype)
assert hicache_hf3fs.set("key_new", v2)
assert torch.allclose(hicache_hf3fs.get("key_new"), v2, atol=1e-3)
hicache_hf3fs.clear()
assert (
len(hicache_hf3fs.metadata_client.rank_metadata.free_pages)
== hicache_hf3fs.metadata_client.rank_metadata.num_pages
)
# batch
num_pages = 10
tensors = {}
keys = []
values = []
for i in range(num_pages):
k = f"key_{i}"
keys.append(k)
v = torch.randn((numel,)).to(dtype=dtype)
values.append(v)
ok = hicache_hf3fs.batch_set(keys, values)
assert not ok
assert hicache_hf3fs.get("key_8") is None
assert hicache_hf3fs.get("key_9") is None
results = hicache_hf3fs.batch_get(keys[: hicache_hf3fs.num_pages])
for result, key, value in zip(
results, keys[: hicache_hf3fs.num_pages], values[: hicache_hf3fs.num_pages]
):
assert torch.allclose(value, result, atol=1e-3), f"Tensor mismatch for {key}"
hicache_hf3fs.close()
os.remove(hicache_hf3fs.file_path)
print("All test cases passed.")
def bench():
# Qwen3-32B
layer_num = 64
head_num, head_dim = 8, 128
kv_lora_rank, qk_rope_head_dim = 0, 0
store_dtype = torch.bfloat16
tokens_per_page = 64
file_path = "/data/test.bin"
file_size = 1 << 40
numjobs = 16
bytes_per_page = 16 << 20
entries = 8
dtype = store_dtype
hicache_hf3fs = HiCacheHF3FS(
rank=0,
file_path=file_path,
file_size=file_size,
numjobs=numjobs,
bytes_per_page=bytes_per_page,
entries=entries,
dtype=dtype,
metadata_client=Hf3fsLocalMetadataClient(),
)
numel = 2 * tokens_per_page * layer_num * head_num * head_dim
assert numel * dtype.itemsize == bytes_per_page
num_page = 128
values = [torch.randn((numel,)).to(dtype=dtype) for _ in tqdm(range(num_page))]
warmup = 50
iteration = 100
w_bw = []
w_size = num_page * bytes_per_page / (1 << 30)
for i in tqdm(range(warmup + iteration), desc="Benchmarking write (GB/s)"):
keys = [f"{j}" for j in range(i * num_page, (i + 1) * num_page)]
tik = time.perf_counter()
ok = hicache_hf3fs.batch_set(keys, values)
tok = time.perf_counter()
if i < warmup:
continue
w_bw.append(w_size / (tok - tik))
assert ok
print_stats(w_bw)
r_bw = []
r_size = num_page * bytes_per_page / (1 << 30)
for i in tqdm(range(warmup + iteration), desc="Benchmarking read (GB/s)"):
keys = random.sample(
list(hicache_hf3fs.metadata_client.rank_metadata.key_to_index.keys()),
num_page,
)
tik = time.perf_counter()
results = hicache_hf3fs.batch_get(keys)
tok = time.perf_counter()
if i < warmup:
continue
r_bw.append(r_size / (tok - tik))
assert all([r is not None for r in results])
print_stats(r_bw)
hicache_hf3fs.close()
def allclose():
# Qwen3-32B
layer_num = 64
head_num, head_dim = 8, 128
kv_lora_rank, qk_rope_head_dim = 0, 0
store_dtype = torch.bfloat16
tokens_per_page = 64
file_path = "/data/test.bin"
file_size = 1 << 40
numjobs = 16
bytes_per_page = 16 << 20
entries = 8
dtype = store_dtype
hicache_hf3fs = HiCacheHF3FS(
rank=0,
file_path=file_path,
file_size=file_size,
numjobs=numjobs,
bytes_per_page=bytes_per_page,
entries=entries,
dtype=dtype,
metadata_client=Hf3fsLocalMetadataClient(),
)
numel = 2 * tokens_per_page * layer_num * head_num * head_dim
assert numel * dtype.itemsize == bytes_per_page
num_page = 128
values = [torch.randn((numel,)).to(dtype=dtype) for _ in tqdm(range(num_page))]
iteration = 100
for i in tqdm(range(iteration), desc="Benchmarking write (GB/s)"):
keys = [f"{j}" for j in range(i * num_page, (i + 1) * num_page)]
ok = hicache_hf3fs.batch_set(keys, values)
assert ok
read_keys, read_results = [], []
for i in tqdm(range(iteration), desc="Benchmarking read (GB/s)"):
keys = random.sample(
list(hicache_hf3fs.metadata_client.rank_metadata.key_to_index.keys()),
num_page,
)
results = hicache_hf3fs.batch_get(keys)
read_keys.extend(keys)
read_results.extend(results)
assert all([r is not None for r in results])
for key, result in tqdm(zip(read_keys, read_results)):
assert torch.allclose(values[int(key) % num_page], result, atol=1e-3)
hicache_hf3fs.close()
def main():
logging.basicConfig(level=logging.INFO)
test()
bench()
allclose()
if __name__ == "__main__":
main()