150 lines
4.5 KiB
Python
150 lines
4.5 KiB
Python
"""
|
|
Copyright (c) 2025 by FlashInfer team.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
"""
|
|
|
|
from typing import List, Optional
|
|
|
|
from .fusion_rules import FusionRule, get_default_fusion_rules
|
|
from .op import Op
|
|
from .types import TensorType
|
|
from .validators import CompileError, ValidityCheck, get_default_validity_checks
|
|
|
|
|
|
class Compiler:
|
|
def __init__(self) -> None:
|
|
self.fusion_rules: List[FusionRule] = []
|
|
self.validity_checks: List[ValidityCheck] = []
|
|
self._install_defaults()
|
|
|
|
def register_fusion_rule(self, rule: FusionRule) -> None:
|
|
self.fusion_rules.append(rule)
|
|
self.fusion_rules.sort(key=lambda r: -r.prio)
|
|
|
|
def register_validity_check(self, check: ValidityCheck) -> None:
|
|
self.validity_checks.append(check)
|
|
|
|
def compile(self, ops: List[Op]) -> List[Op]:
|
|
if not ops:
|
|
raise CompileError("Cannot compile empty operator list")
|
|
|
|
compiled_ops = list(ops)
|
|
|
|
self._type_check(compiled_ops)
|
|
|
|
self._run_validity_checks(compiled_ops)
|
|
|
|
compiled_ops = self._fuse_all(compiled_ops)
|
|
|
|
return compiled_ops
|
|
|
|
def _type_check(self, ops: List[Op]) -> None:
|
|
first_op = ops[0]
|
|
|
|
current_type = first_op.IN
|
|
|
|
if current_type not in [TensorType.LOGITS, TensorType.PROBS]:
|
|
raise CompileError(
|
|
f"First operator ({first_op.__class__.__name__}) cannot accept standard pipeline inputs. "
|
|
f"Expected LOGITS or PROBS, but operator accepts: {first_op.IN}"
|
|
)
|
|
|
|
for i, op in enumerate(ops):
|
|
if current_type != op.IN:
|
|
raise CompileError(
|
|
f"Type mismatch at operator {i} ({op.__class__.__name__}). "
|
|
f"Expected input type: {current_type}, but operator accepts: {op.IN}. "
|
|
f"Previous operator output: {current_type}"
|
|
)
|
|
|
|
current_type = op.OUT
|
|
|
|
def _run_validity_checks(self, ops: List[Op]) -> None:
|
|
for check in self.validity_checks:
|
|
check(ops)
|
|
|
|
def _fuse_all(self, ops: List[Op]) -> List[Op]:
|
|
i = 0
|
|
while i < len(ops):
|
|
fusion_applied = False
|
|
|
|
for rule in self.fusion_rules:
|
|
span = len(rule.pattern)
|
|
|
|
if i + span > len(ops):
|
|
continue
|
|
|
|
window = ops[i : i + span]
|
|
|
|
if self._pattern_matches(window, rule.pattern) and rule.guard(window):
|
|
fused_op = rule.build(window)
|
|
ops[i : i + span] = [fused_op]
|
|
|
|
i = max(i - 1, 0)
|
|
fusion_applied = True
|
|
break
|
|
|
|
if not fusion_applied:
|
|
i += 1
|
|
|
|
return ops
|
|
|
|
def _pattern_matches(self, window: List[Op], pattern: tuple) -> bool:
|
|
if len(window) != len(pattern):
|
|
return False
|
|
|
|
return all(isinstance(window[i], pattern[i]) for i in range(len(pattern)))
|
|
|
|
def _install_defaults(self) -> None:
|
|
for check in get_default_validity_checks():
|
|
self.validity_checks.append(check)
|
|
|
|
for rule in get_default_fusion_rules():
|
|
self.register_fusion_rule(rule)
|
|
|
|
|
|
def compile_pipeline(
|
|
ops: List[Op],
|
|
custom_fusion_rules: Optional[List[FusionRule]] = None,
|
|
custom_validity_checks: Optional[List[ValidityCheck]] = None,
|
|
) -> List[Op]:
|
|
"""
|
|
Compile a pipeline of operators into a list of compiled operators.
|
|
|
|
Parameters
|
|
----------
|
|
ops : List[Op]
|
|
List of operators to compile
|
|
custom_fusion_rules : Optional[List[FusionRule]]
|
|
List of custom fusion rules to use
|
|
custom_validity_checks : Optional[List[ValidityCheck]]
|
|
List of custom validity checks to use
|
|
|
|
Returns
|
|
-------
|
|
List[Op]
|
|
List of compiled operators
|
|
"""
|
|
compiler = Compiler()
|
|
|
|
if custom_fusion_rules:
|
|
for rule in custom_fusion_rules:
|
|
compiler.register_fusion_rule(rule)
|
|
|
|
if custom_validity_checks:
|
|
for check in custom_validity_checks:
|
|
compiler.register_validity_check(check)
|
|
|
|
return compiler.compile(ops)
|