121 lines
4.3 KiB
C++
121 lines
4.3 KiB
C++
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
|
#include <ATen/core/Tensor.h>
|
|
#include <c10/util/irange.h>
|
|
#include <tuple>
|
|
|
|
#ifndef AT_PER_OPERATOR_HEADERS
|
|
#include <ATen/Functions.h>
|
|
#include <ATen/NativeFunctions.h>
|
|
#else
|
|
#include <ATen/ops/conv_tbc_backward_native.h>
|
|
#include <ATen/ops/conv_tbc_native.h>
|
|
#include <ATen/ops/empty.h>
|
|
#include <ATen/ops/zeros_like.h>
|
|
#endif
|
|
|
|
namespace at::native {
|
|
|
|
Tensor conv_tbc(const Tensor& self, const Tensor& weight, const Tensor& bias, int64_t pad) {
|
|
TORCH_CHECK(self.dim() == 3, "Input must have 3 dims: time, batch, "
|
|
"in_channel");
|
|
TORCH_CHECK(weight.dim() == 3, "Weight tensor must have 3 dims: kernel_width,"
|
|
" in_channels, out_channels.");
|
|
TORCH_CHECK(bias.dim() == 1, "Bias must be 1-D");
|
|
|
|
auto input_size = self.sizes();
|
|
auto weight_size = weight.sizes();
|
|
|
|
auto ilen = input_size[0];
|
|
auto batchSize = input_size[1];
|
|
auto inputPlanes = input_size[2];
|
|
auto outputPlanes = weight_size[2];
|
|
auto kw = weight_size[0];
|
|
auto olen = input_size[0] - kw + 1 + pad * 2;
|
|
auto real_pad = (olen - ilen + kw - 1) / 2;
|
|
|
|
// Make sure shapes are correct.
|
|
// Input = (time, batch, in_channels)
|
|
// Weight = (kernel_width, in_channels, out_channels)
|
|
// Bias = (out_channels)
|
|
TORCH_CHECK(inputPlanes == weight_size[1], "Input dim 2 (input channels) "
|
|
"is not == dim 1 in the weight tensor");
|
|
TORCH_CHECK(weight_size[2] == bias.sizes()[0], "Bias size must equal dim 2 in "
|
|
"the weight tensor (output channels).");
|
|
|
|
// input * weights + bias -> output_features
|
|
Tensor output = at::empty({
|
|
olen,
|
|
input_size[1],
|
|
weight_size[2],
|
|
}, self.options());
|
|
output.copy_(bias.expand(output.sizes()));
|
|
for (const auto k : c10::irange(kw)) {
|
|
int iShift = std::max(0, static_cast<int>(k - real_pad));
|
|
int oShift = std::max(0, static_cast<int>(real_pad - k));
|
|
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
|
|
int t = std::min(ilen + real_pad - k, olen) - oShift;
|
|
// Note: gemm assumes column-major matrices
|
|
// input is l*m (row-major)
|
|
// weight is m*r (row-major)
|
|
// output is l*r (row-major)
|
|
if (t > 0) {
|
|
auto W = weight[k];
|
|
auto I = self.narrow(0, iShift, t).view({t * batchSize, inputPlanes});
|
|
auto O = output.narrow(0, oShift, t).view({t * batchSize, outputPlanes});
|
|
O.addmm_(I, W);
|
|
}
|
|
}
|
|
return output;
|
|
}
|
|
|
|
std::tuple<Tensor, Tensor, Tensor> conv_tbc_backward(const Tensor& dOutput, const Tensor& input, const Tensor& weight, const Tensor& bias, int64_t pad) {
|
|
auto input_size = input.sizes();
|
|
auto weight_size = weight.sizes();
|
|
|
|
auto ilen = input_size[0];
|
|
auto batchSize = input_size[1];
|
|
auto inputPlanes = input_size[2];
|
|
auto outputPlanes = weight_size[2];
|
|
auto kw = weight.sizes()[0];
|
|
auto olen = input_size[0] - kw + 1 + pad * 2;
|
|
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
|
|
int real_pad = (olen - ilen + kw - 1) / 2;
|
|
|
|
Tensor dInput = at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
|
|
for (int k = 0; k < kw; k++) {
|
|
int iShift = std::max(0, k - real_pad);
|
|
int oShift = std::max(0, real_pad - k);
|
|
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
|
|
int t = std::min(ilen + real_pad - k, olen) - oShift;
|
|
// dOutput * T(weight) -> dInput
|
|
if (t > 0) {
|
|
auto dO = dOutput.narrow(0, oShift, t).view({t * batchSize, outputPlanes});
|
|
auto dI = dInput.narrow(0, iShift, t).view({t * batchSize, inputPlanes});
|
|
dI.addmm_(dO, weight[k].t());
|
|
}
|
|
}
|
|
|
|
Tensor dWeight = at::zeros_like(weight, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
|
|
for (int k = 0; k < kw; k++) {
|
|
int iShift = std::max(0, k - real_pad);
|
|
int oShift = std::max(0, real_pad - k);
|
|
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
|
|
int t = std::min(ilen + real_pad - k, olen) - oShift;
|
|
// T(input) * dOutput -> dWeight
|
|
if (t > 0) {
|
|
auto dW = dWeight[k];
|
|
auto dO = dOutput.narrow(0, oShift, t).view({t * batchSize, outputPlanes});
|
|
auto I = input.narrow(0, iShift, t).view({t * batchSize, inputPlanes}).t();
|
|
dW.addmm_(I, dO);
|
|
}
|
|
}
|
|
|
|
Tensor dBias = at::zeros_like(bias, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
|
|
auto tmp = dOutput.sum(0, false);
|
|
dBias.copy_(tmp.sum(0));
|
|
|
|
return std::make_tuple(dInput, dWeight, dBias);
|
|
}
|
|
|
|
} // namespace at::native
|