102 lines
3.1 KiB
Python
102 lines
3.1 KiB
Python
"""
|
|
Copyright (c) 2025 by FlashInfer team.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
"""
|
|
|
|
import argparse
|
|
import csv
|
|
import json
|
|
from collections import namedtuple
|
|
from enum import Enum
|
|
from typing import Any, Dict, List, Tuple
|
|
|
|
import torch
|
|
from tg4perfetto import TraceGenerator
|
|
|
|
|
|
class EventType(Enum):
|
|
kBegin = 0
|
|
kEnd = 1
|
|
kInstant = 2
|
|
|
|
|
|
def decode_tag(tag, num_blocks, num_groups):
|
|
"""
|
|
Decode a profiler tag into (block_idx, group_idx, event_idx, event_type, sm_id).
|
|
Tag layout:
|
|
bits 0-1: event_type
|
|
bits 2-11: event_idx
|
|
bits 12-23: block_group_idx
|
|
bits 24-31: sm_id
|
|
"""
|
|
sm_id = (tag >> 24) & 0xFF
|
|
block_group_idx = (tag >> 12) & 0xFFF
|
|
event_idx = (tag >> 2) & 0x3FF
|
|
event_type = tag & 0x3
|
|
block_idx = block_group_idx // num_groups
|
|
group_idx = block_group_idx % num_groups
|
|
return block_idx, group_idx, event_idx, event_type, sm_id
|
|
|
|
|
|
def export_to_perfetto_trace(
|
|
profiler_buffer: torch.Tensor,
|
|
event_names: List[str],
|
|
file_name: str,
|
|
) -> None:
|
|
assert profiler_buffer.dtype == torch.uint64
|
|
profiler_buffer_host = profiler_buffer.cpu()
|
|
num_blocks, num_groups = profiler_buffer_host[:1].view(dtype=torch.int32)
|
|
num_blocks = int(num_blocks)
|
|
num_groups = int(num_groups)
|
|
|
|
tgen = TraceGenerator(file_name)
|
|
|
|
pid_map = {}
|
|
tid_map = {}
|
|
track_map: Dict[Tuple[int, int, int], Any] = {}
|
|
|
|
for i in range(1, len(profiler_buffer_host)):
|
|
if profiler_buffer_host[i] == 0:
|
|
continue
|
|
tag, timestamp = profiler_buffer_host[i : i + 1].view(dtype=torch.uint32)
|
|
tag = int(tag)
|
|
timestamp = int(timestamp)
|
|
block_idx, group_idx, event_idx, event_type, sm_id = decode_tag(
|
|
tag, num_blocks, num_groups
|
|
)
|
|
|
|
# create trackers
|
|
if block_idx not in pid_map:
|
|
pid_map[block_idx] = tgen.create_group(f"sm_{sm_id}_block_{block_idx}")
|
|
pid = pid_map[block_idx]
|
|
if (block_idx, group_idx) not in tid_map:
|
|
tid_map[(block_idx, group_idx)] = pid.create_group(f"group_{group_idx}")
|
|
tid = tid_map[(block_idx, group_idx)]
|
|
event = event_names[event_idx]
|
|
|
|
if (block_idx, group_idx, event_idx) in track_map:
|
|
track = track_map[(block_idx, group_idx, event_idx)]
|
|
else:
|
|
track = tid.create_track()
|
|
track_map[(block_idx, group_idx, event_idx)] = track
|
|
|
|
if event_type == EventType.kBegin.value:
|
|
track.open(timestamp, event)
|
|
elif event_type == EventType.kEnd.value:
|
|
track.close(timestamp)
|
|
elif event_type == EventType.kInstant.value:
|
|
track.instant(timestamp, event)
|
|
|
|
tgen.flush()
|