218 lines
7.2 KiB
Python
218 lines
7.2 KiB
Python
from itertools import product
|
|
|
|
import pytest
|
|
import torch
|
|
from colossalai.accelerator import get_accelerator
|
|
from colossalai.utils import get_current_device
|
|
|
|
from opensora.models.layers.blocks import Attention, split_batch_cat_seq, split_seq_cat_batch
|
|
from opensora.models.layers.rotary_embedding_torch import RotaryEmbedding
|
|
|
|
# B, S, H = 7488, 1, 1152
|
|
# B, S, H = 32, 234, 1152
|
|
B, S, H = 128, 32, 1152
|
|
N, D = 16, 72
|
|
|
|
|
|
def run_attn(enable_flash_attn: bool):
|
|
get_accelerator().reset_peak_memory_stats()
|
|
rope = RotaryEmbedding(D).to(device=get_current_device(), dtype=torch.bfloat16)
|
|
attn = Attention(
|
|
H,
|
|
N,
|
|
qkv_bias=True,
|
|
rope=rope.rotate_queries_or_keys,
|
|
enable_flash_attn=enable_flash_attn,
|
|
).to(device=get_current_device(), dtype=torch.bfloat16)
|
|
x = torch.randn(B, S, H, device=get_current_device(), dtype=torch.bfloat16).requires_grad_()
|
|
y = attn(x)
|
|
y.mean().backward()
|
|
print(f"Peak memory: {get_accelerator().max_memory_allocated() / 1024**2:.2f} MB")
|
|
|
|
|
|
def test_block_transform():
|
|
b, h, w, c = 8, 12, 4, 3
|
|
x = torch.randn(b, h, w, c)
|
|
kernel_sizes = (3, 2)
|
|
dims = (1, 2)
|
|
num_splits = [x.size(d) // k for d, k in zip(dims, kernel_sizes)]
|
|
y = split_seq_cat_batch(x, kernel_sizes, dims)
|
|
z = split_batch_cat_seq(y, b, num_splits, dims)
|
|
assert torch.equal(x, z)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"shape, kernel_sizes",
|
|
[
|
|
[(8, 12, 4, 1), (2, 2, -1)], # divisible + N<B + 3D
|
|
[(1, 5, 2, 5), (2, 2, -1)], # undivisible + N>B + 3D
|
|
],
|
|
)
|
|
@pytest.mark.parametrize("shift_window", [False, True])
|
|
def test_block_attn_nd(shape, kernel_sizes, shift_window):
|
|
hidden_size = 96
|
|
num_heads = 4
|
|
head_dim = hidden_size // num_heads
|
|
rope = RotaryEmbedding(head_dim // 3).to(device=get_current_device(), dtype=torch.bfloat16)
|
|
attn = Attention(
|
|
hidden_size,
|
|
num_heads,
|
|
qkv_bias=True,
|
|
qk_norm=True,
|
|
enable_flash_attn=True,
|
|
rope=rope.rotate_queries_or_keys,
|
|
kernel_size=kernel_sizes,
|
|
shift_window=shift_window,
|
|
).to(device=get_current_device(), dtype=torch.bfloat16)
|
|
# [B, H, W, C]
|
|
x = torch.rand(*shape, hidden_size, device=get_current_device(), dtype=torch.bfloat16).requires_grad_()
|
|
y = attn(x)
|
|
assert x.shape == y.shape
|
|
loss = y.mean()
|
|
loss.backward()
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"shape, kernel_sizes",
|
|
[
|
|
[(8, 12, 4, 1), (2, 2, -1)], # divisible + N<B + 3D
|
|
[(8, 12, 4, 6), (2, 2, -1)], # divisible + N<B + 3D
|
|
[(8, 12, 3, 6), (2, 2, -1)], # divisible + N<B + 3D
|
|
[(8, 90, 60, 13), (8, 8, -1)], # 480p video
|
|
[(8, 160, 90, 1), (8, 8, -1)], # 720p image
|
|
],
|
|
)
|
|
def test_block_attn_3d(shape, kernel_sizes):
|
|
hidden_size = 96
|
|
num_heads = 4
|
|
head_dim = hidden_size // num_heads
|
|
rope = RotaryEmbedding(head_dim // 3).to(device=get_current_device(), dtype=torch.bfloat16)
|
|
attn = Attention(
|
|
hidden_size,
|
|
num_heads,
|
|
qkv_bias=True,
|
|
qk_norm=True,
|
|
enable_flash_attn=False,
|
|
rope=rope.rotate_queries_or_keys,
|
|
kernel_size=kernel_sizes,
|
|
).to(device=get_current_device(), dtype=torch.bfloat16)
|
|
# [B, H, W, T, C]
|
|
x = torch.rand(*shape, hidden_size, device=get_current_device(), dtype=torch.bfloat16)
|
|
y = attn(x)
|
|
|
|
split_size = [k if k > 0 else x.size(i + 1) for i, k in enumerate(kernel_sizes)]
|
|
for start_indices in product(*[range(0, x.size(i + 1), s) for i, s in enumerate(split_size)]):
|
|
piece = x[
|
|
:,
|
|
start_indices[0] : start_indices[0] + split_size[0],
|
|
start_indices[1] : start_indices[1] + split_size[1],
|
|
start_indices[2] : start_indices[2] + split_size[2],
|
|
:,
|
|
]
|
|
piece_z = attn(piece)
|
|
piece_y = y[
|
|
:,
|
|
start_indices[0] : start_indices[0] + split_size[0],
|
|
start_indices[1] : start_indices[1] + split_size[1],
|
|
start_indices[2] : start_indices[2] + split_size[2],
|
|
:,
|
|
]
|
|
assert piece_y.shape == piece_z.shape
|
|
assert torch.equal(
|
|
piece_z,
|
|
piece_y,
|
|
)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"shape, kernel_sizes, kernel_sizes2",
|
|
[
|
|
[(2, 80, 60, 1), (8, 8, 4), (8, 8, -1)], # 720p image
|
|
[(2, 4, 4, 6), (4, 4, -1), (8, 8, -1)], # 720p image
|
|
],
|
|
)
|
|
def test_block_attn_3d_var_kernel(shape, kernel_sizes, kernel_sizes2):
|
|
hidden_size = 24
|
|
num_heads = 2
|
|
head_dim = hidden_size // num_heads
|
|
rope = RotaryEmbedding(head_dim // 3).to(device=get_current_device(), dtype=torch.bfloat16)
|
|
attn = Attention(
|
|
hidden_size,
|
|
num_heads,
|
|
qkv_bias=True,
|
|
qk_norm=True,
|
|
enable_flash_attn=False,
|
|
rope=rope.rotate_queries_or_keys,
|
|
kernel_size=kernel_sizes,
|
|
).to(device=get_current_device(), dtype=torch.bfloat16)
|
|
# [B, H, W, T, C]
|
|
x = torch.rand(*shape, hidden_size, device=get_current_device(), dtype=torch.bfloat16)
|
|
y = attn(x)
|
|
attn2 = Attention(
|
|
hidden_size,
|
|
num_heads,
|
|
qkv_bias=True,
|
|
qk_norm=True,
|
|
enable_flash_attn=False,
|
|
rope=rope.rotate_queries_or_keys,
|
|
kernel_size=kernel_sizes2,
|
|
).to(device=get_current_device(), dtype=torch.bfloat16)
|
|
attn2.load_state_dict(attn.state_dict())
|
|
y2 = attn2(x)
|
|
assert y.shape == y2.shape
|
|
torch.testing.assert_close(y, y2)
|
|
|
|
|
|
def test_block_attn_3d_overlap():
|
|
kernel_sizes = (8, 8, -1)
|
|
hidden_size = 24
|
|
num_heads = 2
|
|
head_dim = hidden_size // num_heads
|
|
rope = RotaryEmbedding(head_dim // 3).to(device=get_current_device(), dtype=torch.bfloat16)
|
|
attn = Attention(
|
|
hidden_size,
|
|
num_heads,
|
|
qkv_bias=True,
|
|
qk_norm=True,
|
|
enable_flash_attn=False,
|
|
rope=rope.rotate_queries_or_keys,
|
|
kernel_size=kernel_sizes,
|
|
).to(device=get_current_device(), dtype=torch.bfloat16)
|
|
# [B, H, W, T, C]
|
|
x = torch.rand(2, 40, 40, 6, hidden_size, device=get_current_device(), dtype=torch.bfloat16)
|
|
y = attn(x)
|
|
x2 = torch.rand(2, 48, 48, 6, hidden_size, device=get_current_device(), dtype=torch.bfloat16)
|
|
x2[:, :40, :40] = x
|
|
y2 = attn(x2)
|
|
torch.testing.assert_close(y, y2[:, :40, :40])
|
|
|
|
|
|
def test_block_transform_3d():
|
|
b, h, w, t, c = 8, 12, 4, 6, 3
|
|
x = torch.randn(b, h, w, t, c)
|
|
kernel_sizes = (3, 2, 6)
|
|
dims = (1, 2, 3)
|
|
num_splits = [x.size(d) // k for d, k in zip(dims, kernel_sizes)]
|
|
y = split_seq_cat_batch(x, kernel_sizes, dims)
|
|
split_size = [k if k > 0 else x.size(i + 1) for i, k in enumerate(kernel_sizes)]
|
|
for i, start_indices in enumerate(product(*[range(0, x.size(i + 1), s) for i, s in enumerate(split_size)])):
|
|
piece = x[
|
|
:,
|
|
start_indices[0] : start_indices[0] + split_size[0],
|
|
start_indices[1] : start_indices[1] + split_size[1],
|
|
start_indices[2] : start_indices[2] + split_size[2],
|
|
:,
|
|
]
|
|
y_piece = y[i * b : (i + 1) * b]
|
|
assert torch.equal(y_piece, piece), f"{y_piece.shape} vs {piece.shape}"
|
|
|
|
z = split_batch_cat_seq(y, b, num_splits, dims)
|
|
assert torch.equal(x, z)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
print("Use flashattn")
|
|
run_attn(True)
|
|
print("No flashattn")
|
|
run_attn(False)
|