82 lines
2.2 KiB
Python
82 lines
2.2 KiB
Python
import datetime
|
|
import os
|
|
import sys
|
|
|
|
from torch.distributed.device_mesh import init_device_mesh
|
|
|
|
from sglang.srt.entrypoints.verl_engine import VerlEngine
|
|
|
|
|
|
def run():
|
|
"""
|
|
Example command:
|
|
```
|
|
torchrun --nproc_per_node=8 offline_batch_inference_torchrun.py
|
|
```
|
|
"""
|
|
|
|
local_rank = int(os.environ["LOCAL_RANK"])
|
|
rank = int(os.environ["RANK"])
|
|
world_size = int(os.environ["WORLD_SIZE"])
|
|
|
|
def _log(text):
|
|
t = datetime.datetime.now().strftime("%H:%M:%S")
|
|
print(f"[{t}] [rank={rank}] {text}")
|
|
|
|
_log(
|
|
f'start {local_rank=} {rank=} {world_size=} {sys.argv=} {os.environ.get("CUDA_VISIBLE_DEVICES")}'
|
|
)
|
|
|
|
tp_size = 4
|
|
dp_size = 2
|
|
assert world_size == tp_size * dp_size
|
|
|
|
device_mesh_kwargs = dict(
|
|
mesh_shape=(tp_size, dp_size, 1), mesh_dim_names=["tp", "dp", "pp"]
|
|
)
|
|
device_mesh_cpu = init_device_mesh("cpu", **device_mesh_kwargs)
|
|
_log(f"{device_mesh_cpu=}")
|
|
|
|
tp_rank = device_mesh_cpu.get_local_rank("tp")
|
|
dp_rank = device_mesh_cpu.get_local_rank("dp")
|
|
_log(f"{tp_rank=} {tp_size=} ; {dp_rank=} {dp_size=}")
|
|
|
|
model_name, mem_fraction_static = "meta-llama/Llama-3.2-1B-Instruct", 0.1
|
|
# model_name, mem_fraction_static = "meta-llama/Llama-3.1-70B-Instruct", 0.9 # test large models
|
|
# model_name, mem_fraction_static = "deepseek-ai/DeepSeek-V2-Lite", 0.8
|
|
|
|
for k in ["TORCHELASTIC_USE_AGENT_STORE"]:
|
|
if k in os.environ:
|
|
del os.environ[k]
|
|
|
|
fragment = VerlEngine(
|
|
model_path=model_name,
|
|
mem_fraction_static=mem_fraction_static,
|
|
device_mesh_cpu=device_mesh_cpu["tp"],
|
|
base_gpu_id=dp_rank,
|
|
gpu_id_step=dp_size,
|
|
port=30000,
|
|
# for DeepSeek-V2-Lite + DP Attention
|
|
# enable_dp_attention=True, port=30000 + dp_rank * 100,
|
|
)
|
|
_log(f"{fragment=}")
|
|
|
|
prompt_all = [
|
|
["1+1=2, 1+2=3, 1+3=4, 1+4=", "9-1=8, 8-1=7, 7-1="],
|
|
["2*1=2, 2*2=4, 2*3=", "8/2=4, 6/2="],
|
|
]
|
|
prompt = prompt_all[dp_rank]
|
|
|
|
output = fragment.generate(
|
|
prompt=prompt,
|
|
sampling_params=dict(max_new_tokens=16, temperature=0.0),
|
|
)
|
|
_log(f"{prompt=} {output=}")
|
|
|
|
fragment.shutdown()
|
|
_log(f"End script")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run()
|