from colossalai.shardformer.modeling.jit import get_jit_fused_dropout_add_func from colossalai.shardformer.modeling.t5 import get_jit_fused_T5_layer_ff_forward, get_T5_layer_self_attention_forward from colossalai.shardformer.policies.base_policy import Policy, SubModuleReplacementDescription class T5EncoderPolicy(Policy): def config_sanity_check(self): assert not self.shard_config.enable_tensor_parallelism assert not self.shard_config.enable_flash_attention def preprocess(self): return self.model def module_policy(self): from transformers.models.t5.modeling_t5 import T5LayerFF, T5LayerSelfAttention, T5Stack policy = {} # use jit operator if self.shard_config.enable_jit_fused: self.append_or_create_method_replacement( description={ "forward": get_jit_fused_T5_layer_ff_forward(), "dropout_add": get_jit_fused_dropout_add_func(), }, policy=policy, target_key=T5LayerFF, ) self.append_or_create_method_replacement( description={ "forward": get_T5_layer_self_attention_forward(), "dropout_add": get_jit_fused_dropout_add_func(), }, policy=policy, target_key=T5LayerSelfAttention, ) return policy def postprocess(self): return self.model