sglang_v0.5.2/flashinfer_0.3.1/flashinfer/profiler/__init__.py

102 lines
3.1 KiB
Python

"""
Copyright (c) 2025 by FlashInfer team.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import argparse
import csv
import json
from collections import namedtuple
from enum import Enum
from typing import Any, Dict, List, Tuple
import torch
from tg4perfetto import TraceGenerator
class EventType(Enum):
kBegin = 0
kEnd = 1
kInstant = 2
def decode_tag(tag, num_blocks, num_groups):
"""
Decode a profiler tag into (block_idx, group_idx, event_idx, event_type, sm_id).
Tag layout:
bits 0-1: event_type
bits 2-11: event_idx
bits 12-23: block_group_idx
bits 24-31: sm_id
"""
sm_id = (tag >> 24) & 0xFF
block_group_idx = (tag >> 12) & 0xFFF
event_idx = (tag >> 2) & 0x3FF
event_type = tag & 0x3
block_idx = block_group_idx // num_groups
group_idx = block_group_idx % num_groups
return block_idx, group_idx, event_idx, event_type, sm_id
def export_to_perfetto_trace(
profiler_buffer: torch.Tensor,
event_names: List[str],
file_name: str,
) -> None:
assert profiler_buffer.dtype == torch.uint64
profiler_buffer_host = profiler_buffer.cpu()
num_blocks, num_groups = profiler_buffer_host[:1].view(dtype=torch.int32)
num_blocks = int(num_blocks)
num_groups = int(num_groups)
tgen = TraceGenerator(file_name)
pid_map = {}
tid_map = {}
track_map: Dict[Tuple[int, int, int], Any] = {}
for i in range(1, len(profiler_buffer_host)):
if profiler_buffer_host[i] == 0:
continue
tag, timestamp = profiler_buffer_host[i : i + 1].view(dtype=torch.uint32)
tag = int(tag)
timestamp = int(timestamp)
block_idx, group_idx, event_idx, event_type, sm_id = decode_tag(
tag, num_blocks, num_groups
)
# create trackers
if block_idx not in pid_map:
pid_map[block_idx] = tgen.create_group(f"sm_{sm_id}_block_{block_idx}")
pid = pid_map[block_idx]
if (block_idx, group_idx) not in tid_map:
tid_map[(block_idx, group_idx)] = pid.create_group(f"group_{group_idx}")
tid = tid_map[(block_idx, group_idx)]
event = event_names[event_idx]
if (block_idx, group_idx, event_idx) in track_map:
track = track_map[(block_idx, group_idx, event_idx)]
else:
track = tid.create_track()
track_map[(block_idx, group_idx, event_idx)] = track
if event_type == EventType.kBegin.value:
track.open(timestamp, event)
elif event_type == EventType.kEnd.value:
track.close(timestamp)
elif event_type == EventType.kInstant.value:
track.instant(timestamp, event)
tgen.flush()