sglang0.4.5.post1/examples/runtime/engine/offline_batch_inference_tor...

82 lines
2.2 KiB
Python

import datetime
import os
import sys
from torch.distributed.device_mesh import init_device_mesh
from sglang.srt.entrypoints.verl_engine import VerlEngine
def run():
"""
Example command:
```
torchrun --nproc_per_node=8 offline_batch_inference_torchrun.py
```
"""
local_rank = int(os.environ["LOCAL_RANK"])
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
def _log(text):
t = datetime.datetime.now().strftime("%H:%M:%S")
print(f"[{t}] [rank={rank}] {text}")
_log(
f'start {local_rank=} {rank=} {world_size=} {sys.argv=} {os.environ.get("CUDA_VISIBLE_DEVICES")}'
)
tp_size = 4
dp_size = 2
assert world_size == tp_size * dp_size
device_mesh_kwargs = dict(
mesh_shape=(tp_size, dp_size, 1), mesh_dim_names=["tp", "dp", "pp"]
)
device_mesh_cpu = init_device_mesh("cpu", **device_mesh_kwargs)
_log(f"{device_mesh_cpu=}")
tp_rank = device_mesh_cpu.get_local_rank("tp")
dp_rank = device_mesh_cpu.get_local_rank("dp")
_log(f"{tp_rank=} {tp_size=} ; {dp_rank=} {dp_size=}")
model_name, mem_fraction_static = "meta-llama/Llama-3.2-1B-Instruct", 0.1
# model_name, mem_fraction_static = "meta-llama/Llama-3.1-70B-Instruct", 0.9 # test large models
# model_name, mem_fraction_static = "deepseek-ai/DeepSeek-V2-Lite", 0.8
for k in ["TORCHELASTIC_USE_AGENT_STORE"]:
if k in os.environ:
del os.environ[k]
fragment = VerlEngine(
model_path=model_name,
mem_fraction_static=mem_fraction_static,
device_mesh_cpu=device_mesh_cpu["tp"],
base_gpu_id=dp_rank,
gpu_id_step=dp_size,
port=30000,
# for DeepSeek-V2-Lite + DP Attention
# enable_dp_attention=True, port=30000 + dp_rank * 100,
)
_log(f"{fragment=}")
prompt_all = [
["1+1=2, 1+2=3, 1+3=4, 1+4=", "9-1=8, 8-1=7, 7-1="],
["2*1=2, 2*2=4, 2*3=", "8/2=4, 6/2="],
]
prompt = prompt_all[dp_rank]
output = fragment.generate(
prompt=prompt,
sampling_params=dict(max_new_tokens=16, temperature=0.0),
)
_log(f"{prompt=} {output=}")
fragment.shutdown()
_log(f"End script")
if __name__ == "__main__":
run()