# mypy: ignore-errors import argparse import ast import json import re import sys from pathlib import Path def get_source_segment(source, node): return ast.get_source_segment(source, node) def load_registry(path): if path.exists(): with path.open() as f: return json.load(f) return {} def save_registry(reg, path): with path.open("w") as f: json.dump(reg, f, indent=2) def next_gb_id(reg): ids = [int(x[2:]) for x in reg if x.startswith("GB") and x[2:].isdigit()] return f"GB{(max(ids, default=0) + 1):04d}" def clean_string(s): """ Normalizes string literals by removing formatting artifacts and escape sequences. Handles f-strings, quotes, newlines, and other syntax elements for cleaner output. """ if isinstance(s, str): # Convert f-string prefix to regular string prefix (e.g., f"hello" -> "hello") s = re.sub(r'^f["\']', r'"', s) # Replace quoted strings with f-prefix in the middle with a space (e.g., " f"" -> " ") s = re.sub(r'["\'] f["\']', " ", s) # Remove surrounding quotes, keeping only the content (e.g., "hello" -> hello) s = re.sub(r'^["\'](.*)["\']$', r"\1", s) # Replace any whitespace s = " ".join(s.splitlines()) # Replace escaped quotes with their unescaped versions s = s.encode().decode("unicode_escape") # Replace adjacent quoted strings with a space (e.g., " "" -> " ") s = re.sub(r'" "', " ", s) return s def expand_hints(hints): # Expands hint references to their actual values from graph_break_hints. from torch._dynamo import graph_break_hints hint_constants = { name: value for name, value in graph_break_hints.__dict__.items() if isinstance(value, list) and name.isupper() } expanded_hints = [] for hint in hints: for name, value in hint_constants.items(): if f"*graph_break_hints.{name}" in hint: expanded_hints.extend(value) break return expanded_hints def extract_info_from_keyword(source, kw): """ Extracts and returns the value of a keyword argument from an AST node. This function handles different types of AST nodes: - If the node is a constant, it returns the constant value. - If the node is an f-string, it reconstructs the string by evaluating formatted values and concatenating them with string literals. - For other types, it cleans the source segment to remove formatting artifacts. """ param_source = get_source_segment(source, kw.value) if isinstance(kw.value, ast.Constant): return kw.value.value elif isinstance(kw.value, ast.JoinedStr): evaluated_context = [] for value in kw.value.values: if isinstance(value, ast.FormattedValue): evaluated_context.append(f"{{{ast.unparse(value.value)}}}") elif isinstance(value, ast.Constant): evaluated_context.append(value.value) return "".join(evaluated_context) else: return clean_string(param_source) def find_unimplemented_v2_calls(path): results = [] path = Path(path) if path.is_dir(): file_paths = path.glob("**/*.py") else: file_paths = [path] for file_path in file_paths: with open(file_path) as f: source = f.read() try: tree = ast.parse(source) for node in ast.walk(tree): if isinstance(node, ast.FunctionDef): if node.name == "unimplemented_v2": continue if ( isinstance(node, ast.Call) and isinstance(node.func, ast.Name) and node.func.id == "unimplemented_v2" ): info = { "gb_type": None, "context": None, "explanation": None, "hints": [], } for kw in node.keywords: if kw.arg in info: info[kw.arg] = extract_info_from_keyword(source, kw) if info["gb_type"] is None: continue if info["hints"]: hints = info["hints"] expanded_hints = [] items = re.findall(r'"([^"]*)"', hints) if items: expanded_hints.extend(items) if "*graph_break_hints." in hints: expanded_hints.extend(expand_hints([hints])) info["hints"] = expanded_hints results.append(info) except SyntaxError: print(f"Syntax error in {file_path}") return results def cmd_add_new_gb_type(gb_type, file_path, registry_path, additional_info=None): """ Add a new graph break type to the registry. Args: gb_type: The graph break type to add file_path: Path to the file containing the unimplemented_v2 call registry_path: Path to the registry JSON file """ registry_path = Path(registry_path) reg = load_registry(registry_path) existing_gb_types = {entry[0]["Gb_type"] for entry in reg.values()} if gb_type in existing_gb_types: print( f"Error: gb_type '{gb_type}' already exists in registry. Please rename the gb_type so it can be unique." ) return False calls = find_unimplemented_v2_calls(Path(file_path)) matching_call = next((call for call in calls if call["gb_type"] == gb_type), None) if not matching_call: print( f"Error: Could not find unimplemented_v2 call with gb_type '{gb_type}' in {file_path}" ) return False gb_id = next_gb_id(reg) reg[gb_id] = [ { "Gb_type": gb_type, "Context": matching_call["context"], "Explanation": matching_call["explanation"], "Hints": matching_call["hints"] or [], **({"Additional_Info": [additional_info]} if additional_info else {}), } ] save_registry(reg, registry_path) print(f"Added {gb_type} to registry with ID {gb_id}") return True def cmd_update_gb_type( old_gb_type, file_path, registry_path, new_gb_type=None, additional_info=None ): """ Update an existing graph break type in the registry by adding a new version to the version history list. Args: old_gb_type: The current graph break type to update file_path: Path to the file containing the updated unimplemented_v2 call registry_path: Path to the registry JSON file new_gb_type: Optional new gb_type name to replace the old one """ registry_path = Path(registry_path) reg = load_registry(registry_path) gb_id_map = {entry[0]["Gb_type"]: id for id, entry in reg.items()} gb_id = gb_id_map.get(old_gb_type) if gb_id is None: print(f"Error: gb_type '{old_gb_type}' not found in registry.") return False search_gb_type = new_gb_type if new_gb_type else old_gb_type calls = find_unimplemented_v2_calls(Path(file_path)) matching_call = next( (call for call in calls if call["gb_type"] == search_gb_type), None ) if not matching_call: print( f"Error: Could not find unimplemented_v2 call with gb_type '{search_gb_type}' in {file_path}" ) return False if ( matching_call["gb_type"] != old_gb_type and matching_call["gb_type"] in gb_id_map ): print( f"Error: New gb_type '{matching_call['gb_type']}' already exists in registry. Please use a unique gb_type." ) return False new_entry = { "Gb_type": matching_call["gb_type"], "Context": matching_call["context"], "Explanation": matching_call["explanation"], "Hints": matching_call["hints"] or [], } if additional_info: additional_info_list = reg[gb_id][0].get("Additional_Info", []) new_entry["Additional_Info"] = ( additional_info_list + [additional_info] if additional_info_list else [additional_info] ) elif "Additional_Info" in reg[gb_id][0]: new_entry["Additional_Info"] = reg[gb_id][0]["Additional_Info"] reg[gb_id].insert(0, new_entry) save_registry(reg, registry_path) print( f"Updated {old_gb_type} to {matching_call['gb_type']} in registry with ID {gb_id}" ) return True def create_registry(dynamo_dir, registry_path): calls = find_unimplemented_v2_calls(dynamo_dir) registry = {} gb_types = {} for info in calls: gb_types[info["gb_type"]] = info GB_ID_INDEX = 0000 for i, (gb_type, info) in enumerate(sorted(gb_types.items()), GB_ID_INDEX): gb_id = f"GB{i:04d}" hints = info["hints"] registry[gb_id] = [ { "Gb_type": gb_type, "Context": info["context"], "Explanation": info["explanation"], "Hints": hints if hints else [], } ] with open(registry_path, "w") as f: json.dump(registry, f, indent=2) def main(): script_dir = Path(__file__).resolve().parent repo_root = script_dir.parent.parent registry_path = script_dir / "graph_break_registry.json" try: import torch._dynamo default_dynamo_dir = str(Path(torch._dynamo.__file__).parent) except ImportError: default_dynamo_dir = str(repo_root / "torch" / "_dynamo") parser = argparse.ArgumentParser(description="Manage graph break registry.") subparsers = parser.add_subparsers(dest="command", help="Command to execute") create_parser = subparsers.add_parser("create", help="Create registry from scratch") create_parser.add_argument( "--dynamo_dir", type=str, default=default_dynamo_dir, help="Directory to search for unimplemented_v2 calls.", ) add_parser = subparsers.add_parser("add", help="Add a gb_type to registry") add_parser.add_argument("gb_type", help="The gb_type to add") add_parser.add_argument( "file_path", help="Path to the file containing the unimplemented_v2 call" ) add_parser.add_argument( "--additional-info", help="Optional additional information to include" ) update_parser = subparsers.add_parser( "update", help="Update an existing gb_type in registry" ) update_parser.add_argument("gb_type", help="The gb_type to update") update_parser.add_argument( "file_path", help="Path to the file containing the updated unimplemented_v2 call", ) update_parser.add_argument( "--new_gb_type", help="New gb_type name if it has changed", default=None ) update_parser.add_argument( "--additional-info", help="Optional additional information to include" ) parser.add_argument( "--registry-path", type=str, default=str(registry_path), help="Path to save the registry JSON file", ) args = parser.parse_args() if args.command == "create": create_registry(args.dynamo_dir, args.registry_path) elif args.command == "add": success = cmd_add_new_gb_type( args.gb_type, args.file_path, args.registry_path, args.additional_info ) if not success: sys.exit(1) elif args.command == "update": success = cmd_update_gb_type( args.gb_type, args.file_path, args.registry_path, args.new_gb_type, args.additional_info, ) if not success: sys.exit(1) else: parser.print_help() if __name__ == "__main__": main()