sglang_v0.5.2/flashinfer_0.3.1/flashinfer/triton/utils.py

29 lines
785 B
Python

from typing import List
import torch
def check_input(x: torch.Tensor):
assert x.is_cuda, f"{str(x)} must be a CUDA Tensor"
assert x.is_contiguous(), f"{str(x)} must be contiguous"
def check_dim(d, x: torch.Tensor):
assert x.dim() == d, f"{str(x)} must be a {d}D tensor"
def check_shape(a: torch.Tensor, b: torch.Tensor):
assert a.dim() == b.dim(), "tensors should have same dim"
for i in range(a.dim()):
assert a.size(i) == b.size(i), (
f"tensors shape mismatch, {a.size()} and {b.size()}"
)
def check_device(tensors: List[torch.Tensor]):
device = tensors[0].device
for t in tensors:
assert t.device == device, (
f"All tensors should be on the same device, but got {device} and {t.device}"
)