124 lines
3.7 KiB
Python
124 lines
3.7 KiB
Python
"""
|
|
Copyright (c) 2025 by FlashInfer team.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
"""
|
|
|
|
from typing import Callable, List, NamedTuple, Tuple
|
|
|
|
from .op import Op
|
|
from .operators import (
|
|
FusedProbsMinPSampleOp,
|
|
FusedProbsTopKSampleOp,
|
|
FusedProbsTopKTopPSampleOp,
|
|
FusedProbsTopPSampleOp,
|
|
FusedTemperatureSoftmaxOp,
|
|
MinPOp,
|
|
ProbsSampleOp,
|
|
ProbsTopKOp,
|
|
SoftmaxOp,
|
|
TemperatureOp,
|
|
TopPOp,
|
|
)
|
|
|
|
|
|
class FusionRule(NamedTuple):
|
|
"""
|
|
Attributes:
|
|
pattern: Tuple of operator types to match (e.g., (TopK, Sampling))
|
|
guard: Function that takes the matched operators and returns True if fusion should apply. It accepts a window of operators parameter which is a subset of the operators in the pipeline and is an exact match of the pattern.
|
|
build: Function that takes the matched operators and returns a fused operator
|
|
prio: Priority for rule application (higher = tried earlier)
|
|
"""
|
|
|
|
pattern: Tuple[type, ...]
|
|
guard: Callable[[List[Op]], bool]
|
|
build: Callable[[List[Op]], Op]
|
|
prio: int = 0
|
|
|
|
|
|
def joint_topk_topp_sampleprobs_guard(window: List[Op]) -> bool:
|
|
"""
|
|
Only fuse when joint_topk_topp is True
|
|
"""
|
|
topk_op = window[0]
|
|
joint_topk_topp = getattr(topk_op, "default_params", {}).get(
|
|
"joint_topk_topp", False
|
|
)
|
|
return joint_topk_topp
|
|
|
|
|
|
def build_temperature_softmax(window: List[Op]) -> Op:
|
|
softmax_op = window[1]
|
|
enable_pdl = getattr(softmax_op, "default_params", {}).get("enable_pdl", None)
|
|
return FusedTemperatureSoftmaxOp(enable_pdl=enable_pdl)
|
|
|
|
|
|
def build_topk_sampling(window: List[Op]) -> Op:
|
|
sampling_op = window[1]
|
|
deterministic = getattr(sampling_op, "default_params", {}).get(
|
|
"deterministic", True
|
|
)
|
|
return FusedProbsTopKSampleOp(deterministic=deterministic)
|
|
|
|
|
|
def build_topp_sampling(window: List[Op]) -> Op:
|
|
sampling_op = window[1]
|
|
deterministic = getattr(sampling_op, "default_params", {}).get(
|
|
"deterministic", True
|
|
)
|
|
return FusedProbsTopPSampleOp(deterministic=deterministic)
|
|
|
|
|
|
def build_minp_sampling(window: List[Op]) -> Op:
|
|
sampling_op = window[1]
|
|
deterministic = getattr(sampling_op, "default_params", {}).get(
|
|
"deterministic", True
|
|
)
|
|
return FusedProbsMinPSampleOp(deterministic=deterministic)
|
|
|
|
|
|
def get_default_fusion_rules() -> List[FusionRule]:
|
|
return [
|
|
FusionRule(
|
|
pattern=(TemperatureOp, SoftmaxOp),
|
|
guard=lambda window: True,
|
|
build=build_temperature_softmax,
|
|
prio=100,
|
|
),
|
|
FusionRule(
|
|
pattern=(ProbsTopKOp, TopPOp, ProbsSampleOp),
|
|
guard=joint_topk_topp_sampleprobs_guard,
|
|
build=lambda window: FusedProbsTopKTopPSampleOp(),
|
|
prio=100,
|
|
),
|
|
FusionRule(
|
|
pattern=(ProbsTopKOp, ProbsSampleOp),
|
|
guard=lambda window: True,
|
|
build=build_topk_sampling,
|
|
prio=10,
|
|
),
|
|
FusionRule(
|
|
pattern=(TopPOp, ProbsSampleOp),
|
|
guard=lambda window: True,
|
|
build=build_topp_sampling,
|
|
prio=10,
|
|
),
|
|
FusionRule(
|
|
pattern=(MinPOp, ProbsSampleOp),
|
|
guard=lambda window: True,
|
|
build=build_minp_sampling,
|
|
prio=10,
|
|
),
|
|
]
|