17 lines
395 B
Python
17 lines
395 B
Python
import torch.nn as nn
|
|
import torchvision.models as models
|
|
|
|
|
|
class EmbeddingNet(nn.Module):
|
|
def __init__(self, backbone=None):
|
|
super().__init__()
|
|
if backbone is None:
|
|
backbone = models.resnet50(num_classes=128)
|
|
|
|
self.backbone = backbone
|
|
|
|
def forward(self, x):
|
|
x = self.backbone(x)
|
|
x = nn.functional.normalize(x, dim=1)
|
|
return x
|