mysora/opensora/models/hunyuan_vae/distributed.py

581 lines
22 KiB
Python

from typing import List, Optional, Tuple
import torch
import torch.distributed as dist
from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward
from colossalai.shardformer.layer.attn import RingComm, _rescale_out_lse
from colossalai.shardformer.layer.utils import SeqParallelUtils
from diffusers.models.attention_processor import Attention
from opensora.models.vae.tensor_parallel import Conv3dTPRow
from opensora.models.vae.utils import get_conv3d_n_chunks
from .unet_causal_3d_blocks import UpsampleCausal3D
try:
from xformers.ops.fmha import (
Context,
Inputs,
_memory_efficient_attention_backward,
_memory_efficient_attention_forward_requires_grad,
)
HAS_XFORMERS = True
except ImportError:
HAS_XFORMERS = False
SEQ_ALIGN = 32
SEQ_LIMIT = 16 * 1024
def align_atten_bias(attn_bias):
B, N, S, S = attn_bias.shape
align_size = 8
if S % align_size != 0:
expand_S = (S // align_size + 1) * align_size
new_shape = [B, N, S, expand_S]
attn_bias = torch.empty(new_shape, dtype=attn_bias.dtype, device=attn_bias.device)[:, :, :, :S].copy_(attn_bias)
return attn_bias
def _attn_fwd(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
attn_bias: Optional[torch.Tensor] = None,
scale: Optional[float] = None,
):
attn_bias = align_atten_bias(attn_bias)
inp = Inputs(q, k, v, attn_bias, p=0, scale=scale, is_partial=False)
out, ctx = _memory_efficient_attention_forward_requires_grad(inp, None)
S = attn_bias.shape[-2]
if ctx.lse.shape[-1] != S:
ctx.lse = ctx.lse[:, :, :S]
return out, ctx.lse, ctx.rng_state
def _attn_bwd(
grad: torch.Tensor,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
out: torch.Tensor,
lse: torch.Tensor,
rng_state: torch.Tensor,
attn_bias: Optional[torch.Tensor] = None,
scale: Optional[float] = None,
):
attn_bias = align_atten_bias(attn_bias)
inp = Inputs(q, k, v, attn_bias, p=0, scale=scale, output_dtype=q.dtype, is_partial=False)
ctx = Context(lse, out, rng_state=rng_state)
grads = _memory_efficient_attention_backward(ctx, inp, grad, None)
return grads.dq, grads.dk, grads.dv
class MemEfficientRingAttention(torch.autograd.Function):
ATTN_DONE: torch.cuda.Event = None
SP_STREAM: torch.cuda.Stream = None
@staticmethod
def forward(
ctx,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
sp_group: dist.ProcessGroup,
sp_stream: torch.cuda.Stream,
softmax_scale: Optional[float] = None,
attn_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Ring attention forward
Args:
ctx (_type_): self
q (torch.Tensor): shape [B, S/P, N, D]
k (torch.Tensor): shape [B, S/P, N, D]
v (torch.Tensor): shape [B, S/P, N, D]
sp_group (dist.ProcessGroup): sequence parallel group
sp_stream (torch.cuda.Stream): sequence parallel stream
softmax_scale (Optional[float], optional): softmax scale. Defaults to None.
attn_mask (Optional[torch.Tensor], optional): attention mask shape [B, N, S/P, S]. Defaults to None.
Returns:
Tuple[torch.Tensor, torch.Tensor]: output and log sum exp. Output's shape should be [B, S/P, N, D]. LSE's shape should be [B, N, S/P].
"""
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
sp_size = dist.get_world_size(sp_group)
sp_rank = dist.get_rank(sp_group)
kv_comms: List[RingComm] = [RingComm(sp_group) for _ in range(2)]
block_attn_masks = [None] * sp_size
if attn_mask is not None:
# if attn_mask is splitted, uncomment the following line
# attn_mask = attn_mask.chunk(sp_size, dim=2)[sp_rank]
block_attn_masks = attn_mask.chunk(sp_size, dim=-1)
# [B, S, N, D]
q, k, v = [x.contiguous() for x in [q, k, v]]
# Pre-allocate double buffer for overlapping and receiving next step's inputs
kv_buffers = [torch.stack((k, v))] # (2, B, S, N, D)
kv_buffers.append(torch.empty_like(kv_buffers[0]))
# outputs
out = None
block_out = [None, None]
softmax_lse = [None, None]
block_softmax_lse = [None, None] # log sum exp, the denominator of softmax in attention
rng_states = [None for _ in range(sp_size)]
sp_streams = [torch.cuda.current_stream(), sp_stream]
def _kv_comm(i):
# Avoid overwriting attn input when it shares mem with buffer
if not MemEfficientRingAttention.ATTN_DONE.query():
kv_buffers[(i + 1) % 2] = torch.empty_like(kv_buffers[i % 2])
if i < sp_size - 1:
kv_comms[i % 2].send_recv(kv_buffers[i % 2], kv_buffers[(i + 1) % 2])
block_idx = sp_rank
for i in range(sp_size):
with torch.cuda.stream(sp_streams[i % 2]):
# Wait for current kv from prev rank
# NOTE: waiting outside the current stream will NOT correctly synchronize.
if i == 0:
_kv_comm(i)
else:
kv_comms[(i + 1) % 2].wait()
kv_block = kv_buffers[i % 2]
q_block = q
block_out[i % 2], block_softmax_lse[i % 2], rng_states[i] = _attn_fwd(
q_block, kv_block[0], kv_block[1], attn_bias=block_attn_masks[block_idx], scale=softmax_scale
)
MemEfficientRingAttention.ATTN_DONE.record()
# Pipeline the next KV comm with output correction instead of the next flash attn
# to minimize idle time when comm takes longer than attn.
_kv_comm(i + 1)
block_softmax_lse[i % 2] = (
block_softmax_lse[i % 2].transpose(1, 2).unsqueeze(-1).contiguous().float()
) # [B, N, S] -> [B, S, N, 1]
assert (
block_out[i % 2].shape[:-1] == block_softmax_lse[i % 2].shape[:-1]
), f"{block_out[i % 2].shape} != {block_softmax_lse[i % 2].shape}"
# Output and log sum exp correction. Ideally overlap this with the next flash attn kernel.
# In reality this always finishes before next flash attn; no need for extra sync.
if i == 0:
out = block_out[0]
softmax_lse = block_softmax_lse[0]
else:
out, softmax_lse = _rescale_out_lse(out, block_out[i % 2], softmax_lse, block_softmax_lse[i % 2])
block_idx = (block_idx - 1) % sp_size
torch.cuda.current_stream().wait_stream(sp_stream)
out = out.to(q.dtype)
softmax_lse = softmax_lse.squeeze(-1).transpose(1, 2).contiguous()
ctx.softmax_scale = softmax_scale
ctx.block_attn_masks = block_attn_masks
ctx.sp_group = sp_group
ctx.save_for_backward(q, k, v, out, softmax_lse, *rng_states) # lse [B, N, S]
return out, softmax_lse
@staticmethod
def backward(ctx, grad_output, grad_softmax_lse):
# q, k, v, out: [B, S, N, D], softmax_lse: [B, N, S]
q, k, v, out, softmax_lse, *rng_states = ctx.saved_tensors
sp_group = ctx.sp_group
sp_size = dist.get_world_size(sp_group)
kv_comm = RingComm(sp_group)
dkv_comm = RingComm(sp_group)
grad_output = grad_output.contiguous()
kv_buffers = [torch.stack((k, v))] # (2, B, S, N, D)
kv_buffers.append(torch.empty_like(kv_buffers[0]))
dq = None
dkv_buffers = [torch.empty_like(kv, dtype=torch.float) for kv in kv_buffers]
del k, v
block_idx = dist.get_rank(sp_group)
for i in range(sp_size):
if i > 0:
kv_comm.wait()
if i < sp_size - 1:
kv_comm.send_recv(kv_buffers[i % 2], kv_buffers[(i + 1) % 2])
k_block, v_block = kv_buffers[i % 2]
dq_block, dk_block, dv_block = _context_chunk_attn_bwd(
grad_output,
q,
k_block,
v_block,
out,
softmax_lse,
rng_states[i],
attn_bias=ctx.block_attn_masks[block_idx],
scale=ctx.softmax_scale,
)
if i == 0:
dq = dq_block.float()
dkv_buffers[i % 2][0] = dk_block.float()
dkv_buffers[i % 2][1] = dv_block.float()
else:
dq += dq_block
dkv_comm.wait()
dkv_buffers[i % 2][0] += dk_block
dkv_buffers[i % 2][1] += dv_block
dkv_comm.send_recv(dkv_buffers[i % 2], dkv_buffers[(i + 1) % 2])
block_idx = (block_idx - 1) % sp_size
dkv_comm.wait()
dkv = dkv_buffers[sp_size % 2]
dq, dk, dv = [x.to(q.dtype) for x in (dq, *dkv)]
torch.cuda.empty_cache()
return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None
@staticmethod
def attention(
q,
k,
v,
sp_group,
softmax_scale: Optional[float] = None,
attn_mask: Optional[torch.Tensor] = None,
return_softmax: bool = False,
):
"""Ring attention
Args:
q (torch.Tensor): shape [B, S, N, D]
k (torch.Tensor): shape [B, S, N, D]
v (torch.Tensor): shape [B, S, N, D]
sp_group (dist.ProcessGroup): sequence parallel group
softmax_scale (Optional[float], optional): softmax scale. Defaults to None.
attn_mask (Optional[torch.Tensor], optional): attention mask. Defaults to None.
return_softmax (bool, optional): return softmax or not. Defaults to False.
Returns:
Tuple[torch.Tensor, torch.Tensor]: output and log sum exp. Output's shape should be [B, S, N, D]. LSE's shape should be [B, N, S].
"""
if MemEfficientRingAttention.ATTN_DONE is None:
MemEfficientRingAttention.ATTN_DONE = torch.cuda.Event()
if MemEfficientRingAttention.SP_STREAM is None:
MemEfficientRingAttention.SP_STREAM = torch.cuda.Stream()
out, softmax_lse = MemEfficientRingAttention.apply(
q, k, v, sp_group, MemEfficientRingAttention.SP_STREAM, softmax_scale, attn_mask
)
if return_softmax:
return out, softmax_lse
return out
class MemEfficientRingAttnProcessor:
def __init__(self, sp_group: dist.ProcessGroup):
self.sp_group = sp_group
if not HAS_XFORMERS:
raise ImportError("MemEfficientRingAttnProcessor requires xformers, to use it, please install xformers.")
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
temb: Optional[torch.Tensor] = None,
*args,
**kwargs,
) -> torch.Tensor:
sp_group = self.sp_group
assert sp_group is not None, "sp_group must be provided for MemEfficientRingAttnProcessor"
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim)
key = key.view(batch_size, -1, attn.heads, head_dim)
value = value.view(batch_size, -1, attn.heads, head_dim)
assert (
query.shape[1] % dist.get_world_size(sp_group) == 0
), f"sequence length ({query.shape[1]}) must be divisible by sp_group size ({dist.get_world_size(sp_group)})"
hidden_states = MemEfficientRingAttention.attention(query, key, value, sp_group, attn_mask=attention_mask)
hidden_states = hidden_states.reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class ContextParallelAttention:
def __init__(self):
raise ImportError(f"ContextParallelAttention should not be initialized directly.")
@staticmethod
def from_native_module(module: Attention, process_group, *args, **kwargs) -> Attention:
"""
Convert a native RMSNorm module to colossalai layer norm module,
and optionally mark parameters for gradient aggregation.
Args:
module (nn.Module): The native RMSNorm module to be converted.
sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism.
Returns:
nn.Module: The RMSNorm module.
"""
# Since gradients are computed using only a subset of the data,
# aggregation of these gradients is necessary during backpropagation.
# Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation.
SeqParallelUtils.marked_as_sp_partial_derived_param(module.to_q.weight)
SeqParallelUtils.marked_as_sp_partial_derived_param(module.to_k.weight)
SeqParallelUtils.marked_as_sp_partial_derived_param(module.to_v.weight)
if module.to_q.bias is not None:
SeqParallelUtils.marked_as_sp_partial_derived_param(module.to_q.bias)
SeqParallelUtils.marked_as_sp_partial_derived_param(module.to_k.bias)
SeqParallelUtils.marked_as_sp_partial_derived_param(module.to_v.bias)
module.set_processor(MemEfficientRingAttnProcessor(process_group))
return module
def _context_chunk_attn_fwd(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
attn_bias: Optional[torch.Tensor],
scale: Optional[float],
seq_align: int = SEQ_ALIGN,
seq_limit: int = SEQ_LIMIT,
):
seq_len = q.shape[1]
n_chunks = get_conv3d_n_chunks(seq_len, seq_align, seq_limit)
q_chunks, k_chunks, v_chunks = q.chunk(n_chunks, dim=1), k.chunk(n_chunks, dim=1), v.chunk(n_chunks, dim=1)
attn_bias_chunks = attn_bias.chunk(n_chunks, dim=2) if attn_bias is not None else [None] * n_chunks
out_chunks = []
lse_chunks = []
rng_states = []
for q_chunk, attn_bias_chunk in zip(q_chunks, attn_bias_chunks):
inner_attn_bias_chunks = (
attn_bias_chunk.chunk(n_chunks, dim=3) if attn_bias_chunk is not None else [None] * n_chunks
)
out_chunk = None
for k_chunk, v_chunk, inner_attn_bias_chunk in zip(k_chunks, v_chunks, inner_attn_bias_chunks):
block_out, block_lse, rng_state = _attn_fwd(q_chunk, k_chunk, v_chunk, inner_attn_bias_chunk, scale)
block_lse = block_lse.transpose(1, 2).unsqueeze(-1).contiguous().float() # [B, N, S] -> [B, S, N, 1]
rng_states.append(rng_state)
if out_chunk is None:
out_chunk = block_out
lse_chunk = block_lse
else:
out_chunk, lse_chunk = _rescale_out_lse(out_chunk, block_out, lse_chunk, block_lse)
lse_chunk = lse_chunk.squeeze(-1).transpose(1, 2).contiguous() # [B, S, N, 1] -> [B, N, S]
out_chunks.append(out_chunk)
lse_chunks.append(lse_chunk)
out = torch.cat(out_chunks, dim=1)
lse = torch.cat(lse_chunks, dim=-1)
return out, lse, rng_states
def _context_chunk_attn_bwd(
grad: torch.Tensor,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
out: torch.Tensor,
lse: torch.Tensor,
rng_states: torch.Tensor,
attn_bias: Optional[torch.Tensor] = None,
scale: Optional[float] = None,
seq_align: int = SEQ_ALIGN,
seq_limit: int = SEQ_LIMIT,
fast_accum: bool = False,
):
seq_len = q.shape[1]
n_chunks = get_conv3d_n_chunks(seq_len, seq_align, seq_limit)
if n_chunks == 1:
return _attn_bwd(grad, q, k, v, out, lse, rng_states, attn_bias, scale)
q_chunks, k_chunks, v_chunks = q.chunk(n_chunks, dim=1), k.chunk(n_chunks, dim=1), v.chunk(n_chunks, dim=1)
attn_bias_chunks = attn_bias.chunk(n_chunks, dim=2) if attn_bias is not None else [None] * n_chunks
out_chunks = out.chunk(n_chunks, dim=1)
dout_chunks = grad.chunk(n_chunks, dim=1)
lse_chunks = lse.chunk(n_chunks, dim=-1)
if rng_states is None:
rng_states = [None] * (n_chunks * n_chunks)
i = 0
acc_dtype = q.dtype if fast_accum else torch.float
dq = torch.zeros_like(q, dtype=acc_dtype)
dk = torch.zeros_like(k, dtype=acc_dtype)
dv = torch.zeros_like(v, dtype=acc_dtype)
dq_chunks = dq.chunk(n_chunks, dim=1)
dk_chunks = dk.chunk(n_chunks, dim=1)
dv_chunks = dv.chunk(n_chunks, dim=1)
for q_idx in range(n_chunks):
q_chunk = q_chunks[q_idx]
attn_bias_chunk = attn_bias_chunks[q_idx]
inner_attn_bias_chunks = (
attn_bias_chunk.chunk(n_chunks, dim=3) if attn_bias_chunk is not None else [None] * n_chunks
)
out_chunk = out_chunks[q_idx]
dout_chunk = dout_chunks[q_idx]
lse_chunk = lse_chunks[q_idx]
dq_acc = dq_chunks[q_idx]
for kv_idx in range(n_chunks):
k_chunk = k_chunks[kv_idx]
v_chunk = v_chunks[kv_idx]
inner_attn_bias_chunk = inner_attn_bias_chunks[kv_idx]
dk_acc = dk_chunks[kv_idx]
dv_acc = dv_chunks[kv_idx]
block_dq, block_dk, block_dv = _attn_bwd(
dout_chunk, q_chunk, k_chunk, v_chunk, out_chunk, lse_chunk, rng_states[i], inner_attn_bias_chunk, scale
)
dq_acc += block_dq
dk_acc += block_dk
dv_acc += block_dv
i += 1
return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype)
def prepare_parallel_causal_attention_mask(
parallel_rank: int, parallel_size: int, n_frame: int, n_hw: int, dtype, device, batch_size: int = None
):
seq_len = n_frame * n_hw
assert seq_len % parallel_size == 0, f"seq_len {seq_len} must be divisible by parallel_size {parallel_size}"
local_seq_len = seq_len // parallel_size
local_seq_start = local_seq_len * parallel_rank
if dtype is torch.bfloat16:
# A trick to avoid nan of memory efficient attention, maybe introduce some bias
fmin = torch.finfo(torch.float16).min
else:
fmin = torch.finfo(dtype).min
mask = torch.full((local_seq_len, seq_len), fmin, dtype=dtype, device=device)
for i in range(local_seq_len):
i_frame = (i + local_seq_start) // n_hw
mask[i, : (i_frame + 1) * n_hw] = 0
if batch_size is not None:
mask = mask.unsqueeze(0).expand(batch_size, -1, -1)
return mask
def prepare_parallel_attention_mask(
self, hidden_states: torch.Tensor, cp_group: dist.ProcessGroup = None
) -> torch.Tensor:
B, C, T, H, W = hidden_states.shape
attention_mask = prepare_parallel_causal_attention_mask(
dist.get_rank(cp_group),
dist.get_world_size(cp_group),
T,
H * W,
hidden_states.dtype,
hidden_states.device,
batch_size=B,
)
return attention_mask
class TPUpDecoderBlockCausal3D(UpsampleCausal3D):
def __init__(
self,
channels,
out_channels=None,
kernel_size=3,
bias=True,
upsample_factor=(2, 2, 2),
tp_group=None,
split_input: bool = False,
split_output: bool = False,
conv_=None,
shortcut_=None,
):
assert tp_group is not None, "tp_group must be provided"
super().__init__(channels, out_channels, kernel_size, bias, upsample_factor)
conv = conv_ if conv_ is not None else self.conv.conv
self.conv.conv = Conv3dTPRow.from_native_module(
conv, tp_group, split_input=split_input, split_output=split_output
)
self.tp_group = tp_group
tp_size = dist.get_world_size(group=self.tp_group)
assert self.channels % tp_size == 0, f"channels {self.channels} must be divisible by tp_size {tp_size}"
self.channels = self.channels // tp_size
def forward(self, input_tensor):
input_tensor = split_forward_gather_backward(input_tensor, 1, self.tp_group)
return super().forward(input_tensor)
def from_native_module(module: UpsampleCausal3D, process_group, **kwargs):
conv = module.conv.conv
return TPUpDecoderBlockCausal3D(
module.channels,
module.out_channels,
conv.kernel_size[0],
conv.bias is not None,
module.upsample_factor,
conv_=conv,
shortcut_=getattr(module, "shortcut", None),
tp_group=process_group,
**kwargs,
)