29 lines
785 B
Python
29 lines
785 B
Python
from typing import List
|
|
|
|
import torch
|
|
|
|
|
|
def check_input(x: torch.Tensor):
|
|
assert x.is_cuda, f"{str(x)} must be a CUDA Tensor"
|
|
assert x.is_contiguous(), f"{str(x)} must be contiguous"
|
|
|
|
|
|
def check_dim(d, x: torch.Tensor):
|
|
assert x.dim() == d, f"{str(x)} must be a {d}D tensor"
|
|
|
|
|
|
def check_shape(a: torch.Tensor, b: torch.Tensor):
|
|
assert a.dim() == b.dim(), "tensors should have same dim"
|
|
for i in range(a.dim()):
|
|
assert a.size(i) == b.size(i), (
|
|
f"tensors shape mismatch, {a.size()} and {b.size()}"
|
|
)
|
|
|
|
|
|
def check_device(tensors: List[torch.Tensor]):
|
|
device = tensors[0].device
|
|
for t in tensors:
|
|
assert t.device == device, (
|
|
f"All tensors should be on the same device, but got {device} and {t.device}"
|
|
)
|