mysora/tests/test_tiled_conv3d.py

53 lines
1.9 KiB
Python

import torch
from opensora.models.layers.tiled_conv3d import TiledConv3d
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = True
def test_tiled_conv3d():
data = torch.rand(1, 128, 51, 256, 256).cuda().to(torch.bfloat16)
exclude_temporal_dim_options = [True, False]
padding_options = [
(0, 0, 0),
(0, 1, 1),
(1, 1, 1),
]
stride_options = [(1, 1, 1), (2, 1, 1)]
kernel_size_options = [(1, 1, 1), (3, 3, 3)]
tile_size = 16
for padding in padding_options:
for stride in stride_options:
for kernel_size in kernel_size_options:
for exclude_temporal_dim in exclude_temporal_dim_options:
conv3d = (
torch.nn.Conv3d(128, 128, kernel_size=kernel_size, stride=stride, padding=padding)
.cuda()
.to(torch.bfloat16)
)
auto_tiled_conv3d = TiledConv3d.from_native_conv3d(
conv3d, tile_size=tile_size, exclude_temporal_dim=exclude_temporal_dim
)
# compare
with torch.inference_mode():
out = conv3d(data)
merged_out = auto_tiled_conv3d(data)
print(f"max allocated: {torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024:.2f} GB")
print(f"max reserved: {torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024:.2f} GB")
try:
torch.testing.assert_close(out, merged_out)
except Exception as e:
print(
f"Failed with padding={padding}, stride={stride}, kernel_size={kernel_size}, exclude_temporal_dim={exclude_temporal_dim}, error: {e}"
)
if __name__ == "__main__":
test_tiled_conv3d()