sglang_v0.5.2/vision_0.23.0/references/similarity/model.py

17 lines
395 B
Python

import torch.nn as nn
import torchvision.models as models
class EmbeddingNet(nn.Module):
def __init__(self, backbone=None):
super().__init__()
if backbone is None:
backbone = models.resnet50(num_classes=128)
self.backbone = backbone
def forward(self, x):
x = self.backbone(x)
x = nn.functional.normalize(x, dim=1)
return x