sglang_v0.5.2/vision_0.22.1/examples/cpp/script_model.py

11 lines
326 B
Python

import torch
from torchvision import models
for model, name in (
(models.resnet18(weights=None), "resnet18"),
(models.detection.fasterrcnn_resnet50_fpn(weights=None, weights_backbone=None), "fasterrcnn_resnet50_fpn"),
):
model.eval()
traced_model = torch.jit.script(model)
traced_model.save(f"{name}.pt")