from typing import List, Optional, Tuple import torch import torch.distributed as dist from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward from colossalai.shardformer.layer.attn import RingComm, _rescale_out_lse from colossalai.shardformer.layer.utils import SeqParallelUtils from diffusers.models.attention_processor import Attention from opensora.models.vae.tensor_parallel import Conv3dTPRow from opensora.models.vae.utils import get_conv3d_n_chunks from .unet_causal_3d_blocks import UpsampleCausal3D try: from xformers.ops.fmha import ( Context, Inputs, _memory_efficient_attention_backward, _memory_efficient_attention_forward_requires_grad, ) HAS_XFORMERS = True except ImportError: HAS_XFORMERS = False SEQ_ALIGN = 32 SEQ_LIMIT = 16 * 1024 def align_atten_bias(attn_bias): B, N, S, S = attn_bias.shape align_size = 8 if S % align_size != 0: expand_S = (S // align_size + 1) * align_size new_shape = [B, N, S, expand_S] attn_bias = torch.empty(new_shape, dtype=attn_bias.dtype, device=attn_bias.device)[:, :, :, :S].copy_(attn_bias) return attn_bias def _attn_fwd( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attn_bias: Optional[torch.Tensor] = None, scale: Optional[float] = None, ): attn_bias = align_atten_bias(attn_bias) inp = Inputs(q, k, v, attn_bias, p=0, scale=scale, is_partial=False) out, ctx = _memory_efficient_attention_forward_requires_grad(inp, None) S = attn_bias.shape[-2] if ctx.lse.shape[-1] != S: ctx.lse = ctx.lse[:, :, :S] return out, ctx.lse, ctx.rng_state def _attn_bwd( grad: torch.Tensor, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, out: torch.Tensor, lse: torch.Tensor, rng_state: torch.Tensor, attn_bias: Optional[torch.Tensor] = None, scale: Optional[float] = None, ): attn_bias = align_atten_bias(attn_bias) inp = Inputs(q, k, v, attn_bias, p=0, scale=scale, output_dtype=q.dtype, is_partial=False) ctx = Context(lse, out, rng_state=rng_state) grads = _memory_efficient_attention_backward(ctx, inp, grad, None) return grads.dq, grads.dk, grads.dv class MemEfficientRingAttention(torch.autograd.Function): ATTN_DONE: torch.cuda.Event = None SP_STREAM: torch.cuda.Stream = None @staticmethod def forward( ctx, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, sp_group: dist.ProcessGroup, sp_stream: torch.cuda.Stream, softmax_scale: Optional[float] = None, attn_mask: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """Ring attention forward Args: ctx (_type_): self q (torch.Tensor): shape [B, S/P, N, D] k (torch.Tensor): shape [B, S/P, N, D] v (torch.Tensor): shape [B, S/P, N, D] sp_group (dist.ProcessGroup): sequence parallel group sp_stream (torch.cuda.Stream): sequence parallel stream softmax_scale (Optional[float], optional): softmax scale. Defaults to None. attn_mask (Optional[torch.Tensor], optional): attention mask shape [B, N, S/P, S]. Defaults to None. Returns: Tuple[torch.Tensor, torch.Tensor]: output and log sum exp. Output's shape should be [B, S/P, N, D]. LSE's shape should be [B, N, S/P]. """ if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) sp_size = dist.get_world_size(sp_group) sp_rank = dist.get_rank(sp_group) kv_comms: List[RingComm] = [RingComm(sp_group) for _ in range(2)] block_attn_masks = [None] * sp_size if attn_mask is not None: # if attn_mask is splitted, uncomment the following line # attn_mask = attn_mask.chunk(sp_size, dim=2)[sp_rank] block_attn_masks = attn_mask.chunk(sp_size, dim=-1) # [B, S, N, D] q, k, v = [x.contiguous() for x in [q, k, v]] # Pre-allocate double buffer for overlapping and receiving next step's inputs kv_buffers = [torch.stack((k, v))] # (2, B, S, N, D) kv_buffers.append(torch.empty_like(kv_buffers[0])) # outputs out = None block_out = [None, None] softmax_lse = [None, None] block_softmax_lse = [None, None] # log sum exp, the denominator of softmax in attention rng_states = [None for _ in range(sp_size)] sp_streams = [torch.cuda.current_stream(), sp_stream] def _kv_comm(i): # Avoid overwriting attn input when it shares mem with buffer if not MemEfficientRingAttention.ATTN_DONE.query(): kv_buffers[(i + 1) % 2] = torch.empty_like(kv_buffers[i % 2]) if i < sp_size - 1: kv_comms[i % 2].send_recv(kv_buffers[i % 2], kv_buffers[(i + 1) % 2]) block_idx = sp_rank for i in range(sp_size): with torch.cuda.stream(sp_streams[i % 2]): # Wait for current kv from prev rank # NOTE: waiting outside the current stream will NOT correctly synchronize. if i == 0: _kv_comm(i) else: kv_comms[(i + 1) % 2].wait() kv_block = kv_buffers[i % 2] q_block = q block_out[i % 2], block_softmax_lse[i % 2], rng_states[i] = _attn_fwd( q_block, kv_block[0], kv_block[1], attn_bias=block_attn_masks[block_idx], scale=softmax_scale ) MemEfficientRingAttention.ATTN_DONE.record() # Pipeline the next KV comm with output correction instead of the next flash attn # to minimize idle time when comm takes longer than attn. _kv_comm(i + 1) block_softmax_lse[i % 2] = ( block_softmax_lse[i % 2].transpose(1, 2).unsqueeze(-1).contiguous().float() ) # [B, N, S] -> [B, S, N, 1] assert ( block_out[i % 2].shape[:-1] == block_softmax_lse[i % 2].shape[:-1] ), f"{block_out[i % 2].shape} != {block_softmax_lse[i % 2].shape}" # Output and log sum exp correction. Ideally overlap this with the next flash attn kernel. # In reality this always finishes before next flash attn; no need for extra sync. if i == 0: out = block_out[0] softmax_lse = block_softmax_lse[0] else: out, softmax_lse = _rescale_out_lse(out, block_out[i % 2], softmax_lse, block_softmax_lse[i % 2]) block_idx = (block_idx - 1) % sp_size torch.cuda.current_stream().wait_stream(sp_stream) out = out.to(q.dtype) softmax_lse = softmax_lse.squeeze(-1).transpose(1, 2).contiguous() ctx.softmax_scale = softmax_scale ctx.block_attn_masks = block_attn_masks ctx.sp_group = sp_group ctx.save_for_backward(q, k, v, out, softmax_lse, *rng_states) # lse [B, N, S] return out, softmax_lse @staticmethod def backward(ctx, grad_output, grad_softmax_lse): # q, k, v, out: [B, S, N, D], softmax_lse: [B, N, S] q, k, v, out, softmax_lse, *rng_states = ctx.saved_tensors sp_group = ctx.sp_group sp_size = dist.get_world_size(sp_group) kv_comm = RingComm(sp_group) dkv_comm = RingComm(sp_group) grad_output = grad_output.contiguous() kv_buffers = [torch.stack((k, v))] # (2, B, S, N, D) kv_buffers.append(torch.empty_like(kv_buffers[0])) dq = None dkv_buffers = [torch.empty_like(kv, dtype=torch.float) for kv in kv_buffers] del k, v block_idx = dist.get_rank(sp_group) for i in range(sp_size): if i > 0: kv_comm.wait() if i < sp_size - 1: kv_comm.send_recv(kv_buffers[i % 2], kv_buffers[(i + 1) % 2]) k_block, v_block = kv_buffers[i % 2] dq_block, dk_block, dv_block = _context_chunk_attn_bwd( grad_output, q, k_block, v_block, out, softmax_lse, rng_states[i], attn_bias=ctx.block_attn_masks[block_idx], scale=ctx.softmax_scale, ) if i == 0: dq = dq_block.float() dkv_buffers[i % 2][0] = dk_block.float() dkv_buffers[i % 2][1] = dv_block.float() else: dq += dq_block dkv_comm.wait() dkv_buffers[i % 2][0] += dk_block dkv_buffers[i % 2][1] += dv_block dkv_comm.send_recv(dkv_buffers[i % 2], dkv_buffers[(i + 1) % 2]) block_idx = (block_idx - 1) % sp_size dkv_comm.wait() dkv = dkv_buffers[sp_size % 2] dq, dk, dv = [x.to(q.dtype) for x in (dq, *dkv)] torch.cuda.empty_cache() return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None @staticmethod def attention( q, k, v, sp_group, softmax_scale: Optional[float] = None, attn_mask: Optional[torch.Tensor] = None, return_softmax: bool = False, ): """Ring attention Args: q (torch.Tensor): shape [B, S, N, D] k (torch.Tensor): shape [B, S, N, D] v (torch.Tensor): shape [B, S, N, D] sp_group (dist.ProcessGroup): sequence parallel group softmax_scale (Optional[float], optional): softmax scale. Defaults to None. attn_mask (Optional[torch.Tensor], optional): attention mask. Defaults to None. return_softmax (bool, optional): return softmax or not. Defaults to False. Returns: Tuple[torch.Tensor, torch.Tensor]: output and log sum exp. Output's shape should be [B, S, N, D]. LSE's shape should be [B, N, S]. """ if MemEfficientRingAttention.ATTN_DONE is None: MemEfficientRingAttention.ATTN_DONE = torch.cuda.Event() if MemEfficientRingAttention.SP_STREAM is None: MemEfficientRingAttention.SP_STREAM = torch.cuda.Stream() out, softmax_lse = MemEfficientRingAttention.apply( q, k, v, sp_group, MemEfficientRingAttention.SP_STREAM, softmax_scale, attn_mask ) if return_softmax: return out, softmax_lse return out class MemEfficientRingAttnProcessor: def __init__(self, sp_group: dist.ProcessGroup): self.sp_group = sp_group if not HAS_XFORMERS: raise ImportError("MemEfficientRingAttnProcessor requires xformers, to use it, please install xformers.") def __call__( self, attn: Attention, hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None, *args, **kwargs, ) -> torch.Tensor: sp_group = self.sp_group assert sp_group is not None, "sp_group must be provided for MemEfficientRingAttnProcessor" residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) if attention_mask is not None: attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) # scaled_dot_product_attention expects attention_mask shape to be # (batch, heads, source_length, target_length) attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group) query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads query = query.view(batch_size, -1, attn.heads, head_dim) key = key.view(batch_size, -1, attn.heads, head_dim) value = value.view(batch_size, -1, attn.heads, head_dim) assert ( query.shape[1] % dist.get_world_size(sp_group) == 0 ), f"sequence length ({query.shape[1]}) must be divisible by sp_group size ({dist.get_world_size(sp_group)})" hidden_states = MemEfficientRingAttention.attention(query, key, value, sp_group, attn_mask=attention_mask) hidden_states = hidden_states.reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states class ContextParallelAttention: def __init__(self): raise ImportError(f"ContextParallelAttention should not be initialized directly.") @staticmethod def from_native_module(module: Attention, process_group, *args, **kwargs) -> Attention: """ Convert a native RMSNorm module to colossalai layer norm module, and optionally mark parameters for gradient aggregation. Args: module (nn.Module): The native RMSNorm module to be converted. sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism. Returns: nn.Module: The RMSNorm module. """ # Since gradients are computed using only a subset of the data, # aggregation of these gradients is necessary during backpropagation. # Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation. SeqParallelUtils.marked_as_sp_partial_derived_param(module.to_q.weight) SeqParallelUtils.marked_as_sp_partial_derived_param(module.to_k.weight) SeqParallelUtils.marked_as_sp_partial_derived_param(module.to_v.weight) if module.to_q.bias is not None: SeqParallelUtils.marked_as_sp_partial_derived_param(module.to_q.bias) SeqParallelUtils.marked_as_sp_partial_derived_param(module.to_k.bias) SeqParallelUtils.marked_as_sp_partial_derived_param(module.to_v.bias) module.set_processor(MemEfficientRingAttnProcessor(process_group)) return module def _context_chunk_attn_fwd( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attn_bias: Optional[torch.Tensor], scale: Optional[float], seq_align: int = SEQ_ALIGN, seq_limit: int = SEQ_LIMIT, ): seq_len = q.shape[1] n_chunks = get_conv3d_n_chunks(seq_len, seq_align, seq_limit) q_chunks, k_chunks, v_chunks = q.chunk(n_chunks, dim=1), k.chunk(n_chunks, dim=1), v.chunk(n_chunks, dim=1) attn_bias_chunks = attn_bias.chunk(n_chunks, dim=2) if attn_bias is not None else [None] * n_chunks out_chunks = [] lse_chunks = [] rng_states = [] for q_chunk, attn_bias_chunk in zip(q_chunks, attn_bias_chunks): inner_attn_bias_chunks = ( attn_bias_chunk.chunk(n_chunks, dim=3) if attn_bias_chunk is not None else [None] * n_chunks ) out_chunk = None for k_chunk, v_chunk, inner_attn_bias_chunk in zip(k_chunks, v_chunks, inner_attn_bias_chunks): block_out, block_lse, rng_state = _attn_fwd(q_chunk, k_chunk, v_chunk, inner_attn_bias_chunk, scale) block_lse = block_lse.transpose(1, 2).unsqueeze(-1).contiguous().float() # [B, N, S] -> [B, S, N, 1] rng_states.append(rng_state) if out_chunk is None: out_chunk = block_out lse_chunk = block_lse else: out_chunk, lse_chunk = _rescale_out_lse(out_chunk, block_out, lse_chunk, block_lse) lse_chunk = lse_chunk.squeeze(-1).transpose(1, 2).contiguous() # [B, S, N, 1] -> [B, N, S] out_chunks.append(out_chunk) lse_chunks.append(lse_chunk) out = torch.cat(out_chunks, dim=1) lse = torch.cat(lse_chunks, dim=-1) return out, lse, rng_states def _context_chunk_attn_bwd( grad: torch.Tensor, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, out: torch.Tensor, lse: torch.Tensor, rng_states: torch.Tensor, attn_bias: Optional[torch.Tensor] = None, scale: Optional[float] = None, seq_align: int = SEQ_ALIGN, seq_limit: int = SEQ_LIMIT, fast_accum: bool = False, ): seq_len = q.shape[1] n_chunks = get_conv3d_n_chunks(seq_len, seq_align, seq_limit) if n_chunks == 1: return _attn_bwd(grad, q, k, v, out, lse, rng_states, attn_bias, scale) q_chunks, k_chunks, v_chunks = q.chunk(n_chunks, dim=1), k.chunk(n_chunks, dim=1), v.chunk(n_chunks, dim=1) attn_bias_chunks = attn_bias.chunk(n_chunks, dim=2) if attn_bias is not None else [None] * n_chunks out_chunks = out.chunk(n_chunks, dim=1) dout_chunks = grad.chunk(n_chunks, dim=1) lse_chunks = lse.chunk(n_chunks, dim=-1) if rng_states is None: rng_states = [None] * (n_chunks * n_chunks) i = 0 acc_dtype = q.dtype if fast_accum else torch.float dq = torch.zeros_like(q, dtype=acc_dtype) dk = torch.zeros_like(k, dtype=acc_dtype) dv = torch.zeros_like(v, dtype=acc_dtype) dq_chunks = dq.chunk(n_chunks, dim=1) dk_chunks = dk.chunk(n_chunks, dim=1) dv_chunks = dv.chunk(n_chunks, dim=1) for q_idx in range(n_chunks): q_chunk = q_chunks[q_idx] attn_bias_chunk = attn_bias_chunks[q_idx] inner_attn_bias_chunks = ( attn_bias_chunk.chunk(n_chunks, dim=3) if attn_bias_chunk is not None else [None] * n_chunks ) out_chunk = out_chunks[q_idx] dout_chunk = dout_chunks[q_idx] lse_chunk = lse_chunks[q_idx] dq_acc = dq_chunks[q_idx] for kv_idx in range(n_chunks): k_chunk = k_chunks[kv_idx] v_chunk = v_chunks[kv_idx] inner_attn_bias_chunk = inner_attn_bias_chunks[kv_idx] dk_acc = dk_chunks[kv_idx] dv_acc = dv_chunks[kv_idx] block_dq, block_dk, block_dv = _attn_bwd( dout_chunk, q_chunk, k_chunk, v_chunk, out_chunk, lse_chunk, rng_states[i], inner_attn_bias_chunk, scale ) dq_acc += block_dq dk_acc += block_dk dv_acc += block_dv i += 1 return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype) def prepare_parallel_causal_attention_mask( parallel_rank: int, parallel_size: int, n_frame: int, n_hw: int, dtype, device, batch_size: int = None ): seq_len = n_frame * n_hw assert seq_len % parallel_size == 0, f"seq_len {seq_len} must be divisible by parallel_size {parallel_size}" local_seq_len = seq_len // parallel_size local_seq_start = local_seq_len * parallel_rank if dtype is torch.bfloat16: # A trick to avoid nan of memory efficient attention, maybe introduce some bias fmin = torch.finfo(torch.float16).min else: fmin = torch.finfo(dtype).min mask = torch.full((local_seq_len, seq_len), fmin, dtype=dtype, device=device) for i in range(local_seq_len): i_frame = (i + local_seq_start) // n_hw mask[i, : (i_frame + 1) * n_hw] = 0 if batch_size is not None: mask = mask.unsqueeze(0).expand(batch_size, -1, -1) return mask def prepare_parallel_attention_mask( self, hidden_states: torch.Tensor, cp_group: dist.ProcessGroup = None ) -> torch.Tensor: B, C, T, H, W = hidden_states.shape attention_mask = prepare_parallel_causal_attention_mask( dist.get_rank(cp_group), dist.get_world_size(cp_group), T, H * W, hidden_states.dtype, hidden_states.device, batch_size=B, ) return attention_mask class TPUpDecoderBlockCausal3D(UpsampleCausal3D): def __init__( self, channels, out_channels=None, kernel_size=3, bias=True, upsample_factor=(2, 2, 2), tp_group=None, split_input: bool = False, split_output: bool = False, conv_=None, shortcut_=None, ): assert tp_group is not None, "tp_group must be provided" super().__init__(channels, out_channels, kernel_size, bias, upsample_factor) conv = conv_ if conv_ is not None else self.conv.conv self.conv.conv = Conv3dTPRow.from_native_module( conv, tp_group, split_input=split_input, split_output=split_output ) self.tp_group = tp_group tp_size = dist.get_world_size(group=self.tp_group) assert self.channels % tp_size == 0, f"channels {self.channels} must be divisible by tp_size {tp_size}" self.channels = self.channels // tp_size def forward(self, input_tensor): input_tensor = split_forward_gather_backward(input_tensor, 1, self.tp_group) return super().forward(input_tensor) def from_native_module(module: UpsampleCausal3D, process_group, **kwargs): conv = module.conv.conv return TPUpDecoderBlockCausal3D( module.channels, module.out_channels, conv.kernel_size[0], conv.bias is not None, module.upsample_factor, conv_=conv, shortcut_=getattr(module, "shortcut", None), tp_group=process_group, **kwargs, )