532 lines
12 KiB
C++
532 lines
12 KiB
C++
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
|
#include <ATen/core/Tensor.h>
|
|
#include <ATen/native/Unfold3d.h>
|
|
#include <ATen/Config.h>
|
|
#include <ATen/Dispatch.h>
|
|
#include <ATen/Parallel.h>
|
|
#include <c10/util/irange.h>
|
|
|
|
#if AT_MKL_ENABLED()
|
|
#include <mkl.h>
|
|
#endif // AT_MKL_ENABLED()
|
|
|
|
namespace at::native {
|
|
|
|
namespace {
|
|
|
|
bool IsAGeZeroAndALtB(int64_t a, int64_t b) {
|
|
return static_cast<uint64_t>(a) < static_cast<uint64_t>(b);
|
|
}
|
|
|
|
template <typename T>
|
|
void MatCopy(int64_t M, int64_t N, int64_t lda, int64_t ldb, const T* A, T* B) {
|
|
for (const auto i : c10::irange(M)) {
|
|
std::memcpy(B + i * ldb, A + i * lda, N * sizeof(T));
|
|
}
|
|
}
|
|
|
|
template <typename T>
|
|
void MatCopy(
|
|
int64_t M,
|
|
int64_t N,
|
|
int64_t lda,
|
|
int64_t stridea,
|
|
int64_t ldb,
|
|
int64_t strideb,
|
|
const T* A,
|
|
T* B) {
|
|
for (const auto i : c10::irange(M)) {
|
|
const T* A_ptr = A + i * lda;
|
|
T* B_ptr = B + i * ldb;
|
|
for (const auto j : c10::irange(N)) {
|
|
B_ptr[j * strideb] = A_ptr[j * stridea];
|
|
}
|
|
}
|
|
}
|
|
|
|
// Y += X
|
|
template <typename T>
|
|
void MatAdd(int64_t M, int64_t N, int64_t ldx, int64_t ldy, const T* X, T* Y) {
|
|
for (const auto i : c10::irange(M)) {
|
|
for (const auto j : c10::irange(N)) {
|
|
Y[i * ldy + j] += X[i * ldx + j];
|
|
}
|
|
}
|
|
}
|
|
|
|
// Y += X
|
|
template <typename T>
|
|
void MatAdd(
|
|
int64_t M,
|
|
int64_t N,
|
|
int64_t ldx,
|
|
int64_t stridex,
|
|
int64_t ldy,
|
|
int64_t stridey,
|
|
const T* X,
|
|
T* Y) {
|
|
for (const auto i : c10::irange(M)) {
|
|
for (const auto j : c10::irange(N)) {
|
|
Y[i * ldy + j * stridey] += X[i * ldx + j * stridex];
|
|
}
|
|
}
|
|
}
|
|
|
|
#if AT_MKL_ENABLED()
|
|
|
|
template <>
|
|
void MatCopy<float>(
|
|
int64_t M,
|
|
int64_t N,
|
|
int64_t lda,
|
|
int64_t ldb,
|
|
const float* A,
|
|
float* B) {
|
|
mkl_somatcopy('R', 'N', M, N, 1.0f, A, lda, B, ldb);
|
|
}
|
|
|
|
template <>
|
|
void MatCopy<double>(
|
|
int64_t M,
|
|
int64_t N,
|
|
int64_t lda,
|
|
int64_t ldb,
|
|
const double* A,
|
|
double* B) {
|
|
mkl_domatcopy('R', 'N', M, N, 1.0, A, lda, B, ldb);
|
|
}
|
|
|
|
template <>
|
|
void MatCopy<float>(
|
|
int64_t M,
|
|
int64_t N,
|
|
int64_t lda,
|
|
int64_t stridea,
|
|
int64_t ldb,
|
|
int64_t strideb,
|
|
const float* A,
|
|
float* B) {
|
|
mkl_somatcopy2('R', 'N', M, N, 1.0f, A, lda, stridea, B, ldb, strideb);
|
|
}
|
|
|
|
template <>
|
|
void MatCopy<double>(
|
|
int64_t M,
|
|
int64_t N,
|
|
int64_t lda,
|
|
int64_t stridea,
|
|
int64_t ldb,
|
|
int64_t strideb,
|
|
const double* A,
|
|
double* B) {
|
|
mkl_domatcopy2('R', 'N', M, N, 1.0, A, lda, stridea, B, ldb, strideb);
|
|
}
|
|
|
|
template <>
|
|
void MatAdd<float>(
|
|
int64_t M,
|
|
int64_t N,
|
|
int64_t ldx,
|
|
int64_t ldy,
|
|
const float* X,
|
|
float* Y) {
|
|
mkl_somatadd('R', 'N', 'N', M, N, 1.0f, X, ldx, 1.0f, Y, ldy, Y, ldy);
|
|
}
|
|
|
|
template <>
|
|
void MatAdd<double>(
|
|
int64_t M,
|
|
int64_t N,
|
|
int64_t ldx,
|
|
int64_t ldy,
|
|
const double* X,
|
|
double* Y) {
|
|
mkl_domatadd('R', 'N', 'N', M, N, 1.0, X, ldx, 1.0, Y, ldy, Y, ldy);
|
|
}
|
|
|
|
template <>
|
|
void MatAdd(
|
|
int64_t M,
|
|
int64_t N,
|
|
int64_t ldx,
|
|
int64_t stridex,
|
|
int64_t ldy,
|
|
int64_t stridey,
|
|
const float* X,
|
|
float* Y) {
|
|
for (const auto i : c10::irange(M)) {
|
|
cblas_saxpy(N, 1.0f, X + i * ldx, stridex, Y + i * ldy, stridey);
|
|
}
|
|
}
|
|
|
|
template <>
|
|
void MatAdd(
|
|
int64_t M,
|
|
int64_t N,
|
|
int64_t ldx,
|
|
int64_t stridex,
|
|
int64_t ldy,
|
|
int64_t stridey,
|
|
const double* X,
|
|
double* Y) {
|
|
for (const auto i : c10::irange(M)) {
|
|
cblas_daxpy(N, 1.0, X + i * ldx, stridex, Y + i * ldy, stridey);
|
|
}
|
|
}
|
|
|
|
#endif // AT_MKL_ENABLED()
|
|
|
|
template <typename T>
|
|
void Unfold3dZeroPaddingCopyKernelImpl(
|
|
int64_t C,
|
|
int64_t X_D,
|
|
int64_t X_H,
|
|
int64_t X_W,
|
|
int64_t Y_D,
|
|
int64_t Y_H,
|
|
int64_t Y_W,
|
|
int64_t kernel_d,
|
|
int64_t kernel_h,
|
|
int64_t kernel_w,
|
|
int64_t stride_d,
|
|
int64_t stride_h,
|
|
int64_t stride_w,
|
|
const T* src,
|
|
T* dst) {
|
|
const int64_t n = C * kernel_d * kernel_h * kernel_w;
|
|
const int64_t X_size = X_D * X_H * X_W;
|
|
const int64_t Y_size = Y_D * Y_H * Y_W;
|
|
at::parallel_for(0, n, 0, [=](int64_t begin, int64_t end) {
|
|
for (const auto p : c10::irange(begin, end)) {
|
|
int64_t c = p;
|
|
const int64_t kw = c % kernel_w;
|
|
c /= kernel_w;
|
|
const int64_t kh = c % kernel_h;
|
|
c /= kernel_h;
|
|
const int64_t kd = c % kernel_d;
|
|
c /= kernel_d;
|
|
for (const auto yd : c10::irange(Y_D)) {
|
|
const int64_t xd = yd * stride_d + kd;
|
|
const T* src_ptr = src + c * X_size + xd * X_H * X_W + kh * X_W + kw;
|
|
T* dst_ptr = dst + p * Y_size + yd * Y_H * Y_W;
|
|
if (stride_w == 1) {
|
|
MatCopy<T>(Y_H, Y_W, stride_h * X_W, Y_W, src_ptr, dst_ptr);
|
|
} else {
|
|
MatCopy<T>(
|
|
Y_H, Y_W, stride_h * X_W, stride_w, Y_W, 1, src_ptr, dst_ptr);
|
|
}
|
|
}
|
|
}
|
|
});
|
|
}
|
|
|
|
template <typename T>
|
|
void Unfold3dCopyKernelImpl(
|
|
int64_t C,
|
|
int64_t X_D,
|
|
int64_t X_H,
|
|
int64_t X_W,
|
|
int64_t Y_D,
|
|
int64_t Y_H,
|
|
int64_t Y_W,
|
|
int64_t kernel_d,
|
|
int64_t kernel_h,
|
|
int64_t kernel_w,
|
|
int64_t stride_d,
|
|
int64_t stride_h,
|
|
int64_t stride_w,
|
|
int64_t pad_d,
|
|
int64_t pad_h,
|
|
int64_t pad_w,
|
|
const T* src,
|
|
T* dst) {
|
|
if (pad_d == 0 && pad_h == 0 && pad_w == 0) {
|
|
Unfold3dZeroPaddingCopyKernelImpl<T>(
|
|
C,
|
|
X_D,
|
|
X_H,
|
|
X_W,
|
|
Y_D,
|
|
Y_H,
|
|
Y_W,
|
|
kernel_d,
|
|
kernel_h,
|
|
kernel_w,
|
|
stride_d,
|
|
stride_h,
|
|
stride_w,
|
|
src,
|
|
dst);
|
|
return;
|
|
}
|
|
|
|
const int64_t n = C * kernel_d * kernel_h * kernel_w;
|
|
const int64_t X_size = X_D * X_H * X_W;
|
|
const int64_t Y_size = Y_D * Y_H * Y_W;
|
|
at::parallel_for(0, n, 0, [=](int64_t begin, int64_t end) {
|
|
for (const auto p : c10::irange(begin, end)) {
|
|
int64_t c = p;
|
|
const int64_t kw = c % kernel_w;
|
|
c /= kernel_w;
|
|
const int64_t kh = c % kernel_h;
|
|
c /= kernel_h;
|
|
const int64_t kd = c % kernel_d;
|
|
c /= kernel_d;
|
|
const T* src_ptr = src + c * X_size;
|
|
T* dst_ptr = dst + p * Y_size;
|
|
for (const auto yd : c10::irange(Y_D)) {
|
|
const int64_t xd = yd * stride_d - pad_d + kd;
|
|
if (!IsAGeZeroAndALtB(xd, X_D)) {
|
|
std::memset(dst_ptr + yd * Y_H * Y_W, 0, Y_H * Y_W * sizeof(T));
|
|
continue;
|
|
}
|
|
for (const auto yh : c10::irange(Y_H)) {
|
|
const int64_t xh = yh * stride_h - pad_h + kh;
|
|
if (!IsAGeZeroAndALtB(xh, X_H)) {
|
|
std::memset(
|
|
dst_ptr + yd * Y_H * Y_W + yh * Y_W, 0, Y_W * sizeof(T));
|
|
continue;
|
|
}
|
|
for (const auto yw : c10::irange(Y_W)) {
|
|
const int64_t xw = yw * stride_w - pad_w + kw;
|
|
dst_ptr[yd * Y_H * Y_W + yh * Y_W + yw] = IsAGeZeroAndALtB(xw, X_W)
|
|
? src_ptr[xd * X_H * X_W + xh * X_W + xw]
|
|
: T(0);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
});
|
|
}
|
|
|
|
template <typename T>
|
|
void Unfold3dZeroPaddingAccKernelImpl(
|
|
int64_t C,
|
|
int64_t X_D,
|
|
int64_t X_H,
|
|
int64_t X_W,
|
|
int64_t Y_D,
|
|
int64_t Y_H,
|
|
int64_t Y_W,
|
|
int64_t kernel_d,
|
|
int64_t kernel_h,
|
|
int64_t kernel_w,
|
|
int64_t stride_d,
|
|
int64_t stride_h,
|
|
int64_t stride_w,
|
|
const T* src,
|
|
T* dst) {
|
|
const int64_t X_size = X_D * X_H * X_W;
|
|
const int64_t Y_size = Y_D * Y_H * Y_W;
|
|
const int64_t kernel_size = kernel_d * kernel_h * kernel_w;
|
|
at::parallel_for(0, C, 0, [=](int64_t begin, int64_t end) {
|
|
std::memset(dst + begin * X_size, 0, (end - begin) * X_size * sizeof(T));
|
|
for (const auto c : c10::irange(begin, end)) {
|
|
for (const auto kd : c10::irange(kernel_d)) {
|
|
for (const auto kh : c10::irange(kernel_h)) {
|
|
for (const auto kw : c10::irange(kernel_w)) {
|
|
const int64_t p =
|
|
c * kernel_size + kd * kernel_h * kernel_w + kh * kernel_w + kw;
|
|
for (const auto yd : c10::irange(Y_D)) {
|
|
const int64_t xd = yd * stride_d + kd;
|
|
const T* src_ptr = src + p * Y_size + yd * Y_H * Y_W;
|
|
T* dst_ptr = dst + c * X_size + xd * X_H * X_W + kh * X_W + kw;
|
|
if (stride_w == 1) {
|
|
MatAdd<T>(Y_H, Y_W, Y_W, stride_h * X_W, src_ptr, dst_ptr);
|
|
} else {
|
|
MatAdd<T>(
|
|
Y_H,
|
|
Y_W,
|
|
Y_W,
|
|
1,
|
|
stride_h * X_W,
|
|
stride_w,
|
|
src_ptr,
|
|
dst_ptr);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
});
|
|
}
|
|
|
|
template <typename T>
|
|
void Unfold3dAccKernelImpl(
|
|
int64_t C,
|
|
int64_t X_D,
|
|
int64_t X_H,
|
|
int64_t X_W,
|
|
int64_t Y_D,
|
|
int64_t Y_H,
|
|
int64_t Y_W,
|
|
int64_t kernel_d,
|
|
int64_t kernel_h,
|
|
int64_t kernel_w,
|
|
int64_t stride_d,
|
|
int64_t stride_h,
|
|
int64_t stride_w,
|
|
int64_t pad_d,
|
|
int64_t pad_h,
|
|
int64_t pad_w,
|
|
const T* src,
|
|
T* dst) {
|
|
if (pad_d == 0 && pad_h == 0 && pad_w == 0) {
|
|
Unfold3dZeroPaddingAccKernelImpl<T>(
|
|
C,
|
|
X_D,
|
|
X_H,
|
|
X_W,
|
|
Y_D,
|
|
Y_H,
|
|
Y_W,
|
|
kernel_d,
|
|
kernel_h,
|
|
kernel_w,
|
|
stride_d,
|
|
stride_h,
|
|
stride_w,
|
|
src,
|
|
dst);
|
|
return;
|
|
}
|
|
const int64_t X_size = X_D * X_H * X_W;
|
|
const int64_t Y_size = Y_D * Y_H * Y_W;
|
|
const int64_t kernel_size = kernel_d * kernel_h * kernel_w;
|
|
at::parallel_for(0, C, 0, [=](int64_t begin, int64_t end) {
|
|
std::memset(dst + begin * X_size, 0, (end - begin) * X_size * sizeof(T));
|
|
for (const auto c : c10::irange(begin, end)) {
|
|
T* dst_ptr = dst + c * X_size;
|
|
for (const auto kd : c10::irange(kernel_d)) {
|
|
for (const auto kh : c10::irange(kernel_h)) {
|
|
for (const auto kw : c10::irange(kernel_w)) {
|
|
const int64_t p =
|
|
c * kernel_size + kd * kernel_h * kernel_w + kh * kernel_w + kw;
|
|
const T* src_ptr = src + p * Y_size;
|
|
for (const auto yd : c10::irange(Y_D)) {
|
|
const int64_t xd = yd * stride_d - pad_d + kd;
|
|
if (!IsAGeZeroAndALtB(xd, X_D)) {
|
|
continue;
|
|
}
|
|
for (const auto yh : c10::irange(Y_H)) {
|
|
const int64_t xh = yh * stride_h - pad_h + kh;
|
|
if (!IsAGeZeroAndALtB(xh, X_H)) {
|
|
continue;
|
|
}
|
|
for (const auto yw : c10::irange(Y_W)) {
|
|
const int64_t xw = yw * stride_w - pad_w + kw;
|
|
if (IsAGeZeroAndALtB(xw, X_W)) {
|
|
dst_ptr[xd * X_H * X_W + xh * X_W + xw] +=
|
|
src_ptr[yd * Y_H * Y_W + yh * Y_W + yw];
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
});
|
|
}
|
|
|
|
} // namespace
|
|
|
|
void Unfold3dCopyCPU(
|
|
ScalarType dtype,
|
|
const void *src,
|
|
int64_t C,
|
|
int64_t X_D,
|
|
int64_t X_H,
|
|
int64_t X_W,
|
|
int64_t Y_D,
|
|
int64_t Y_H,
|
|
int64_t Y_W,
|
|
int64_t kernel_d,
|
|
int64_t kernel_h,
|
|
int64_t kernel_w,
|
|
int64_t stride_d,
|
|
int64_t stride_h,
|
|
int64_t stride_w,
|
|
int64_t pad_d,
|
|
int64_t pad_h,
|
|
int64_t pad_w,
|
|
void* dst) {
|
|
AT_DISPATCH_ALL_TYPES_AND2(
|
|
at::ScalarType::BFloat16,
|
|
at::ScalarType::Half,
|
|
dtype,
|
|
"Unfold3dCopyCPU",
|
|
[=, &src]() {
|
|
Unfold3dCopyKernelImpl<scalar_t>(
|
|
C,
|
|
X_D,
|
|
X_H,
|
|
X_W,
|
|
Y_D,
|
|
Y_H,
|
|
Y_W,
|
|
kernel_d,
|
|
kernel_h,
|
|
kernel_w,
|
|
stride_d,
|
|
stride_h,
|
|
stride_w,
|
|
pad_d,
|
|
pad_h,
|
|
pad_w,
|
|
static_cast<const scalar_t*>(src),
|
|
static_cast<scalar_t*>(dst));
|
|
});
|
|
}
|
|
|
|
void Unfold3dAccCPU(
|
|
ScalarType dtype,
|
|
const void *src,
|
|
int64_t C,
|
|
int64_t X_D,
|
|
int64_t X_H,
|
|
int64_t X_W,
|
|
int64_t Y_D,
|
|
int64_t Y_H,
|
|
int64_t Y_W,
|
|
int64_t kernel_d,
|
|
int64_t kernel_h,
|
|
int64_t kernel_w,
|
|
int64_t stride_d,
|
|
int64_t stride_h,
|
|
int64_t stride_w,
|
|
int64_t pad_d,
|
|
int64_t pad_h,
|
|
int64_t pad_w,
|
|
void* dst) {
|
|
AT_DISPATCH_ALL_TYPES_AND2(
|
|
at::ScalarType::BFloat16,
|
|
at::ScalarType::Half,
|
|
dtype,
|
|
"Unfold3dAccCPU",
|
|
[=, &src]() {
|
|
Unfold3dAccKernelImpl<scalar_t>(
|
|
C,
|
|
X_D,
|
|
X_H,
|
|
X_W,
|
|
Y_D,
|
|
Y_H,
|
|
Y_W,
|
|
kernel_d,
|
|
kernel_h,
|
|
kernel_w,
|
|
stride_d,
|
|
stride_h,
|
|
stride_w,
|
|
pad_d,
|
|
pad_h,
|
|
pad_w,
|
|
static_cast<const scalar_t*>(src),
|
|
static_cast<scalar_t*>(dst));
|
|
});
|
|
}
|
|
|
|
} // namespace at::native
|