375 lines
12 KiB
Python
375 lines
12 KiB
Python
# mypy: ignore-errors
|
|
|
|
import argparse
|
|
import ast
|
|
import json
|
|
import re
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
|
|
def get_source_segment(source, node):
|
|
return ast.get_source_segment(source, node)
|
|
|
|
|
|
def load_registry(path):
|
|
if path.exists():
|
|
with path.open() as f:
|
|
return json.load(f)
|
|
return {}
|
|
|
|
|
|
def save_registry(reg, path):
|
|
with path.open("w") as f:
|
|
json.dump(reg, f, indent=2)
|
|
|
|
|
|
def next_gb_id(reg):
|
|
ids = [int(x[2:]) for x in reg if x.startswith("GB") and x[2:].isdigit()]
|
|
return f"GB{(max(ids, default=0) + 1):04d}"
|
|
|
|
|
|
def clean_string(s):
|
|
"""
|
|
Normalizes string literals by removing formatting artifacts and escape sequences.
|
|
Handles f-strings, quotes, newlines, and other syntax elements for cleaner output.
|
|
"""
|
|
if isinstance(s, str):
|
|
# Convert f-string prefix to regular string prefix (e.g., f"hello" -> "hello")
|
|
s = re.sub(r'^f["\']', r'"', s)
|
|
# Replace quoted strings with f-prefix in the middle with a space (e.g., " f"" -> " ")
|
|
s = re.sub(r'["\'] f["\']', " ", s)
|
|
# Remove surrounding quotes, keeping only the content (e.g., "hello" -> hello)
|
|
s = re.sub(r'^["\'](.*)["\']$', r"\1", s)
|
|
# Replace any whitespace
|
|
s = " ".join(s.splitlines())
|
|
# Replace escaped quotes with their unescaped versions
|
|
s = s.encode().decode("unicode_escape")
|
|
# Replace adjacent quoted strings with a space (e.g., " "" -> " ")
|
|
s = re.sub(r'" "', " ", s)
|
|
return s
|
|
|
|
|
|
def expand_hints(hints):
|
|
# Expands hint references to their actual values from graph_break_hints.
|
|
from torch._dynamo import graph_break_hints
|
|
|
|
hint_constants = {
|
|
name: value
|
|
for name, value in graph_break_hints.__dict__.items()
|
|
if isinstance(value, list) and name.isupper()
|
|
}
|
|
|
|
expanded_hints = []
|
|
for hint in hints:
|
|
for name, value in hint_constants.items():
|
|
if f"*graph_break_hints.{name}" in hint:
|
|
expanded_hints.extend(value)
|
|
break
|
|
return expanded_hints
|
|
|
|
|
|
def extract_info_from_keyword(source, kw):
|
|
"""
|
|
Extracts and returns the value of a keyword argument from an AST node.
|
|
|
|
This function handles different types of AST nodes:
|
|
- If the node is a constant, it returns the constant value.
|
|
- If the node is an f-string, it reconstructs the string by
|
|
evaluating formatted values and concatenating them with string literals.
|
|
- For other types, it cleans the source segment to remove formatting artifacts.
|
|
|
|
"""
|
|
param_source = get_source_segment(source, kw.value)
|
|
if isinstance(kw.value, ast.Constant):
|
|
return kw.value.value
|
|
elif isinstance(kw.value, ast.JoinedStr):
|
|
evaluated_context = []
|
|
for value in kw.value.values:
|
|
if isinstance(value, ast.FormattedValue):
|
|
evaluated_context.append(f"{{{ast.unparse(value.value)}}}")
|
|
elif isinstance(value, ast.Constant):
|
|
evaluated_context.append(value.value)
|
|
return "".join(evaluated_context)
|
|
else:
|
|
return clean_string(param_source)
|
|
|
|
|
|
def find_unimplemented_v2_calls(path):
|
|
results = []
|
|
path = Path(path)
|
|
|
|
if path.is_dir():
|
|
file_paths = path.glob("**/*.py")
|
|
else:
|
|
file_paths = [path]
|
|
|
|
for file_path in file_paths:
|
|
with open(file_path) as f:
|
|
source = f.read()
|
|
try:
|
|
tree = ast.parse(source)
|
|
|
|
for node in ast.walk(tree):
|
|
if isinstance(node, ast.FunctionDef):
|
|
if node.name == "unimplemented_v2":
|
|
continue
|
|
if (
|
|
isinstance(node, ast.Call)
|
|
and isinstance(node.func, ast.Name)
|
|
and node.func.id == "unimplemented_v2"
|
|
):
|
|
info = {
|
|
"gb_type": None,
|
|
"context": None,
|
|
"explanation": None,
|
|
"hints": [],
|
|
}
|
|
|
|
for kw in node.keywords:
|
|
if kw.arg in info:
|
|
info[kw.arg] = extract_info_from_keyword(source, kw)
|
|
|
|
if info["gb_type"] is None:
|
|
continue
|
|
|
|
if info["hints"]:
|
|
hints = info["hints"]
|
|
expanded_hints = []
|
|
items = re.findall(r'"([^"]*)"', hints)
|
|
if items:
|
|
expanded_hints.extend(items)
|
|
|
|
if "*graph_break_hints." in hints:
|
|
expanded_hints.extend(expand_hints([hints]))
|
|
|
|
info["hints"] = expanded_hints
|
|
|
|
results.append(info)
|
|
except SyntaxError:
|
|
print(f"Syntax error in {file_path}")
|
|
|
|
return results
|
|
|
|
|
|
def cmd_add_new_gb_type(gb_type, file_path, registry_path, additional_info=None):
|
|
"""
|
|
Add a new graph break type to the registry.
|
|
|
|
Args:
|
|
gb_type: The graph break type to add
|
|
file_path: Path to the file containing the unimplemented_v2 call
|
|
registry_path: Path to the registry JSON file
|
|
"""
|
|
registry_path = Path(registry_path)
|
|
reg = load_registry(registry_path)
|
|
|
|
existing_gb_types = {entry[0]["Gb_type"] for entry in reg.values()}
|
|
if gb_type in existing_gb_types:
|
|
print(
|
|
f"Error: gb_type '{gb_type}' already exists in registry. Please rename the gb_type so it can be unique."
|
|
)
|
|
return False
|
|
|
|
calls = find_unimplemented_v2_calls(Path(file_path))
|
|
matching_call = next((call for call in calls if call["gb_type"] == gb_type), None)
|
|
|
|
if not matching_call:
|
|
print(
|
|
f"Error: Could not find unimplemented_v2 call with gb_type '{gb_type}' in {file_path}"
|
|
)
|
|
return False
|
|
|
|
gb_id = next_gb_id(reg)
|
|
reg[gb_id] = [
|
|
{
|
|
"Gb_type": gb_type,
|
|
"Context": matching_call["context"],
|
|
"Explanation": matching_call["explanation"],
|
|
"Hints": matching_call["hints"] or [],
|
|
**({"Additional_Info": [additional_info]} if additional_info else {}),
|
|
}
|
|
]
|
|
|
|
save_registry(reg, registry_path)
|
|
print(f"Added {gb_type} to registry with ID {gb_id}")
|
|
return True
|
|
|
|
|
|
def cmd_update_gb_type(
|
|
old_gb_type, file_path, registry_path, new_gb_type=None, additional_info=None
|
|
):
|
|
"""
|
|
Update an existing graph break type in the registry by adding a new version
|
|
to the version history list.
|
|
|
|
Args:
|
|
old_gb_type: The current graph break type to update
|
|
file_path: Path to the file containing the updated unimplemented_v2 call
|
|
registry_path: Path to the registry JSON file
|
|
new_gb_type: Optional new gb_type name to replace the old one
|
|
"""
|
|
registry_path = Path(registry_path)
|
|
reg = load_registry(registry_path)
|
|
|
|
gb_id_map = {entry[0]["Gb_type"]: id for id, entry in reg.items()}
|
|
gb_id = gb_id_map.get(old_gb_type)
|
|
|
|
if gb_id is None:
|
|
print(f"Error: gb_type '{old_gb_type}' not found in registry.")
|
|
return False
|
|
|
|
search_gb_type = new_gb_type if new_gb_type else old_gb_type
|
|
calls = find_unimplemented_v2_calls(Path(file_path))
|
|
matching_call = next(
|
|
(call for call in calls if call["gb_type"] == search_gb_type), None
|
|
)
|
|
|
|
if not matching_call:
|
|
print(
|
|
f"Error: Could not find unimplemented_v2 call with gb_type '{search_gb_type}' in {file_path}"
|
|
)
|
|
return False
|
|
|
|
if (
|
|
matching_call["gb_type"] != old_gb_type
|
|
and matching_call["gb_type"] in gb_id_map
|
|
):
|
|
print(
|
|
f"Error: New gb_type '{matching_call['gb_type']}' already exists in registry. Please use a unique gb_type."
|
|
)
|
|
return False
|
|
|
|
new_entry = {
|
|
"Gb_type": matching_call["gb_type"],
|
|
"Context": matching_call["context"],
|
|
"Explanation": matching_call["explanation"],
|
|
"Hints": matching_call["hints"] or [],
|
|
}
|
|
|
|
if additional_info:
|
|
additional_info_list = reg[gb_id][0].get("Additional_Info", [])
|
|
new_entry["Additional_Info"] = (
|
|
additional_info_list + [additional_info]
|
|
if additional_info_list
|
|
else [additional_info]
|
|
)
|
|
elif "Additional_Info" in reg[gb_id][0]:
|
|
new_entry["Additional_Info"] = reg[gb_id][0]["Additional_Info"]
|
|
|
|
reg[gb_id].insert(0, new_entry)
|
|
|
|
save_registry(reg, registry_path)
|
|
print(
|
|
f"Updated {old_gb_type} to {matching_call['gb_type']} in registry with ID {gb_id}"
|
|
)
|
|
return True
|
|
|
|
|
|
def create_registry(dynamo_dir, registry_path):
|
|
calls = find_unimplemented_v2_calls(dynamo_dir)
|
|
registry = {}
|
|
|
|
gb_types = {}
|
|
for info in calls:
|
|
gb_types[info["gb_type"]] = info
|
|
|
|
GB_ID_INDEX = 0000
|
|
for i, (gb_type, info) in enumerate(sorted(gb_types.items()), GB_ID_INDEX):
|
|
gb_id = f"GB{i:04d}"
|
|
hints = info["hints"]
|
|
|
|
registry[gb_id] = [
|
|
{
|
|
"Gb_type": gb_type,
|
|
"Context": info["context"],
|
|
"Explanation": info["explanation"],
|
|
"Hints": hints if hints else [],
|
|
}
|
|
]
|
|
|
|
with open(registry_path, "w") as f:
|
|
json.dump(registry, f, indent=2)
|
|
|
|
|
|
def main():
|
|
script_dir = Path(__file__).resolve().parent
|
|
repo_root = script_dir.parent.parent
|
|
registry_path = script_dir / "graph_break_registry.json"
|
|
|
|
try:
|
|
import torch._dynamo
|
|
|
|
default_dynamo_dir = str(Path(torch._dynamo.__file__).parent)
|
|
except ImportError:
|
|
default_dynamo_dir = str(repo_root / "torch" / "_dynamo")
|
|
|
|
parser = argparse.ArgumentParser(description="Manage graph break registry.")
|
|
subparsers = parser.add_subparsers(dest="command", help="Command to execute")
|
|
|
|
create_parser = subparsers.add_parser("create", help="Create registry from scratch")
|
|
create_parser.add_argument(
|
|
"--dynamo_dir",
|
|
type=str,
|
|
default=default_dynamo_dir,
|
|
help="Directory to search for unimplemented_v2 calls.",
|
|
)
|
|
|
|
add_parser = subparsers.add_parser("add", help="Add a gb_type to registry")
|
|
add_parser.add_argument("gb_type", help="The gb_type to add")
|
|
add_parser.add_argument(
|
|
"file_path", help="Path to the file containing the unimplemented_v2 call"
|
|
)
|
|
add_parser.add_argument(
|
|
"--additional-info", help="Optional additional information to include"
|
|
)
|
|
|
|
update_parser = subparsers.add_parser(
|
|
"update", help="Update an existing gb_type in registry"
|
|
)
|
|
update_parser.add_argument("gb_type", help="The gb_type to update")
|
|
update_parser.add_argument(
|
|
"file_path",
|
|
help="Path to the file containing the updated unimplemented_v2 call",
|
|
)
|
|
update_parser.add_argument(
|
|
"--new_gb_type", help="New gb_type name if it has changed", default=None
|
|
)
|
|
update_parser.add_argument(
|
|
"--additional-info", help="Optional additional information to include"
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--registry-path",
|
|
type=str,
|
|
default=str(registry_path),
|
|
help="Path to save the registry JSON file",
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
if args.command == "create":
|
|
create_registry(args.dynamo_dir, args.registry_path)
|
|
elif args.command == "add":
|
|
success = cmd_add_new_gb_type(
|
|
args.gb_type, args.file_path, args.registry_path, args.additional_info
|
|
)
|
|
if not success:
|
|
sys.exit(1)
|
|
elif args.command == "update":
|
|
success = cmd_update_gb_type(
|
|
args.gb_type,
|
|
args.file_path,
|
|
args.registry_path,
|
|
args.new_gb_type,
|
|
args.additional_info,
|
|
)
|
|
if not success:
|
|
sys.exit(1)
|
|
else:
|
|
parser.print_help()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|