faiss_rag_enterprise/llama_index/embeddings/adapter_utils.py

180 lines
5.4 KiB
Python

"""Adapter utils."""
import json
import logging
import os
from abc import abstractmethod
from typing import Callable, Dict
import torch
import torch.nn.functional as F
from torch import Tensor, nn
logger = logging.getLogger(__name__)
class BaseAdapter(nn.Module):
"""Base adapter.
Can be subclassed to implement custom adapters.
To implement a custom adapter, subclass this class and implement the
following methods:
- get_config_dict
- forward
"""
@abstractmethod
def get_config_dict(self) -> Dict:
"""Get config dict."""
@abstractmethod
def forward(self, embed: Tensor) -> Tensor:
"""Forward pass."""
def save(self, output_path: str) -> None:
"""Save model."""
os.makedirs(output_path, exist_ok=True)
with open(os.path.join(output_path, "config.json"), "w") as fOut:
json.dump(self.get_config_dict(), fOut)
torch.save(self.state_dict(), os.path.join(output_path, "pytorch_model.bin"))
@classmethod
def load(cls, input_path: str) -> "BaseAdapter":
"""Load model."""
with open(os.path.join(input_path, "config.json")) as fIn:
config = json.load(fIn)
model = cls(**config)
model.load_state_dict(
torch.load(
os.path.join(input_path, "pytorch_model.bin"),
map_location=torch.device("cpu"),
)
)
return model
class LinearLayer(BaseAdapter):
"""Linear transformation.
Args:
in_features (int): Input dimension.
out_features (int): Output dimension.
bias (bool): Whether to use bias. Defaults to False.
"""
def __init__(self, in_features: int, out_features: int, bias: bool = False) -> None:
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.bias = bias
self.linear = nn.Linear(in_features, out_features, bias=bias)
# seed with identity matrix and 0 bias
# only works for square matrices
self.linear.weight.data.copy_(torch.eye(in_features, out_features))
if bias:
self.linear.bias.data.copy_(torch.zeros(out_features))
def forward(self, embed: Tensor) -> Tensor:
"""Forward pass (Wv)."""
return self.linear(embed)
def get_config_dict(self) -> Dict:
return {
"in_features": self.in_features,
"out_features": self.out_features,
"bias": self.bias,
}
def get_activation_function(name: str) -> Callable:
"""Get activation function.
Args:
name (str): Name of activation function.
"""
activations: Dict[str, Callable] = {
"relu": F.relu,
"sigmoid": torch.sigmoid,
"tanh": torch.tanh,
"leaky_relu": F.leaky_relu,
# add more activations here as needed
}
if name not in activations:
raise ValueError(f"Unknown activation function: {name}")
return activations[name]
class TwoLayerNN(BaseAdapter):
"""Two-layer transformation.
Args:
in_features (int): Input dimension.
hidden_features (int): Hidden dimension.
out_features (int): Output dimension.
bias (bool): Whether to use bias. Defaults to False.
activation_fn_str (str): Name of activation function. Defaults to "relu".
"""
def __init__(
self,
in_features: int,
hidden_features: int,
out_features: int,
bias: bool = False,
activation_fn_str: str = "relu",
add_residual: bool = False,
) -> None:
super().__init__()
self.in_features = in_features
self.hidden_features = hidden_features
self.out_features = out_features
self.bias = bias
self.activation_fn_str = activation_fn_str
self.linear1 = nn.Linear(in_features, hidden_features, bias=True)
self.linear2 = nn.Linear(hidden_features, out_features, bias=True)
# self.linear1.weight.data.copy_(torch.zeros(hidden_features, in_features))
# self.linear2.weight.data.copy_(torch.zeros(out_features, hidden_features))
# if bias:
# self.linear1.bias.data.copy_(torch.zeros(hidden_features))
# self.linear2.bias.data.copy_(torch.zeros(out_features))
self._activation_function = get_activation_function(activation_fn_str)
self._add_residual = add_residual
# if add_residual, then add residual_weight (init to 0)
self.residual_weight = nn.Parameter(torch.zeros(1))
def forward(self, embed: Tensor) -> Tensor:
"""Forward pass (Wv).
Args:
embed (Tensor): Input tensor.
"""
output1 = self.linear1(embed)
output1 = self._activation_function(output1)
output2 = self.linear2(output1)
if self._add_residual:
# print(output2)
# print(self.residual_weight)
# print(self.linear2.weight.data)
output2 = self.residual_weight * output2 + embed
return output2
def get_config_dict(self) -> Dict:
"""Get config dict."""
return {
"in_features": self.in_features,
"hidden_features": self.hidden_features,
"out_features": self.out_features,
"bias": self.bias,
"activation_fn_str": self.activation_fn_str,
"add_residual": self._add_residual,
}