93 lines
3.1 KiB
Python
93 lines
3.1 KiB
Python
# Owner(s): ["module: higher order operators"]
|
|
import importlib
|
|
import pkgutil
|
|
|
|
import torch
|
|
from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo, TestCase
|
|
from torch.testing._internal.hop_db import (
|
|
FIXME_hop_that_doesnt_have_opinfo_test_allowlist,
|
|
hop_db,
|
|
)
|
|
|
|
|
|
def do_imports():
|
|
for mod in pkgutil.walk_packages(
|
|
torch._higher_order_ops.__path__, "torch._higher_order_ops."
|
|
):
|
|
modname = mod.name
|
|
importlib.import_module(modname)
|
|
|
|
|
|
do_imports()
|
|
|
|
|
|
@skipIfTorchDynamo("not applicable")
|
|
class TestHOPInfra(TestCase):
|
|
def test_all_hops_have_opinfo(self):
|
|
"""All HOPs should have an OpInfo in torch/testing/_internal/hop_db.py"""
|
|
from torch._ops import _higher_order_ops
|
|
|
|
hops_that_have_op_info = {k.name for k in hop_db}
|
|
all_hops = _higher_order_ops.keys()
|
|
|
|
missing_ops = set()
|
|
|
|
for op in all_hops:
|
|
if (
|
|
op not in hops_that_have_op_info
|
|
and op not in FIXME_hop_that_doesnt_have_opinfo_test_allowlist
|
|
):
|
|
missing_ops.add(op)
|
|
|
|
self.assertTrue(
|
|
len(missing_ops) == 0,
|
|
f"Missing hop_db OpInfo entries for {missing_ops}, please add them to torch/testing/_internal/hop_db.py",
|
|
)
|
|
|
|
def test_all_hops_are_imported(self):
|
|
"""All HOPs should be listed in torch._higher_order_ops.__all__
|
|
|
|
Some constraints (see test_testing.py::TestImports)
|
|
- Sympy must be lazily imported
|
|
- Dynamo must be lazily imported
|
|
"""
|
|
imported_hops = torch._higher_order_ops.__all__
|
|
registered_hops = torch._ops._higher_order_ops.keys()
|
|
|
|
# Please don't add anything here.
|
|
# We want to ensure that all HOPs are imported at "import torch" time.
|
|
# It is bad if someone tries to access torch.ops.higher_order.cond
|
|
# and it doesn't exist (this may happen if your HOP isn't imported at
|
|
# "import torch" time).
|
|
FIXME_ALLOWLIST = {
|
|
"autograd_function_apply",
|
|
"run_with_rng_state",
|
|
"graphsafe_run_with_rng_state",
|
|
"map_impl",
|
|
"_export_tracepoint",
|
|
"run_and_save_rng_state",
|
|
"map",
|
|
"custom_function_call",
|
|
"trace_wrapped",
|
|
"triton_kernel_wrapper_functional",
|
|
"triton_kernel_wrapper_mutation",
|
|
"wrap", # Really weird failure -- importing this causes Dynamo to choke on checkpoint
|
|
}
|
|
not_imported_hops = registered_hops - imported_hops
|
|
not_imported_hops = not_imported_hops - FIXME_ALLOWLIST
|
|
self.assertEqual(
|
|
not_imported_hops,
|
|
set(),
|
|
msg="All HOPs must be listed under torch/_higher_order_ops/__init__.py's __all__.",
|
|
)
|
|
|
|
def test_imports_from_all_work(self):
|
|
"""All APIs listed in torch._higher_order_ops.__all__ must be importable"""
|
|
stuff = torch._higher_order_ops.__all__
|
|
for attr in stuff:
|
|
getattr(torch._higher_order_ops, attr)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|