163 lines
4.7 KiB
Python
163 lines
4.7 KiB
Python
import concurrent.futures
|
|
import logging
|
|
import random
|
|
import time
|
|
from typing import List
|
|
|
|
import torch
|
|
from tqdm import tqdm
|
|
|
|
from sglang.srt.mem_cache.storage.hf3fs.client_hf3fs import Hf3fsClient
|
|
|
|
|
|
def print_stats(x: List[int]):
|
|
x = sorted(x)
|
|
lenx = len(x)
|
|
print(
|
|
f"mean = {sum(x)/len(x):.2f}, "
|
|
f"min = {min(x):.2f}, "
|
|
f"p25 = {x[int(lenx*0.25)]:.2f}, "
|
|
f"p50 = {x[int(lenx*0.5)]:.2f}, "
|
|
f"p75 = {x[int(lenx*0.75)]:.2f}, "
|
|
f"max = {max(x):.2f}"
|
|
)
|
|
|
|
|
|
def test():
|
|
# /path/to/hf3fs
|
|
file_path = "/data/bench.bin"
|
|
file_size = 1 << 40
|
|
bytes_per_page = 16 << 20
|
|
entries = 32
|
|
file_ops = Hf3fsClient(file_path, file_size, bytes_per_page, entries)
|
|
|
|
print("test batch_read / batch_write")
|
|
num_pages = 128
|
|
dtype = torch.bfloat16
|
|
numel = bytes_per_page // dtype.itemsize
|
|
offsets = list(range(file_size // bytes_per_page))
|
|
random.shuffle(offsets)
|
|
offsets = offsets[:num_pages]
|
|
offsets = [i * bytes_per_page for i in offsets]
|
|
tensor_writes = [
|
|
torch.randn(numel, dtype=dtype)
|
|
for _ in tqdm(range(num_pages), desc="prepare tensor")
|
|
]
|
|
for i in tqdm(range(0, num_pages, file_ops.entries), desc="batch_write"):
|
|
results = file_ops.batch_write(
|
|
offsets[i : i + file_ops.entries], tensor_writes[i : i + file_ops.entries]
|
|
)
|
|
assert all([result == numel * dtype.itemsize for result in results])
|
|
tensor_reads = [
|
|
torch.empty(numel, dtype=dtype)
|
|
for _ in tqdm(range(num_pages), desc="prepare tensor")
|
|
]
|
|
for i in tqdm(range(0, num_pages, file_ops.entries), desc="batch_read"):
|
|
results = file_ops.batch_read(
|
|
offsets[i : i + file_ops.entries], tensor_reads[i : i + file_ops.entries]
|
|
)
|
|
assert all([result == numel * dtype.itemsize for result in results])
|
|
assert all([torch.allclose(r, w) for r, w in zip(tensor_reads, tensor_writes)])
|
|
|
|
file_ops.close()
|
|
print("test done")
|
|
|
|
|
|
def bench():
|
|
file_path = "/data/bench.bin"
|
|
file_size = 1 << 40
|
|
bytes_per_page = 16 << 20
|
|
entries = 8
|
|
numjobs = 16
|
|
|
|
dtype = torch.bfloat16
|
|
numel = bytes_per_page // dtype.itemsize
|
|
|
|
file_ops = [
|
|
Hf3fsClient(file_path, file_size, bytes_per_page, entries)
|
|
for _ in range(numjobs)
|
|
]
|
|
|
|
num_page = entries
|
|
|
|
offsets = list(range(file_size // bytes_per_page))
|
|
tensors_write = [torch.randn(numel, dtype=dtype)] * num_page
|
|
tensors_read = [torch.empty(numel, dtype=dtype)] * num_page
|
|
random.shuffle(offsets)
|
|
|
|
warmup = 50
|
|
iteration = 100
|
|
|
|
executor = concurrent.futures.ThreadPoolExecutor(max_workers=numjobs)
|
|
|
|
w_bw = []
|
|
w_size = num_page * numjobs * bytes_per_page / (1 << 30)
|
|
for i in tqdm(range(warmup + iteration), desc="Benchmarking write (GB/s)"):
|
|
_offsets = [
|
|
[
|
|
offset * bytes_per_page
|
|
for offset in offsets[
|
|
(i * numjobs + j) * num_page : (i * numjobs + j + 1) * num_page
|
|
]
|
|
]
|
|
for j in range(numjobs)
|
|
]
|
|
tik = time.perf_counter()
|
|
futures = [
|
|
executor.submit(file_ops[j].batch_write, offset, tensors_write)
|
|
for j, offset in enumerate(_offsets)
|
|
]
|
|
results = [future.result() for future in futures]
|
|
tok = time.perf_counter()
|
|
if i < warmup:
|
|
continue
|
|
w_bw.append(w_size / (tok - tik))
|
|
results = [
|
|
_result == bytes_per_page for result in results for _result in result
|
|
]
|
|
assert all(results)
|
|
print_stats(w_bw)
|
|
|
|
r_bw = []
|
|
r_size = w_size
|
|
for i in tqdm(range(warmup + iteration), desc="Benchmarking read (GB/s)"):
|
|
_offsets = [
|
|
[
|
|
offset * bytes_per_page
|
|
for offset in offsets[
|
|
(i * numjobs + j) * num_page : (i * numjobs + j + 1) * num_page
|
|
]
|
|
]
|
|
for j in range(numjobs)
|
|
]
|
|
tik = time.perf_counter()
|
|
futures = [
|
|
executor.submit(file_ops[j].batch_read, offset, tensors_read)
|
|
for j, offset in enumerate(_offsets)
|
|
]
|
|
results = [future.result() for future in futures]
|
|
tok = time.perf_counter()
|
|
if i < warmup:
|
|
continue
|
|
r_bw.append(r_size / (tok - tik))
|
|
results = [
|
|
_result == bytes_per_page for result in results for _result in result
|
|
]
|
|
assert all(results)
|
|
print_stats(r_bw)
|
|
|
|
executor.shutdown(wait=True)
|
|
for _file_ops in file_ops:
|
|
_file_ops.close()
|
|
print("bench done")
|
|
|
|
|
|
def main():
|
|
logging.basicConfig(level=logging.INFO)
|
|
test()
|
|
bench()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|