sglang_v0.5.2/vision_0.23.0/references/video_classification/transforms.py

10 lines
230 B
Python

import torch
import torch.nn as nn
class ConvertBCHWtoCBHW(nn.Module):
"""Convert tensor from (B, C, H, W) to (C, B, H, W)"""
def forward(self, vid: torch.Tensor) -> torch.Tensor:
return vid.permute(1, 0, 2, 3)