125 lines
3.3 KiB
Python
125 lines
3.3 KiB
Python
import torch
|
|
|
|
from . import benchmark
|
|
|
|
|
|
class RNNEltwise(benchmark.Benchmark):
|
|
def __init__(self, mode, device, dtype, b, hs):
|
|
super().__init__(mode, device, dtype)
|
|
self.b = b
|
|
self.hs = hs
|
|
self.input = self.rand(
|
|
[b, 4 * hs], device=device, dtype=dtype, requires_grad=self.requires_grad
|
|
)
|
|
self.hx = self.rand(
|
|
[b, 4 * hs], device=device, dtype=dtype, requires_grad=self.requires_grad
|
|
)
|
|
self.cx = self.rand(
|
|
[b, hs], device=device, dtype=dtype, requires_grad=self.requires_grad
|
|
)
|
|
self.b_ih = self.rand(
|
|
[b, 4 * hs], device=device, dtype=dtype, requires_grad=self.requires_grad
|
|
)
|
|
self.b_hh = self.rand(
|
|
[b, 4 * hs], device=device, dtype=dtype, requires_grad=self.requires_grad
|
|
)
|
|
self.inputs = [
|
|
self.input,
|
|
self.hx,
|
|
self.cx,
|
|
self.b_ih,
|
|
self.b_hh,
|
|
]
|
|
|
|
def forward(self, input, hx, cx, b_ih, b_hh):
|
|
gates = input + hx + b_ih + b_hh
|
|
|
|
ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
|
|
|
|
ingate = torch.sigmoid(ingate)
|
|
forgetgate = torch.sigmoid(forgetgate)
|
|
cellgate = torch.tanh(cellgate)
|
|
outgate = torch.sigmoid(outgate)
|
|
|
|
cy = (forgetgate * cx) + (ingate * cellgate)
|
|
hy = outgate * torch.tanh(cy)
|
|
|
|
return hy, cy
|
|
|
|
def config(self):
|
|
return [self.b, self.hs]
|
|
|
|
@staticmethod
|
|
def module():
|
|
return "rnn_eltwise"
|
|
|
|
def memory_workload(self):
|
|
def memsize(t):
|
|
return t.numel() * t.element_size()
|
|
|
|
input_size = sum(memsize(t) for t in self.inputs)
|
|
output_size = 2 * memsize(self.cx)
|
|
io_size = input_size + output_size
|
|
return {"sol": io_size, "algorithmic": io_size}
|
|
|
|
@staticmethod
|
|
def default_configs():
|
|
return [[64, 512]]
|
|
|
|
|
|
benchmark.register_benchmark_class(RNNEltwise)
|
|
|
|
|
|
class DynamicLSTM(benchmark.DynamicShape, RNNEltwise):
|
|
def __init__(self, mode, device, dtype, b, hs):
|
|
benchmark.DynamicShape.__init__(self)
|
|
RNNEltwise.__init__(self, mode, device, dtype, b, hs)
|
|
|
|
def instantiate_input(self):
|
|
b, hs = self.rand_shape([self.b, self.hs])
|
|
|
|
self.input = self.rand(
|
|
[b, 4 * hs],
|
|
device=self.device,
|
|
dtype=self.dtype,
|
|
requires_grad=self.requires_grad,
|
|
)
|
|
self.hx = self.rand(
|
|
[b, 4 * hs],
|
|
device=self.device,
|
|
dtype=self.dtype,
|
|
requires_grad=self.requires_grad,
|
|
)
|
|
self.cx = self.rand(
|
|
[b, hs],
|
|
device=self.device,
|
|
dtype=self.dtype,
|
|
requires_grad=self.requires_grad,
|
|
)
|
|
self.b_ih = self.rand(
|
|
[b, 4 * hs],
|
|
device=self.device,
|
|
dtype=self.dtype,
|
|
requires_grad=self.requires_grad,
|
|
)
|
|
self.b_hh = self.rand(
|
|
[b, 4 * hs],
|
|
device=self.device,
|
|
dtype=self.dtype,
|
|
requires_grad=self.requires_grad,
|
|
)
|
|
self.inputs = [
|
|
self.input,
|
|
self.hx,
|
|
self.cx,
|
|
self.b_ih,
|
|
self.b_hh,
|
|
]
|
|
|
|
@staticmethod
|
|
def module():
|
|
return "dynamic_lstm"
|
|
|
|
|
|
benchmark.register_benchmark_class(DynamicLSTM)
|