faiss_rag_enterprise/llama_index/embeddings/pooling.py

50 lines
1.4 KiB
Python

from enum import Enum
from typing import TYPE_CHECKING, Union, overload
import numpy as np
if TYPE_CHECKING:
import torch
class Pooling(str, Enum):
"""Enum of possible pooling choices with pooling behaviors."""
CLS = "cls"
MEAN = "mean"
def __call__(self, array: np.ndarray) -> np.ndarray:
if self == self.CLS:
return self.cls_pooling(array)
return self.mean_pooling(array)
@classmethod
@overload
def cls_pooling(cls, array: np.ndarray) -> np.ndarray:
...
@classmethod
@overload
# TODO: Remove this `type: ignore` after the false positive problem
# is addressed in mypy: https://github.com/python/mypy/issues/15683 .
def cls_pooling(cls, array: "torch.Tensor") -> "torch.Tensor": # type: ignore
...
@classmethod
def cls_pooling(
cls, array: "Union[np.ndarray, torch.Tensor]"
) -> "Union[np.ndarray, torch.Tensor]":
if len(array.shape) == 3:
return array[:, 0]
if len(array.shape) == 2:
return array[0]
raise NotImplementedError(f"Unhandled shape {array.shape}.")
@classmethod
def mean_pooling(cls, array: np.ndarray) -> np.ndarray:
if len(array.shape) == 3:
return array.mean(axis=1)
if len(array.shape) == 2:
return array.mean(axis=0)
raise NotImplementedError(f"Unhandled shape {array.shape}.")