180 lines
5.4 KiB
Python
180 lines
5.4 KiB
Python
"""Adapter utils."""
|
|
|
|
import json
|
|
import logging
|
|
import os
|
|
from abc import abstractmethod
|
|
from typing import Callable, Dict
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from torch import Tensor, nn
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class BaseAdapter(nn.Module):
|
|
"""Base adapter.
|
|
|
|
Can be subclassed to implement custom adapters.
|
|
To implement a custom adapter, subclass this class and implement the
|
|
following methods:
|
|
- get_config_dict
|
|
- forward
|
|
|
|
"""
|
|
|
|
@abstractmethod
|
|
def get_config_dict(self) -> Dict:
|
|
"""Get config dict."""
|
|
|
|
@abstractmethod
|
|
def forward(self, embed: Tensor) -> Tensor:
|
|
"""Forward pass."""
|
|
|
|
def save(self, output_path: str) -> None:
|
|
"""Save model."""
|
|
os.makedirs(output_path, exist_ok=True)
|
|
with open(os.path.join(output_path, "config.json"), "w") as fOut:
|
|
json.dump(self.get_config_dict(), fOut)
|
|
torch.save(self.state_dict(), os.path.join(output_path, "pytorch_model.bin"))
|
|
|
|
@classmethod
|
|
def load(cls, input_path: str) -> "BaseAdapter":
|
|
"""Load model."""
|
|
with open(os.path.join(input_path, "config.json")) as fIn:
|
|
config = json.load(fIn)
|
|
model = cls(**config)
|
|
model.load_state_dict(
|
|
torch.load(
|
|
os.path.join(input_path, "pytorch_model.bin"),
|
|
map_location=torch.device("cpu"),
|
|
)
|
|
)
|
|
return model
|
|
|
|
|
|
class LinearLayer(BaseAdapter):
|
|
"""Linear transformation.
|
|
|
|
Args:
|
|
in_features (int): Input dimension.
|
|
out_features (int): Output dimension.
|
|
bias (bool): Whether to use bias. Defaults to False.
|
|
|
|
"""
|
|
|
|
def __init__(self, in_features: int, out_features: int, bias: bool = False) -> None:
|
|
super().__init__()
|
|
self.in_features = in_features
|
|
self.out_features = out_features
|
|
self.bias = bias
|
|
self.linear = nn.Linear(in_features, out_features, bias=bias)
|
|
# seed with identity matrix and 0 bias
|
|
# only works for square matrices
|
|
self.linear.weight.data.copy_(torch.eye(in_features, out_features))
|
|
if bias:
|
|
self.linear.bias.data.copy_(torch.zeros(out_features))
|
|
|
|
def forward(self, embed: Tensor) -> Tensor:
|
|
"""Forward pass (Wv)."""
|
|
return self.linear(embed)
|
|
|
|
def get_config_dict(self) -> Dict:
|
|
return {
|
|
"in_features": self.in_features,
|
|
"out_features": self.out_features,
|
|
"bias": self.bias,
|
|
}
|
|
|
|
|
|
def get_activation_function(name: str) -> Callable:
|
|
"""Get activation function.
|
|
|
|
Args:
|
|
name (str): Name of activation function.
|
|
|
|
"""
|
|
activations: Dict[str, Callable] = {
|
|
"relu": F.relu,
|
|
"sigmoid": torch.sigmoid,
|
|
"tanh": torch.tanh,
|
|
"leaky_relu": F.leaky_relu,
|
|
# add more activations here as needed
|
|
}
|
|
if name not in activations:
|
|
raise ValueError(f"Unknown activation function: {name}")
|
|
return activations[name]
|
|
|
|
|
|
class TwoLayerNN(BaseAdapter):
|
|
"""Two-layer transformation.
|
|
|
|
Args:
|
|
in_features (int): Input dimension.
|
|
hidden_features (int): Hidden dimension.
|
|
out_features (int): Output dimension.
|
|
bias (bool): Whether to use bias. Defaults to False.
|
|
activation_fn_str (str): Name of activation function. Defaults to "relu".
|
|
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
in_features: int,
|
|
hidden_features: int,
|
|
out_features: int,
|
|
bias: bool = False,
|
|
activation_fn_str: str = "relu",
|
|
add_residual: bool = False,
|
|
) -> None:
|
|
super().__init__()
|
|
self.in_features = in_features
|
|
self.hidden_features = hidden_features
|
|
self.out_features = out_features
|
|
self.bias = bias
|
|
self.activation_fn_str = activation_fn_str
|
|
|
|
self.linear1 = nn.Linear(in_features, hidden_features, bias=True)
|
|
self.linear2 = nn.Linear(hidden_features, out_features, bias=True)
|
|
# self.linear1.weight.data.copy_(torch.zeros(hidden_features, in_features))
|
|
# self.linear2.weight.data.copy_(torch.zeros(out_features, hidden_features))
|
|
# if bias:
|
|
# self.linear1.bias.data.copy_(torch.zeros(hidden_features))
|
|
# self.linear2.bias.data.copy_(torch.zeros(out_features))
|
|
|
|
self._activation_function = get_activation_function(activation_fn_str)
|
|
self._add_residual = add_residual
|
|
# if add_residual, then add residual_weight (init to 0)
|
|
self.residual_weight = nn.Parameter(torch.zeros(1))
|
|
|
|
def forward(self, embed: Tensor) -> Tensor:
|
|
"""Forward pass (Wv).
|
|
|
|
Args:
|
|
embed (Tensor): Input tensor.
|
|
|
|
"""
|
|
output1 = self.linear1(embed)
|
|
output1 = self._activation_function(output1)
|
|
output2 = self.linear2(output1)
|
|
|
|
if self._add_residual:
|
|
# print(output2)
|
|
# print(self.residual_weight)
|
|
# print(self.linear2.weight.data)
|
|
output2 = self.residual_weight * output2 + embed
|
|
|
|
return output2
|
|
|
|
def get_config_dict(self) -> Dict:
|
|
"""Get config dict."""
|
|
return {
|
|
"in_features": self.in_features,
|
|
"hidden_features": self.hidden_features,
|
|
"out_features": self.out_features,
|
|
"bias": self.bias,
|
|
"activation_fn_str": self.activation_fn_str,
|
|
"add_residual": self._add_residual,
|
|
}
|