mysora/opensora/acceleration/shardformer/policy/t5_encoder.py

42 lines
1.5 KiB
Python

from colossalai.shardformer.modeling.jit import get_jit_fused_dropout_add_func
from colossalai.shardformer.modeling.t5 import get_jit_fused_T5_layer_ff_forward, get_T5_layer_self_attention_forward
from colossalai.shardformer.policies.base_policy import Policy, SubModuleReplacementDescription
class T5EncoderPolicy(Policy):
def config_sanity_check(self):
assert not self.shard_config.enable_tensor_parallelism
assert not self.shard_config.enable_flash_attention
def preprocess(self):
return self.model
def module_policy(self):
from transformers.models.t5.modeling_t5 import T5LayerFF, T5LayerSelfAttention, T5Stack
policy = {}
# use jit operator
if self.shard_config.enable_jit_fused:
self.append_or_create_method_replacement(
description={
"forward": get_jit_fused_T5_layer_ff_forward(),
"dropout_add": get_jit_fused_dropout_add_func(),
},
policy=policy,
target_key=T5LayerFF,
)
self.append_or_create_method_replacement(
description={
"forward": get_T5_layer_self_attention_forward(),
"dropout_add": get_jit_fused_dropout_add_func(),
},
policy=policy,
target_key=T5LayerSelfAttention,
)
return policy
def postprocess(self):
return self.model