50 lines
1.4 KiB
Python
50 lines
1.4 KiB
Python
from enum import Enum
|
|
from typing import TYPE_CHECKING, Union, overload
|
|
|
|
import numpy as np
|
|
|
|
if TYPE_CHECKING:
|
|
import torch
|
|
|
|
|
|
class Pooling(str, Enum):
|
|
"""Enum of possible pooling choices with pooling behaviors."""
|
|
|
|
CLS = "cls"
|
|
MEAN = "mean"
|
|
|
|
def __call__(self, array: np.ndarray) -> np.ndarray:
|
|
if self == self.CLS:
|
|
return self.cls_pooling(array)
|
|
return self.mean_pooling(array)
|
|
|
|
@classmethod
|
|
@overload
|
|
def cls_pooling(cls, array: np.ndarray) -> np.ndarray:
|
|
...
|
|
|
|
@classmethod
|
|
@overload
|
|
# TODO: Remove this `type: ignore` after the false positive problem
|
|
# is addressed in mypy: https://github.com/python/mypy/issues/15683 .
|
|
def cls_pooling(cls, array: "torch.Tensor") -> "torch.Tensor": # type: ignore
|
|
...
|
|
|
|
@classmethod
|
|
def cls_pooling(
|
|
cls, array: "Union[np.ndarray, torch.Tensor]"
|
|
) -> "Union[np.ndarray, torch.Tensor]":
|
|
if len(array.shape) == 3:
|
|
return array[:, 0]
|
|
if len(array.shape) == 2:
|
|
return array[0]
|
|
raise NotImplementedError(f"Unhandled shape {array.shape}.")
|
|
|
|
@classmethod
|
|
def mean_pooling(cls, array: np.ndarray) -> np.ndarray:
|
|
if len(array.shape) == 3:
|
|
return array.mean(axis=1)
|
|
if len(array.shape) == 2:
|
|
return array.mean(axis=0)
|
|
raise NotImplementedError(f"Unhandled shape {array.shape}.")
|