sglang_v0.5.2/sglang/benchmark/hf3fs/bench_client.py

163 lines
4.7 KiB
Python

import concurrent.futures
import logging
import random
import time
from typing import List
import torch
from tqdm import tqdm
from sglang.srt.mem_cache.storage.hf3fs.client_hf3fs import Hf3fsClient
def print_stats(x: List[int]):
x = sorted(x)
lenx = len(x)
print(
f"mean = {sum(x)/len(x):.2f}, "
f"min = {min(x):.2f}, "
f"p25 = {x[int(lenx*0.25)]:.2f}, "
f"p50 = {x[int(lenx*0.5)]:.2f}, "
f"p75 = {x[int(lenx*0.75)]:.2f}, "
f"max = {max(x):.2f}"
)
def test():
# /path/to/hf3fs
file_path = "/data/bench.bin"
file_size = 1 << 40
bytes_per_page = 16 << 20
entries = 32
file_ops = Hf3fsClient(file_path, file_size, bytes_per_page, entries)
print("test batch_read / batch_write")
num_pages = 128
dtype = torch.bfloat16
numel = bytes_per_page // dtype.itemsize
offsets = list(range(file_size // bytes_per_page))
random.shuffle(offsets)
offsets = offsets[:num_pages]
offsets = [i * bytes_per_page for i in offsets]
tensor_writes = [
torch.randn(numel, dtype=dtype)
for _ in tqdm(range(num_pages), desc="prepare tensor")
]
for i in tqdm(range(0, num_pages, file_ops.entries), desc="batch_write"):
results = file_ops.batch_write(
offsets[i : i + file_ops.entries], tensor_writes[i : i + file_ops.entries]
)
assert all([result == numel * dtype.itemsize for result in results])
tensor_reads = [
torch.empty(numel, dtype=dtype)
for _ in tqdm(range(num_pages), desc="prepare tensor")
]
for i in tqdm(range(0, num_pages, file_ops.entries), desc="batch_read"):
results = file_ops.batch_read(
offsets[i : i + file_ops.entries], tensor_reads[i : i + file_ops.entries]
)
assert all([result == numel * dtype.itemsize for result in results])
assert all([torch.allclose(r, w) for r, w in zip(tensor_reads, tensor_writes)])
file_ops.close()
print("test done")
def bench():
file_path = "/data/bench.bin"
file_size = 1 << 40
bytes_per_page = 16 << 20
entries = 8
numjobs = 16
dtype = torch.bfloat16
numel = bytes_per_page // dtype.itemsize
file_ops = [
Hf3fsClient(file_path, file_size, bytes_per_page, entries)
for _ in range(numjobs)
]
num_page = entries
offsets = list(range(file_size // bytes_per_page))
tensors_write = [torch.randn(numel, dtype=dtype)] * num_page
tensors_read = [torch.empty(numel, dtype=dtype)] * num_page
random.shuffle(offsets)
warmup = 50
iteration = 100
executor = concurrent.futures.ThreadPoolExecutor(max_workers=numjobs)
w_bw = []
w_size = num_page * numjobs * bytes_per_page / (1 << 30)
for i in tqdm(range(warmup + iteration), desc="Benchmarking write (GB/s)"):
_offsets = [
[
offset * bytes_per_page
for offset in offsets[
(i * numjobs + j) * num_page : (i * numjobs + j + 1) * num_page
]
]
for j in range(numjobs)
]
tik = time.perf_counter()
futures = [
executor.submit(file_ops[j].batch_write, offset, tensors_write)
for j, offset in enumerate(_offsets)
]
results = [future.result() for future in futures]
tok = time.perf_counter()
if i < warmup:
continue
w_bw.append(w_size / (tok - tik))
results = [
_result == bytes_per_page for result in results for _result in result
]
assert all(results)
print_stats(w_bw)
r_bw = []
r_size = w_size
for i in tqdm(range(warmup + iteration), desc="Benchmarking read (GB/s)"):
_offsets = [
[
offset * bytes_per_page
for offset in offsets[
(i * numjobs + j) * num_page : (i * numjobs + j + 1) * num_page
]
]
for j in range(numjobs)
]
tik = time.perf_counter()
futures = [
executor.submit(file_ops[j].batch_read, offset, tensors_read)
for j, offset in enumerate(_offsets)
]
results = [future.result() for future in futures]
tok = time.perf_counter()
if i < warmup:
continue
r_bw.append(r_size / (tok - tik))
results = [
_result == bytes_per_page for result in results for _result in result
]
assert all(results)
print_stats(r_bw)
executor.shutdown(wait=True)
for _file_ops in file_ops:
_file_ops.close()
print("bench done")
def main():
logging.basicConfig(level=logging.INFO)
test()
bench()
if __name__ == "__main__":
main()