42 lines
1.5 KiB
Python
42 lines
1.5 KiB
Python
from colossalai.shardformer.modeling.jit import get_jit_fused_dropout_add_func
|
|
from colossalai.shardformer.modeling.t5 import get_jit_fused_T5_layer_ff_forward, get_T5_layer_self_attention_forward
|
|
from colossalai.shardformer.policies.base_policy import Policy, SubModuleReplacementDescription
|
|
|
|
|
|
class T5EncoderPolicy(Policy):
|
|
def config_sanity_check(self):
|
|
assert not self.shard_config.enable_tensor_parallelism
|
|
assert not self.shard_config.enable_flash_attention
|
|
|
|
def preprocess(self):
|
|
return self.model
|
|
|
|
def module_policy(self):
|
|
from transformers.models.t5.modeling_t5 import T5LayerFF, T5LayerSelfAttention, T5Stack
|
|
|
|
policy = {}
|
|
|
|
# use jit operator
|
|
if self.shard_config.enable_jit_fused:
|
|
self.append_or_create_method_replacement(
|
|
description={
|
|
"forward": get_jit_fused_T5_layer_ff_forward(),
|
|
"dropout_add": get_jit_fused_dropout_add_func(),
|
|
},
|
|
policy=policy,
|
|
target_key=T5LayerFF,
|
|
)
|
|
self.append_or_create_method_replacement(
|
|
description={
|
|
"forward": get_T5_layer_self_attention_forward(),
|
|
"dropout_add": get_jit_fused_dropout_add_func(),
|
|
},
|
|
policy=policy,
|
|
target_key=T5LayerSelfAttention,
|
|
)
|
|
|
|
return policy
|
|
|
|
def postprocess(self):
|
|
return self.model
|