sglang_v0.5.2/vision_0.23.0/test/common_utils.py

542 lines
17 KiB
Python

import contextlib
import functools
import itertools
import os
import pathlib
import random
import re
import shutil
import sys
import tempfile
import warnings
from subprocess import CalledProcessError, check_output, STDOUT
import numpy as np
import PIL
import pytest
import torch
import torch.testing
from torch.testing._comparison import BooleanPair, NonePair, not_close_error_metas, NumberPair, TensorLikePair
from torchvision import io, tv_tensors
from torchvision.transforms._functional_tensor import _max_value as get_max_value
from torchvision.transforms.v2.functional import to_image, to_pil_image
from torchvision.utils import _Image_fromarray
IN_OSS_CI = any(os.getenv(var) == "true" for var in ["CIRCLECI", "GITHUB_ACTIONS"])
IN_RE_WORKER = os.environ.get("INSIDE_RE_WORKER") is not None
IN_FBCODE = os.environ.get("IN_FBCODE_TORCHVISION") == "1"
CUDA_NOT_AVAILABLE_MSG = "CUDA device not available"
MPS_NOT_AVAILABLE_MSG = "MPS device not available"
OSS_CI_GPU_NO_CUDA_MSG = "We're in an OSS GPU machine, and this test doesn't need cuda."
@contextlib.contextmanager
def get_tmp_dir(src=None, **kwargs):
with tempfile.TemporaryDirectory(
**kwargs,
) as tmp_dir:
if src is not None:
shutil.copytree(src, tmp_dir)
yield tmp_dir
def set_rng_seed(seed):
torch.manual_seed(seed)
random.seed(seed)
class MapNestedTensorObjectImpl:
def __init__(self, tensor_map_fn):
self.tensor_map_fn = tensor_map_fn
def __call__(self, object):
if isinstance(object, torch.Tensor):
return self.tensor_map_fn(object)
elif isinstance(object, dict):
mapped_dict = {}
for key, value in object.items():
mapped_dict[self(key)] = self(value)
return mapped_dict
elif isinstance(object, (list, tuple)):
mapped_iter = []
for iter in object:
mapped_iter.append(self(iter))
return mapped_iter if not isinstance(object, tuple) else tuple(mapped_iter)
else:
return object
def map_nested_tensor_object(object, tensor_map_fn):
impl = MapNestedTensorObjectImpl(tensor_map_fn)
return impl(object)
def is_iterable(obj):
try:
iter(obj)
return True
except TypeError:
return False
@contextlib.contextmanager
def freeze_rng_state():
rng_state = torch.get_rng_state()
if torch.cuda.is_available():
cuda_rng_state = torch.cuda.get_rng_state()
yield
if torch.cuda.is_available():
torch.cuda.set_rng_state(cuda_rng_state)
torch.set_rng_state(rng_state)
def cycle_over(objs):
for idx, obj1 in enumerate(objs):
for obj2 in objs[:idx] + objs[idx + 1 :]:
yield obj1, obj2
def int_dtypes():
return (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64)
def float_dtypes():
return (torch.float32, torch.float64)
@contextlib.contextmanager
def disable_console_output():
with contextlib.ExitStack() as stack, open(os.devnull, "w") as devnull:
stack.enter_context(contextlib.redirect_stdout(devnull))
stack.enter_context(contextlib.redirect_stderr(devnull))
yield
def cpu_and_cuda():
import pytest # noqa
return ("cpu", pytest.param("cuda", marks=pytest.mark.needs_cuda))
def cpu_and_cuda_and_mps():
return cpu_and_cuda() + (pytest.param("mps", marks=pytest.mark.needs_mps),)
def needs_cuda(test_func):
import pytest # noqa
return pytest.mark.needs_cuda(test_func)
def needs_mps(test_func):
import pytest # noqa
return pytest.mark.needs_mps(test_func)
def _create_data(height=3, width=3, channels=3, device="cpu"):
# TODO: When all relevant tests are ported to pytest, turn this into a module-level fixture
tensor = torch.randint(0, 256, (channels, height, width), dtype=torch.uint8, device=device)
data = tensor.permute(1, 2, 0).contiguous().cpu().numpy()
mode = "RGB"
if channels == 1:
mode = "L"
data = data[..., 0]
pil_img = _Image_fromarray(data, mode=mode)
return tensor, pil_img
def _create_data_batch(height=3, width=3, channels=3, num_samples=4, device="cpu"):
# TODO: When all relevant tests are ported to pytest, turn this into a module-level fixture
batch_tensor = torch.randint(0, 256, (num_samples, channels, height, width), dtype=torch.uint8, device=device)
return batch_tensor
def get_list_of_videos(tmpdir, num_videos=5, sizes=None, fps=None):
names = []
for i in range(num_videos):
if sizes is None:
size = 5 * (i + 1)
else:
size = sizes[i]
if fps is None:
f = 5
else:
f = fps[i]
data = torch.randint(0, 256, (size, 300, 400, 3), dtype=torch.uint8)
name = os.path.join(tmpdir, f"{i}.mp4")
names.append(name)
io.write_video(name, data, fps=f)
return names
def _assert_equal_tensor_to_pil(tensor, pil_image, msg=None):
# FIXME: this is handled automatically by `assert_equal` below. Let's remove this in favor of it
np_pil_image = np.array(pil_image)
if np_pil_image.ndim == 2:
np_pil_image = np_pil_image[:, :, None]
pil_tensor = torch.as_tensor(np_pil_image.transpose((2, 0, 1)))
if msg is None:
msg = f"tensor:\n{tensor} \ndid not equal PIL tensor:\n{pil_tensor}"
assert_equal(tensor.cpu(), pil_tensor, msg=msg)
def _assert_approx_equal_tensor_to_pil(
tensor, pil_image, tol=1e-5, msg=None, agg_method="mean", allowed_percentage_diff=None
):
# FIXME: this is handled automatically by `assert_close` below. Let's remove this in favor of it
# TODO: we could just merge this into _assert_equal_tensor_to_pil
np_pil_image = np.array(pil_image)
if np_pil_image.ndim == 2:
np_pil_image = np_pil_image[:, :, None]
pil_tensor = torch.as_tensor(np_pil_image.transpose((2, 0, 1))).to(tensor)
if allowed_percentage_diff is not None:
# Assert that less than a given %age of pixels are different
assert (tensor != pil_tensor).to(torch.float).mean() <= allowed_percentage_diff
# error value can be mean absolute error, max abs error
# Convert to float to avoid underflow when computing absolute difference
tensor = tensor.to(torch.float)
pil_tensor = pil_tensor.to(torch.float)
err = getattr(torch, agg_method)(torch.abs(tensor - pil_tensor)).item()
assert err < tol, f"{err} vs {tol}"
def _test_fn_on_batch(batch_tensors, fn, scripted_fn_atol=1e-8, **fn_kwargs):
transformed_batch = fn(batch_tensors, **fn_kwargs)
for i in range(len(batch_tensors)):
img_tensor = batch_tensors[i, ...]
transformed_img = fn(img_tensor, **fn_kwargs)
torch.testing.assert_close(transformed_img, transformed_batch[i, ...], rtol=0, atol=1e-6)
if scripted_fn_atol >= 0:
scripted_fn = torch.jit.script(fn)
# scriptable function test
s_transformed_batch = scripted_fn(batch_tensors, **fn_kwargs)
torch.testing.assert_close(transformed_batch, s_transformed_batch, rtol=1e-5, atol=scripted_fn_atol)
def cache(fn):
"""Similar to :func:`functools.cache` (Python >= 3.8) or :func:`functools.lru_cache` with infinite cache size,
but this also caches exceptions.
"""
sentinel = object()
out_cache = {}
exc_tb_cache = {}
@functools.wraps(fn)
def wrapper(*args, **kwargs):
key = args + tuple(kwargs.values())
out = out_cache.get(key, sentinel)
if out is not sentinel:
return out
exc_tb = exc_tb_cache.get(key, sentinel)
if exc_tb is not sentinel:
raise exc_tb[0].with_traceback(exc_tb[1])
try:
out = fn(*args, **kwargs)
except Exception as exc:
# We need to cache the traceback here as well. Otherwise, each re-raise will add the internal pytest
# traceback frames anew, but they will only be removed once. Thus, the traceback will be ginormous hiding
# the actual information in the noise. See https://github.com/pytest-dev/pytest/issues/10363 for details.
exc_tb_cache[key] = exc, exc.__traceback__
raise exc
out_cache[key] = out
return out
return wrapper
def combinations_grid(**kwargs):
"""Creates a grid of input combinations.
Each element in the returned sequence is a dictionary containing one possible combination as values.
Example:
>>> combinations_grid(foo=("bar", "baz"), spam=("eggs", "ham"))
[
{'foo': 'bar', 'spam': 'eggs'},
{'foo': 'bar', 'spam': 'ham'},
{'foo': 'baz', 'spam': 'eggs'},
{'foo': 'baz', 'spam': 'ham'}
]
"""
return [dict(zip(kwargs.keys(), values)) for values in itertools.product(*kwargs.values())]
class ImagePair(TensorLikePair):
def __init__(
self,
actual,
expected,
*,
mae=False,
**other_parameters,
):
if all(isinstance(input, PIL.Image.Image) for input in [actual, expected]):
actual, expected = (to_image(input) for input in [actual, expected])
super().__init__(actual, expected, **other_parameters)
self.mae = mae
def compare(self) -> None:
actual, expected = self.actual, self.expected
self._compare_attributes(actual, expected)
actual, expected = self._equalize_attributes(actual, expected)
if self.mae:
if actual.dtype is torch.uint8:
actual, expected = actual.to(torch.int), expected.to(torch.int)
mae = float(torch.abs(actual - expected).float().mean())
if mae > self.atol:
self._fail(
AssertionError,
f"The MAE of the images is {mae}, but only {self.atol} is allowed.",
)
else:
super()._compare_values(actual, expected)
def assert_close(
actual,
expected,
*,
allow_subclasses=True,
rtol=None,
atol=None,
equal_nan=False,
check_device=True,
check_dtype=True,
check_layout=True,
check_stride=False,
msg=None,
**kwargs,
):
"""Superset of :func:`torch.testing.assert_close` with support for PIL vs. tensor image comparison"""
__tracebackhide__ = True
error_metas = not_close_error_metas(
actual,
expected,
pair_types=(
NonePair,
BooleanPair,
NumberPair,
ImagePair,
TensorLikePair,
),
allow_subclasses=allow_subclasses,
rtol=rtol,
atol=atol,
equal_nan=equal_nan,
check_device=check_device,
check_dtype=check_dtype,
check_layout=check_layout,
check_stride=check_stride,
**kwargs,
)
if error_metas:
raise error_metas[0].to_error(msg)
assert_equal = functools.partial(assert_close, rtol=0, atol=0)
DEFAULT_SIZE = (17, 11)
NUM_CHANNELS_MAP = {
"GRAY": 1,
"GRAY_ALPHA": 2,
"RGB": 3,
"RGBA": 4,
}
def make_image(
size=DEFAULT_SIZE,
*,
color_space="RGB",
batch_dims=(),
dtype=None,
device="cpu",
memory_format=torch.contiguous_format,
):
num_channels = NUM_CHANNELS_MAP[color_space]
dtype = dtype or torch.uint8
max_value = get_max_value(dtype)
data = torch.testing.make_tensor(
(*batch_dims, num_channels, *size),
low=0,
high=max_value,
dtype=dtype,
device=device,
memory_format=memory_format,
)
if color_space in {"GRAY_ALPHA", "RGBA"}:
data[..., -1, :, :] = max_value
return tv_tensors.Image(data)
def make_image_tensor(*args, **kwargs):
return make_image(*args, **kwargs).as_subclass(torch.Tensor)
def make_image_pil(*args, **kwargs):
return to_pil_image(make_image(*args, **kwargs))
def make_keypoints(canvas_size=DEFAULT_SIZE, *, num_points=4, dtype=None, device="cpu"):
y = torch.randint(0, canvas_size[0], size=(num_points, 1), dtype=dtype, device=device)
x = torch.randint(0, canvas_size[1], size=(num_points, 1), dtype=dtype, device=device)
return tv_tensors.KeyPoints(torch.cat((x, y), dim=-1), canvas_size=canvas_size)
def make_bounding_boxes(
canvas_size=DEFAULT_SIZE,
*,
format=tv_tensors.BoundingBoxFormat.XYXY,
clamping_mode="soft",
num_boxes=1,
dtype=None,
device="cpu",
):
def sample_position(values, max_value):
# We cannot use torch.randint directly here, because it only allows integer scalars as values for low and high.
# However, if we have batch_dims, we need tensors as limits.
return torch.stack([torch.randint(max_value - v, ()) for v in values.tolist()])
if isinstance(format, str):
format = tv_tensors.BoundingBoxFormat[format]
dtype = dtype or torch.float32
h, w = (torch.randint(1, s, (num_boxes,)) for s in canvas_size)
y = sample_position(h, canvas_size[0])
x = sample_position(w, canvas_size[1])
r = -360 * torch.rand((num_boxes,)) + 180
if format is tv_tensors.BoundingBoxFormat.XYWH:
parts = (x, y, w, h)
elif format is tv_tensors.BoundingBoxFormat.XYXY:
x1, y1 = x, y
x2 = x1 + w
y2 = y1 + h
parts = (x1, y1, x2, y2)
elif format is tv_tensors.BoundingBoxFormat.CXCYWH:
cx = x + w / 2
cy = y + h / 2
parts = (cx, cy, w, h)
elif format is tv_tensors.BoundingBoxFormat.XYWHR:
parts = (x, y, w, h, r)
elif format is tv_tensors.BoundingBoxFormat.CXCYWHR:
cx = x + w / 2
cy = y + h / 2
parts = (cx, cy, w, h, r)
elif format is tv_tensors.BoundingBoxFormat.XYXYXYXY:
r_rad = r * torch.pi / 180.0
cos, sin = torch.cos(r_rad), torch.sin(r_rad)
x1 = x
y1 = y
x2 = x1 + w * cos
y2 = y1 - w * sin
x3 = x2 + h * sin
y3 = y2 + h * cos
x4 = x1 + h * sin
y4 = y1 + h * cos
parts = (x1, y1, x2, y2, x3, y3, x4, y4)
else:
raise ValueError(f"Format {format} is not supported")
out_boxes = torch.stack(parts, dim=-1).to(dtype=dtype, device=device)
return tv_tensors.BoundingBoxes(out_boxes, format=format, canvas_size=canvas_size, clamping_mode=clamping_mode)
def make_detection_masks(size=DEFAULT_SIZE, *, num_masks=1, dtype=None, device="cpu"):
"""Make a "detection" mask, i.e. (*, N, H, W), where each object is encoded as one of N boolean masks"""
return tv_tensors.Mask(
torch.testing.make_tensor(
(num_masks, *size),
low=0,
high=2,
dtype=dtype or torch.bool,
device=device,
)
)
def make_segmentation_mask(size=DEFAULT_SIZE, *, num_categories=10, batch_dims=(), dtype=None, device="cpu"):
"""Make a "segmentation" mask, i.e. (*, H, W), where the category is encoded as pixel value"""
return tv_tensors.Mask(
torch.testing.make_tensor(
(*batch_dims, *size),
low=0,
high=num_categories,
dtype=dtype or torch.uint8,
device=device,
)
)
def make_video(size=DEFAULT_SIZE, *, num_frames=3, batch_dims=(), **kwargs):
return tv_tensors.Video(make_image(size, batch_dims=(*batch_dims, num_frames), **kwargs))
def make_video_tensor(*args, **kwargs):
return make_video(*args, **kwargs).as_subclass(torch.Tensor)
def assert_run_python_script(source_code):
"""Utility to check assertions in an independent Python subprocess.
The script provided in the source code should return 0 and not print
anything on stderr or stdout. Modified from scikit-learn test utils.
Args:
source_code (str): The Python source code to execute.
"""
with get_tmp_dir() as root:
path = pathlib.Path(root) / "main.py"
with open(path, "w") as file:
file.write(source_code)
try:
out = check_output([sys.executable, str(path)], stderr=STDOUT)
except CalledProcessError as e:
raise RuntimeError(f"script errored with output:\n{e.output.decode()}")
if out != b"":
raise AssertionError(out.decode())
@contextlib.contextmanager
def assert_no_warnings():
# The name `catch_warnings` is a misnomer as the context manager does **not** catch any warnings, but rather scopes
# the warning filters. All changes that are made to the filters while in this context, will be reset upon exit.
with warnings.catch_warnings():
warnings.simplefilter("error")
yield
@contextlib.contextmanager
def ignore_jit_no_profile_information_warning():
# Calling a scripted object often triggers a warning like
# `UserWarning: operator() profile_node %$INT1 : int[] = prim::profile_ivalue($INT2) does not have profile information`
# with varying `INT1` and `INT2`. Since these are uninteresting for us and only clutter the test summary, we ignore
# them.
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message=re.escape("operator() profile_node %"), category=UserWarning)
yield