sglang_v0.5.2/vision_0.23.0/test/test_extended_models.py

504 lines
18 KiB
Python

import copy
import os
import pickle
import pytest
import test_models as TM
import torch
from common_extended_utils import get_file_size_mb, get_ops
from torchvision import models
from torchvision.models import get_model_weights, Weights, WeightsEnum
from torchvision.models._utils import handle_legacy_interface
from torchvision.models.detection.backbone_utils import mobilenet_backbone, resnet_fpn_backbone
run_if_test_with_extended = pytest.mark.skipif(
os.getenv("PYTORCH_TEST_WITH_EXTENDED", "0") != "1",
reason="Extended tests are disabled by default. Set PYTORCH_TEST_WITH_EXTENDED=1 to run them.",
)
@pytest.mark.parametrize(
"name, model_class",
[
("resnet50", models.ResNet),
("retinanet_resnet50_fpn_v2", models.detection.RetinaNet),
("raft_large", models.optical_flow.RAFT),
("quantized_resnet50", models.quantization.QuantizableResNet),
("lraspp_mobilenet_v3_large", models.segmentation.LRASPP),
("mvit_v1_b", models.video.MViT),
],
)
def test_get_model(name, model_class):
assert isinstance(models.get_model(name), model_class)
@pytest.mark.parametrize(
"name, model_fn",
[
("resnet50", models.resnet50),
("retinanet_resnet50_fpn_v2", models.detection.retinanet_resnet50_fpn_v2),
("raft_large", models.optical_flow.raft_large),
("quantized_resnet50", models.quantization.resnet50),
("lraspp_mobilenet_v3_large", models.segmentation.lraspp_mobilenet_v3_large),
("mvit_v1_b", models.video.mvit_v1_b),
],
)
def test_get_model_builder(name, model_fn):
assert models.get_model_builder(name) == model_fn
@pytest.mark.parametrize(
"name, weight",
[
("resnet50", models.ResNet50_Weights),
("retinanet_resnet50_fpn_v2", models.detection.RetinaNet_ResNet50_FPN_V2_Weights),
("raft_large", models.optical_flow.Raft_Large_Weights),
("quantized_resnet50", models.quantization.ResNet50_QuantizedWeights),
("lraspp_mobilenet_v3_large", models.segmentation.LRASPP_MobileNet_V3_Large_Weights),
("mvit_v1_b", models.video.MViT_V1_B_Weights),
],
)
def test_get_model_weights(name, weight):
assert models.get_model_weights(name) == weight
@pytest.mark.parametrize("copy_fn", [copy.copy, copy.deepcopy])
@pytest.mark.parametrize(
"name",
[
"resnet50",
"retinanet_resnet50_fpn_v2",
"raft_large",
"quantized_resnet50",
"lraspp_mobilenet_v3_large",
"mvit_v1_b",
],
)
def test_weights_copyable(copy_fn, name):
for weights in list(models.get_model_weights(name)):
# It is somewhat surprising that (deep-)copying is an identity operation here, but this is the default behavior
# of enums: https://docs.python.org/3/howto/enum.html#enum-members-aka-instances
# Checking for equality, i.e. `==`, is sufficient (and even preferable) for our use case, should we need to drop
# support for the identity operation in the future.
assert copy_fn(weights) is weights
@pytest.mark.parametrize(
"name",
[
"resnet50",
"retinanet_resnet50_fpn_v2",
"raft_large",
"quantized_resnet50",
"lraspp_mobilenet_v3_large",
"mvit_v1_b",
],
)
def test_weights_deserializable(name):
for weights in list(models.get_model_weights(name)):
# It is somewhat surprising that deserialization is an identity operation here, but this is the default behavior
# of enums: https://docs.python.org/3/howto/enum.html#enum-members-aka-instances
# Checking for equality, i.e. `==`, is sufficient (and even preferable) for our use case, should we need to drop
# support for the identity operation in the future.
assert pickle.loads(pickle.dumps(weights)) is weights
def get_models_from_module(module):
return [
v.__name__
for k, v in module.__dict__.items()
if callable(v) and k[0].islower() and k[0] != "_" and k not in models._api.__all__
]
@pytest.mark.parametrize(
"module", [models, models.detection, models.quantization, models.segmentation, models.video, models.optical_flow]
)
def test_list_models(module):
a = set(get_models_from_module(module))
b = {x.replace("quantized_", "") for x in models.list_models(module)}
assert len(b) > 0
assert a == b
@pytest.mark.parametrize(
"include_filters",
[
None,
[],
(),
"",
"*resnet*",
["*alexnet*"],
"*not-existing-model-for-test?",
["*resnet*", "*alexnet*"],
["*resnet*", "*alexnet*", "*not-existing-model-for-test?"],
("*resnet*", "*alexnet*"),
{"*resnet*", "*alexnet*"},
],
)
@pytest.mark.parametrize(
"exclude_filters",
[
None,
[],
(),
"",
"*resnet*",
["*alexnet*"],
["*not-existing-model-for-test?"],
["resnet34", "*not-existing-model-for-test?"],
["resnet34", "*resnet1*"],
("resnet34", "*resnet1*"),
{"resnet34", "*resnet1*"},
],
)
def test_list_models_filters(include_filters, exclude_filters):
actual = set(models.list_models(models, include=include_filters, exclude=exclude_filters))
classification_models = set(get_models_from_module(models))
if isinstance(include_filters, str):
include_filters = [include_filters]
if isinstance(exclude_filters, str):
exclude_filters = [exclude_filters]
if include_filters:
expected = set()
for include_f in include_filters:
include_f = include_f.strip("*?")
expected = expected | {x for x in classification_models if include_f in x}
else:
expected = classification_models
if exclude_filters:
for exclude_f in exclude_filters:
exclude_f = exclude_f.strip("*?")
if exclude_f != "":
a_exclude = {x for x in classification_models if exclude_f in x}
expected = expected - a_exclude
assert expected == actual
@pytest.mark.parametrize(
"name, weight",
[
("ResNet50_Weights.IMAGENET1K_V1", models.ResNet50_Weights.IMAGENET1K_V1),
("ResNet50_Weights.DEFAULT", models.ResNet50_Weights.IMAGENET1K_V2),
(
"ResNet50_QuantizedWeights.DEFAULT",
models.quantization.ResNet50_QuantizedWeights.IMAGENET1K_FBGEMM_V2,
),
(
"ResNet50_QuantizedWeights.IMAGENET1K_FBGEMM_V1",
models.quantization.ResNet50_QuantizedWeights.IMAGENET1K_FBGEMM_V1,
),
],
)
def test_get_weight(name, weight):
assert models.get_weight(name) == weight
@pytest.mark.parametrize(
"model_fn",
TM.list_model_fns(models)
+ TM.list_model_fns(models.detection)
+ TM.list_model_fns(models.quantization)
+ TM.list_model_fns(models.segmentation)
+ TM.list_model_fns(models.video)
+ TM.list_model_fns(models.optical_flow),
)
def test_naming_conventions(model_fn):
weights_enum = get_model_weights(model_fn)
assert weights_enum is not None
assert len(weights_enum) == 0 or hasattr(weights_enum, "DEFAULT")
detection_models_input_dims = {
"fasterrcnn_mobilenet_v3_large_320_fpn": (320, 320),
"fasterrcnn_mobilenet_v3_large_fpn": (800, 800),
"fasterrcnn_resnet50_fpn": (800, 800),
"fasterrcnn_resnet50_fpn_v2": (800, 800),
"fcos_resnet50_fpn": (800, 800),
"keypointrcnn_resnet50_fpn": (1333, 1333),
"maskrcnn_resnet50_fpn": (800, 800),
"maskrcnn_resnet50_fpn_v2": (800, 800),
"retinanet_resnet50_fpn": (800, 800),
"retinanet_resnet50_fpn_v2": (800, 800),
"ssd300_vgg16": (300, 300),
"ssdlite320_mobilenet_v3_large": (320, 320),
}
@pytest.mark.parametrize(
"model_fn",
TM.list_model_fns(models)
+ TM.list_model_fns(models.detection)
+ TM.list_model_fns(models.quantization)
+ TM.list_model_fns(models.segmentation)
+ TM.list_model_fns(models.video)
+ TM.list_model_fns(models.optical_flow),
)
@run_if_test_with_extended
def test_schema_meta_validation(model_fn):
if model_fn.__name__ == "maskrcnn_resnet50_fpn_v2":
pytest.skip(reason="FIXME https://github.com/pytorch/vision/issues/7349")
# list of all possible supported high-level fields for weights meta-data
permitted_fields = {
"backend",
"categories",
"keypoint_names",
"license",
"_metrics",
"min_size",
"min_temporal_size",
"num_params",
"recipe",
"unquantized",
"_docs",
"_ops",
"_file_size",
}
# mandatory fields for each computer vision task
classification_fields = {"categories", ("_metrics", "ImageNet-1K", "acc@1"), ("_metrics", "ImageNet-1K", "acc@5")}
defaults = {
"all": {"_metrics", "min_size", "num_params", "recipe", "_docs", "_file_size", "_ops"},
"models": classification_fields,
"detection": {"categories", ("_metrics", "COCO-val2017", "box_map")},
"quantization": classification_fields | {"backend", "unquantized"},
"segmentation": {
"categories",
("_metrics", "COCO-val2017-VOC-labels", "miou"),
("_metrics", "COCO-val2017-VOC-labels", "pixel_acc"),
},
"video": {"categories", ("_metrics", "Kinetics-400", "acc@1"), ("_metrics", "Kinetics-400", "acc@5")},
"optical_flow": set(),
}
model_name = model_fn.__name__
module_name = model_fn.__module__.split(".")[-2]
expected_fields = defaults["all"] | defaults[module_name]
weights_enum = get_model_weights(model_fn)
if len(weights_enum) == 0:
pytest.skip(f"Model '{model_name}' doesn't have any pre-trained weights.")
problematic_weights = {}
incorrect_meta = []
bad_names = []
for w in weights_enum:
actual_fields = set(w.meta.keys())
actual_fields |= {
("_metrics", dataset, metric_key)
for dataset in w.meta.get("_metrics", {}).keys()
for metric_key in w.meta.get("_metrics", {}).get(dataset, {}).keys()
}
missing_fields = expected_fields - actual_fields
unsupported_fields = set(w.meta.keys()) - permitted_fields
if missing_fields or unsupported_fields:
problematic_weights[w] = {"missing": missing_fields, "unsupported": unsupported_fields}
if w == weights_enum.DEFAULT or any(w.meta[k] != weights_enum.DEFAULT.meta[k] for k in ["num_params", "_ops"]):
if module_name == "quantization":
# parameters() count doesn't work well with quantization, so we check against the non-quantized
unquantized_w = w.meta.get("unquantized")
if unquantized_w is not None:
if w.meta.get("num_params") != unquantized_w.meta.get("num_params"):
incorrect_meta.append((w, "num_params"))
# the methodology for quantized ops count doesn't work as well, so we take unquantized FLOPs
# instead
if w.meta["_ops"] != unquantized_w.meta.get("_ops"):
incorrect_meta.append((w, "_ops"))
else:
# loading the model and using it for parameter and ops verification
model = model_fn(weights=w)
if w.meta.get("num_params") != sum(p.numel() for p in model.parameters()):
incorrect_meta.append((w, "num_params"))
kwargs = {}
if model_name in detection_models_input_dims:
# detection models have non default height and width
height, width = detection_models_input_dims[model_name]
kwargs = {"height": height, "width": width}
if not model_fn.__name__.startswith("vit"):
# FIXME: https://github.com/pytorch/vision/issues/7871
calculated_ops = get_ops(model=model, weight=w, **kwargs)
if calculated_ops != w.meta["_ops"]:
incorrect_meta.append((w, "_ops"))
if not w.name.isupper():
bad_names.append(w)
if get_file_size_mb(w) != w.meta.get("_file_size"):
incorrect_meta.append((w, "_file_size"))
assert not problematic_weights
assert not incorrect_meta
assert not bad_names
@pytest.mark.parametrize(
"model_fn",
TM.list_model_fns(models)
+ TM.list_model_fns(models.detection)
+ TM.list_model_fns(models.quantization)
+ TM.list_model_fns(models.segmentation)
+ TM.list_model_fns(models.video)
+ TM.list_model_fns(models.optical_flow),
)
@run_if_test_with_extended
def test_transforms_jit(model_fn):
model_name = model_fn.__name__
weights_enum = get_model_weights(model_fn)
if len(weights_enum) == 0:
pytest.skip(f"Model '{model_name}' doesn't have any pre-trained weights.")
defaults = {
"models": {
"input_shape": (1, 3, 224, 224),
},
"detection": {
"input_shape": (3, 300, 300),
},
"quantization": {
"input_shape": (1, 3, 224, 224),
},
"segmentation": {
"input_shape": (1, 3, 520, 520),
},
"video": {
"input_shape": (1, 3, 4, 112, 112),
},
"optical_flow": {
"input_shape": (1, 3, 128, 128),
},
}
module_name = model_fn.__module__.split(".")[-2]
kwargs = {**defaults[module_name], **TM._model_params.get(model_name, {})}
input_shape = kwargs.pop("input_shape")
x = torch.rand(input_shape)
if module_name == "optical_flow":
args = (x, x)
else:
if module_name == "video":
x = x.permute(0, 2, 1, 3, 4)
args = (x,)
problematic_weights = []
for w in weights_enum:
transforms = w.transforms()
try:
TM._check_jit_scriptable(transforms, args)
except Exception:
problematic_weights.append(w)
assert not problematic_weights
# With this filter, every unexpected warning will be turned into an error
@pytest.mark.filterwarnings("error")
class TestHandleLegacyInterface:
class ModelWeights(WeightsEnum):
Sentinel = Weights(url="https://pytorch.org", transforms=lambda x: x, meta=dict())
@pytest.mark.parametrize(
"kwargs",
[
pytest.param(dict(), id="empty"),
pytest.param(dict(weights=None), id="None"),
pytest.param(dict(weights=ModelWeights.Sentinel), id="Weights"),
],
)
def test_no_warn(self, kwargs):
@handle_legacy_interface(weights=("pretrained", self.ModelWeights.Sentinel))
def builder(*, weights=None):
pass
builder(**kwargs)
@pytest.mark.parametrize("pretrained", (True, False))
def test_pretrained_pos(self, pretrained):
@handle_legacy_interface(weights=("pretrained", self.ModelWeights.Sentinel))
def builder(*, weights=None):
pass
with pytest.warns(UserWarning, match="positional"):
builder(pretrained)
@pytest.mark.parametrize("pretrained", (True, False))
def test_pretrained_kw(self, pretrained):
@handle_legacy_interface(weights=("pretrained", self.ModelWeights.Sentinel))
def builder(*, weights=None):
pass
with pytest.warns(UserWarning, match="deprecated"):
builder(pretrained)
@pytest.mark.parametrize("pretrained", (True, False))
@pytest.mark.parametrize("positional", (True, False))
def test_equivalent_behavior_weights(self, pretrained, positional):
@handle_legacy_interface(weights=("pretrained", self.ModelWeights.Sentinel))
def builder(*, weights=None):
pass
args, kwargs = ((pretrained,), dict()) if positional else ((), dict(pretrained=pretrained))
with pytest.warns(UserWarning, match=f"weights={self.ModelWeights.Sentinel if pretrained else None}"):
builder(*args, **kwargs)
def test_multi_params(self):
weights_params = ("weights", "weights_other")
pretrained_params = [param.replace("weights", "pretrained") for param in weights_params]
@handle_legacy_interface(
**{
weights_param: (pretrained_param, self.ModelWeights.Sentinel)
for weights_param, pretrained_param in zip(weights_params, pretrained_params)
}
)
def builder(*, weights=None, weights_other=None):
pass
for pretrained_param in pretrained_params:
with pytest.warns(UserWarning, match="deprecated"):
builder(**{pretrained_param: True})
def test_default_callable(self):
@handle_legacy_interface(
weights=(
"pretrained",
lambda kwargs: self.ModelWeights.Sentinel if kwargs["flag"] else None,
)
)
def builder(*, weights=None, flag):
pass
with pytest.warns(UserWarning, match="deprecated"):
builder(pretrained=True, flag=True)
with pytest.raises(ValueError, match="weights"):
builder(pretrained=True, flag=False)
@pytest.mark.parametrize(
"model_fn",
[fn for fn in TM.list_model_fns(models) if fn.__name__ not in {"vit_h_14", "regnet_y_128gf"}]
+ TM.list_model_fns(models.detection)
+ TM.list_model_fns(models.quantization)
+ TM.list_model_fns(models.segmentation)
+ TM.list_model_fns(models.video)
+ TM.list_model_fns(models.optical_flow)
+ [
lambda pretrained: resnet_fpn_backbone(backbone_name="resnet50", pretrained=pretrained),
lambda pretrained: mobilenet_backbone(backbone_name="mobilenet_v2", fpn=False, pretrained=pretrained),
],
)
@run_if_test_with_extended
def test_pretrained_deprecation(self, model_fn):
with pytest.warns(UserWarning, match="deprecated"):
model_fn(pretrained=True)