diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 62d1ff9..9021c01 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -71,6 +71,8 @@ from sglang.srt.utils import ( suppress_other_loggers, ) from sglang.utils import get_exception_traceback +from rpdTracerControl import rpdTracerControl +rpdTracerControl.skipCreate() logger = logging.getLogger(__name__) @@ -245,6 +247,7 @@ class Scheduler: ], with_stack=True, ) + self.rpd = rpdTracerControl() @torch.inference_mode() def event_loop(self): @@ -1027,15 +1030,24 @@ class Scheduler: def start_profile(self) -> None: if self.profiler is None: raise RuntimeError("Profiler is not enabled.") - self.profiler.start() + #self.profiler.start() #block pytorch profiler for rpd profiler enabling + if self.tp_rank == 0 or self.tp_rank == 1: + self.rpd.start() + self.rpd.rangePush("", "rpd profile range", "") + logger.info("rpd is enabled") def stop_profile(self) -> None: if self.profiler is None: raise RuntimeError("Profiler is not enabled.") - self.profiler.stop() - self.profiler.export_chrome_trace( - self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz" - ) + #self.profiler.stop() + #self.profiler.export_chrome_trace( + # self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz" + #) + if self.tp_rank ==0 or self.tp_rank ==1: + self.rpd.rangePop() + self.rpd.stop() + self.rpd.flush() + logger.info("rpd is done") logger.info("Profiler is done")