sglang_v0.5.2/pytorch_2.8.0/aten/src/ATen/native/BlasKernel.cpp

791 lines
25 KiB
C++

#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/Context.h>
#include <ATen/Config.h>
#include <ATen/OpMathType.h>
#include <ATen/Parallel.h>
#include <ATen/cpu/vec/vec.h>
#include <ATen/native/cpu/ReducedPrecisionFloatGemvFastPathKernel.h>
#include <c10/core/ScalarType.h>
#include <c10/macros/Macros.h>
#include <c10/util/Exception.h>
#include <c10/util/Unroll.h>
#include <c10/util/complex.h>
#include <c10/util/irange.h>
#include <algorithm>
#include <climits>
#include <limits>
#if defined(__aarch64__) && !defined(C10_MOBILE)
#include <arm_neon.h>
#include <cpuinfo.h>
#endif
namespace {
/// Wrapper for const_cast<T*> with type-inference.
///
/// Use this to call into APIs that are not const-correct.
template <typename T>
T* remove_const(const T* x) {
return const_cast<T*>(x);
}
} // namespace
#if AT_BUILD_WITH_BLAS()
#ifndef _ARMPL_H
extern "C" double ddot_(int *n, double *x, int *incx, double *y, int *incy);
extern "C" void dscal_(int *n, double *a, double *x, int *incx);
extern "C" void sscal_(int *n, float *a, float *x, int *incx);
extern "C" void dgemv_(char *trans, int *m, int *n, double *alpha, double *a, int *lda, double *x, int *incx, double *beta, double *y, int *incy);
extern "C" void sgemv_(char *trans, int *m, int *n, float *alpha, float *a, int *lda, float *x, int *incx, float *beta, float *y, int *incy);
#endif
#if AT_BLAS_F2C()
# define ffloat double
#else
# define ffloat float
#endif
#if AT_BLAS_USE_CBLAS_DOT()
extern "C" float cblas_sdot(const int n, const float *x, const int incx, const float *y, const int incy);
extern "C" void cblas_cdotu_sub(const int n, const void *x, const int incx, const void *y, const int incy, void *dotu);
extern "C" void cblas_zdotu_sub(const int n, const void *x, const int incx, const void *y, const int incy, void *dotu);
extern "C" void cblas_cdotc_sub(const int n, const void *x, const int incx, const void *y, const int incy, void *dotc);
extern "C" void cblas_zdotc_sub(const int n, const void *x, const int incx, const void *y, const int incy, void *dotc);
#ifndef _ARMPL_H
static inline ffloat sdot_(const int *n, const float *x, const int *incx, const float *y, const int *incy)
{
return cblas_sdot(*n, x, *incx, y, *incy);
}
#endif
static inline void cdotu_(std::complex<float> *res, const int *n, const std::complex<float> *x, const int *incx,
const std::complex<float> *y, const int *incy) {
cblas_cdotu_sub(*n, x, *incx, y, *incy, res);
}
static inline void zdotu_(std::complex<double> *res, const int *n, const std::complex<double> *x, const int *incx,
const std::complex<double> *y, const int *incy) {
cblas_zdotu_sub(*n, x, *incx, y, *incy, res);
}
static inline void cdotc_(std::complex<float> *res, const int *n, const std::complex<float> *x, const int *incx,
const std::complex<float> *y, const int *incy) {
cblas_cdotc_sub(*n, x, *incx, y, *incy, res);
}
static inline void zdotc_(std::complex<double> *res, const int *n, const std::complex<double> *x, const int *incx,
const std::complex<double> *y, const int *incy) {
cblas_zdotc_sub(*n, x, *incx, y, *incy, res);
}
#else
extern "C" ffloat sdot_(int *n, const float *x, int *incx, const float *y, int *incy);
extern "C" void cdotu_(std::complex<float> *res, int *n, const std::complex<float> *x, int *incx, const std::complex<float> *y, int *incy);
extern "C" void zdotu_(std::complex<double> *res, int *n, const std::complex<double> *x, int *incx, const std::complex<double> *y, int *incy);
extern "C" void cdotc_(std::complex<float> *res, int *n, const std::complex<float> *x, int *incx, const std::complex<float> *y, int *incy);
extern "C" void zdotc_(std::complex<double> *res, int *n, const std::complex<double> *x, int *incx, const std::complex<double> *y, int *incy);
#endif // AT_BLAS_USE_CBLAS_DOT
#endif // AT_BUILD_WITH_BLAS
namespace at::native {
#if !defined(C10_MOBILE)
DEFINE_DISPATCH(fp16_gemv_trans_stub);
DEFINE_DISPATCH(bf16_gemv_trans_stub);
DEFINE_DISPATCH(fp16_dot_stub);
DEFINE_DISPATCH(bf16_dot_stub);
#endif // !defined(C10_MOBILE)
namespace blas_impl {
#if !defined(C10_MOBILE)
void fp16_gemv_trans(
const int m,
const int n,
const float alpha,
const Half* a,
const int lda,
const Half* x,
const int incx,
const float beta,
Half* y,
const int incy);
void fp16_gemv_trans(
const int m,
const int n,
const float alpha,
const Half* a,
const int lda,
const Half* x,
const int incx,
const float beta,
Half* y,
const int incy) {
fp16_gemv_trans_stub(kCPU, m, n, alpha, a, lda, x, incx, beta, y, incy);
}
static float fp16_dot(
const int64_t n,
const Half* x,
const int64_t incx,
const Half* y,
const int64_t incy) {
return fp16_dot_stub(kCPU, n, x, incx, y, incy);
}
static float bf16_dot(
const int64_t n,
const BFloat16* x,
const int64_t incx,
const BFloat16* y,
const int64_t incy) {
return bf16_dot_stub(kCPU, n, x, incx, y, incy);
}
#endif // !defined(C10_MOBILE)
#if defined(__aarch64__) && !defined(C10_MOBILE)
#ifdef __ARM_FEATURE_FP16_SCALAR_ARITHMETIC
static void fp16_gemv_notrans_fp16_arith(int m, int n, const float16_t* a, const int lda, const float16_t *x, float16_t *y) {
for (auto j = 0; j < n; j++) {
auto vecCol = vdup_n_f16(x[j]);
const auto* column = a + lda * j;
for (auto i = 0; i < m; i += 4) {
auto yf16 = y + i;
auto matRow = vld1_f16(column + i);
auto resVec = j != 0 ? vld1_f16(yf16) : vdup_n_f16(0);
resVec = vfma_lane_f16(resVec, matRow, vecCol, 0);
vst1_f16(yf16, resVec);
}
}
}
#endif
static void fp16_gemv_notrans_fp32_arith(int m, int n, const float16_t* a, const int lda, const float16_t *x, float16_t *y) {
std::vector<float> sum(m);
for (auto j = 0; j < n; j++) {
auto vecCol = vdup_n_f32(x[j]);
const auto* column = a + lda * j;
for (auto i = 0; i < m; i += 4) {
auto sf32 = sum.data() + i;
auto matRow = vcvt_f32_f16(vld1_f16(column + i));
auto resVec = j != 0 ? vld1q_f32(sf32) : vdupq_n_f32(0);
resVec = vfmaq_lane_f32(resVec, matRow, vecCol, 0);
vst1q_f32(sf32, resVec);
}
}
for (auto i = 0; i < m; i+= 4) {
vst1_f16(y + i, vcvt_f16_f32(vld1q_f32(sum.data() + i)));
}
}
void fp16_gemv_notrans(
const int m,
const int n,
const float alpha,
const Half* a,
const int lda,
const Half* x,
const int incx,
const float beta,
Half* y,
const int incy);
void fp16_gemv_notrans(
const int m,
const int n,
const float alpha,
const Half* a,
const int lda,
const Half* x,
const int incx,
const float beta,
Half* y,
const int incy) {
if (incx == 1 && alpha == 1.0 && beta == 0.0 && m % 4 == 0 && incy == 1) {
#ifdef __ARM_FEATURE_FP16_SCALAR_ARITHMETIC
if (at::globalContext().allowFP16ReductionCPU()) {
return fp16_gemv_notrans_fp16_arith(m, n, reinterpret_cast<const float16_t*>(a), lda, reinterpret_cast<const float16_t*>(x), reinterpret_cast<float16_t*>(y));
}
#endif
return fp16_gemv_notrans_fp32_arith(m, n, reinterpret_cast<const float16_t*>(a), lda, reinterpret_cast<const float16_t*>(x), reinterpret_cast<float16_t*>(y));
}
std::vector<float> sum(m);
for (const auto j : c10::irange(n)) {
const auto* column_ = a + lda * j;
auto z = alpha * x[j * incx];
for (const auto i : c10::irange(m)) {
sum[i] += z * column_[i];
}
}
if (beta == 0.0) {
for (const auto i : c10::irange(m)) {
y[i * incy] = sum[i];
}
} else {
for (const auto i : c10::irange(m)) {
y[i * incy] += sum[i];
}
}
}
#endif // defined(__aarch64__) && !defined(C10_MOBILE)
template <typename scalar_t>
static bool scal_use_fast_path(
[[maybe_unused]] int64_t n,
[[maybe_unused]] int64_t incx) {
return false;
}
template <typename scalar_t>
static bool gemv_use_fast_path(
[[maybe_unused]] char trans,
[[maybe_unused]] int64_t m,
[[maybe_unused]] int64_t n,
[[maybe_unused]] scalar_t alpha,
[[maybe_unused]] int64_t lda,
[[maybe_unused]] int64_t incx,
[[maybe_unused]] scalar_t beta,
[[maybe_unused]] int64_t incy) {
return false;
}
template <typename scalar_t>
static void scal_fast_path(
[[maybe_unused]] int* n,
[[maybe_unused]] scalar_t* a,
[[maybe_unused]] scalar_t* x,
[[maybe_unused]] int* incx) {
TORCH_INTERNAL_ASSERT(
false, "scal_fast_path shouldn't be called for this configuration");
}
template <typename scalar_t>
static void gemv_fast_path(
[[maybe_unused]] const char* trans,
[[maybe_unused]] const int* m,
[[maybe_unused]] const int* n,
[[maybe_unused]] const scalar_t* alpha,
[[maybe_unused]] const scalar_t* a,
[[maybe_unused]] const int* lda,
[[maybe_unused]] const scalar_t* x,
[[maybe_unused]] const int* incx,
[[maybe_unused]] const scalar_t* beta,
[[maybe_unused]] scalar_t* y,
[[maybe_unused]] const int* incy) {
TORCH_INTERNAL_ASSERT(
false, "gemv_fast_path shouldn't be called for this configuration");
}
#define INSTANTIATE(scalar_t) \
template bool scal_use_fast_path<scalar_t>(int64_t n, int64_t incx); \
template bool gemv_use_fast_path<scalar_t>(char trans, int64_t m, int64_t n, scalar_t alpha, int64_t lda, int64_t incx, scalar_t beta, int64_t incy); \
template void gemv_fast_path<scalar_t>(const char *trans, const int *m, const int *n, const scalar_t *alpha, const scalar_t *a, const int *lda, const scalar_t *x, const int *incx, const scalar_t *beta, scalar_t *y, const int *incy); \
template void scal_fast_path<scalar_t>(int *n, scalar_t *a, scalar_t *x, int *incx);
#if AT_BUILD_WITH_BLAS()
template <>
bool scal_use_fast_path<double>(int64_t n, int64_t incx) {
auto intmax = std::numeric_limits<int>::max();
return n <= intmax && incx <= intmax;
}
template <>
bool scal_use_fast_path<float>(int64_t n, int64_t incx) {
return scal_use_fast_path<double>(n, incx);
}
template <>
void scal_fast_path<double>(int *n, double *a, double *x, int *incx) {
dscal_(n, a, x, incx);
}
template <>
void scal_fast_path<float>(int *n, float *a, float *x, int *incx) {
sscal_(n, a, x, incx);
}
template <>
bool gemv_use_fast_path<float>(
[[maybe_unused]] char trans,
int64_t m,
int64_t n,
[[maybe_unused]] float alpha,
int64_t lda,
int64_t incx,
[[maybe_unused]] float beta,
int64_t incy) {
auto intmax = std::numeric_limits<int>::max();
return (m <= intmax) && (n <= intmax) && (lda <= intmax) &&
(incx > 0) && (incx <= intmax) && (incy > 0) && (incy <= intmax);
}
template <>
bool gemv_use_fast_path<double>(
[[maybe_unused]] char trans,
int64_t m,
int64_t n,
[[maybe_unused]] double alpha,
int64_t lda,
int64_t incx,
[[maybe_unused]] double beta,
int64_t incy) {
return gemv_use_fast_path<float>(
trans, m, n, (float)alpha, lda, incx, (float)beta, incy);
}
template <>
void gemv_fast_path<double>(const char *trans, const int *m, const int *n, const double *alpha, const double *a, const int *lda, const double *x, const int *incx, const double *beta, double *y, const int *incy) {
dgemv_(remove_const(trans), remove_const(m), remove_const(n), remove_const(alpha), remove_const(a), remove_const(lda), remove_const(x), remove_const(incx), remove_const(beta), y, remove_const(incy));
}
template <>
void gemv_fast_path<float>(const char *trans, const int *m, const int *n, const float *alpha, const float *a, const int *lda, const float *x, const int *incx, const float *beta, float *y, const int *incy) {
sgemv_(remove_const(trans), remove_const(m), remove_const(n), remove_const(alpha), remove_const(a), remove_const(lda), remove_const(x), remove_const(incx), remove_const(beta), y, remove_const(incy));
}
INSTANTIATE(uint8_t)
INSTANTIATE(int8_t)
INSTANTIATE(int16_t)
INSTANTIATE(int)
INSTANTIATE(int64_t)
#if !defined(C10_MOBILE)
template <>
bool gemv_use_fast_path<at::BFloat16>(
[[maybe_unused]] char trans,
[[maybe_unused]] int64_t m,
[[maybe_unused]] int64_t n,
at::BFloat16 alpha,
[[maybe_unused]] int64_t lda,
[[maybe_unused]] int64_t incx,
at::BFloat16 beta,
[[maybe_unused]] int64_t incy) {
return (trans == 'T' || trans == 't') && incx == 1 && alpha == 1.0 &&
beta == 0.0;
}
static void bf16_gemv_trans(
const int m,
const int n,
const at::BFloat16 alpha,
const at::BFloat16* a,
const int lda,
const at::BFloat16* x,
const int incx,
const at::BFloat16 beta,
at::BFloat16* y,
const int incy) {
return bf16_gemv_trans_stub(kCPU, m, n, alpha, a, lda, x, incx, beta, y, incy);
}
template <>
void gemv_fast_path<at::BFloat16>(
const char* trans,
const int* m,
const int* n,
const at::BFloat16* alpha,
const at::BFloat16* a,
const int* lda,
const at::BFloat16* x,
const int* incx,
const at::BFloat16* beta,
at::BFloat16* y,
const int* incy) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(trans[0] == 'T' || trans[0] == 't');
bf16_gemv_trans(
*m,
*n,
*alpha,
a,
*lda,
x,
*incx,
*beta,
y,
*incy);
}
#if !defined(__aarch64__)
// Currently, only fp16_gemv_trans is built for non-aarch64.
template <>
bool gemv_use_fast_path<at::Half>(
char trans,
[[maybe_unused]] int64_t m,
[[maybe_unused]] int64_t n,
at::Half alpha,
[[maybe_unused]] int64_t lda,
[[maybe_unused]] int64_t incx,
[[maybe_unused]] at::Half beta,
[[maybe_unused]] int64_t incy) {
// clang is capable of constant-folding fp16_ieee_from_fp32_value,
// so use it to get simple integer comparisons.
// https://godbolt.org/z/v936hroYb
using c10::detail::fp16_ieee_from_fp32_value;;
return (trans == 'T' || trans == 't') && incx == 1 &&
alpha.x == fp16_ieee_from_fp32_value(1.0f);
}
template <>
void gemv_fast_path<at::Half>(
const char* trans,
const int* m,
const int* n,
const at::Half* alpha,
const at::Half* a,
const int* lda,
const at::Half* x,
const int* incx,
const at::Half* beta,
at::Half* y,
const int* incy) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(trans[0] == 'T' || trans[0] == 't');
fp16_gemv_trans(
*m,
*n,
*alpha,
a,
*lda,
x,
*incx,
*beta,
y,
*incy);
}
#else // !defined(__aarch64__))
template <>
bool gemv_use_fast_path<at::Half>(
char trans,
[[maybe_unused]] int64_t m,
[[maybe_unused]] int64_t n,
at::Half alpha,
[[maybe_unused]] int64_t lda,
[[maybe_unused]] int64_t incx,
at::Half beta,
[[maybe_unused]] int64_t incy) {
return incx == 1 && c10::detail::fp16_from_bits(alpha.x) == 1.0f &&
// TODO: enable nonzero beta for fp16_gemv_notrans
(c10::detail::fp16_from_bits(beta.x) == 0.0f || trans == 't' || trans == 'T');
}
template <>
void gemv_fast_path<at::Half>(
const char* trans,
const int* m,
const int* n,
const at::Half* alpha,
const at::Half* a,
const int* lda,
const at::Half* x,
const int* incx,
const at::Half* beta,
at::Half* y,
const int* incy) {
using namespace c10::detail;
if ((trans[0] == 'T') || (trans[0] == 't')) {
fp16_gemv_trans(
*m,
*n,
*alpha,
a,
*lda,
x,
*incx,
*beta,
y,
*incy);
} else {
fp16_gemv_notrans(
*m,
*n,
*alpha,
a,
*lda,
x,
*incx,
*beta,
y,
*incy);
}
}
// Note that the above block was an else, so it's active if __aarch64__ *is* defined.
#endif // !defined(__aarch64__)
#else // !defined(C10_MOBILE))
INSTANTIATE(c10::Half)
INSTANTIATE(c10::BFloat16)
#endif // !defined(C10_MOBILE)
#endif // AT_BUILD_WITH_BLAS
#undef INSTANTIATE
} // namespace blas_impl
template <typename scalar_t>
static inline void scal(int64_t n, scalar_t a, scalar_t *x, int64_t incx)
{
if (n == 1) incx = 1;
#if AT_BUILD_WITH_BLAS()
if (blas_impl::scal_use_fast_path<scalar_t>(n, incx)) {
int i_n = (int)n;
int i_incx = (int)incx;
blas_impl::scal_fast_path<scalar_t>(&i_n, &a, x, &i_incx);
return;
}
#endif
for (const auto i : c10::irange(n)) {
if (a == scalar_t(0)) {
x[i * incx] = 0;
} else {
x[i * incx] *= a;
}
}
}
template<typename scalar_t>
void gemv(char trans, int64_t m, int64_t n, scalar_t alpha, const scalar_t *a, int64_t lda, const scalar_t *x, int64_t incx, scalar_t beta, scalar_t *y, int64_t incy) {
if(n == 1) lda = m;
#if AT_BUILD_WITH_BLAS()
if (blas_impl::gemv_use_fast_path<scalar_t>(trans, m, n, alpha, lda, incx, beta, incy)) {
TORCH_CHECK(lda >= std::max<int64_t>(1L, m), "lda should be at least max(1,", m, "), but have ", lda);
int i_m = (int)m;
int i_n = (int)n;
int i_lda = (int)lda;
int i_incx = (int)incx;
int i_incy = (int)incy;
blas_impl::gemv_fast_path<scalar_t>(&trans, &i_m, &i_n, &alpha, a, &i_lda, x, &i_incx, &beta, y, &i_incy);
return;
}
#endif
using opmath_t = at::opmath_type<scalar_t>;
if ((trans == 'T') || (trans == 't')) {
for (const auto i : c10::irange(n)) {
opmath_t sum = 0;
const scalar_t *row_ = a + lda * i;
for (const auto j : c10::irange(m)) {
sum += static_cast<opmath_t>(x[j * incx]) * static_cast<opmath_t>(row_[j]);
}
if (beta == scalar_t(0)) {
y[i * incy] = alpha * sum;
} else {
y[i * incy] = beta * y[i * incy] + alpha * sum;
}
}
} else {
if (beta != scalar_t(1) && beta != scalar_t(0)) scal<scalar_t>(m, beta, y, incy);
constexpr bool is_low_precision = !std::is_same_v<opmath_t, scalar_t>;
std::vector<opmath_t> sum;
if constexpr (is_low_precision) {
sum.resize(m);
}
for (const auto j : c10::irange(n)) {
const scalar_t *column_ = a + lda * j;
opmath_t z = alpha * static_cast<opmath_t>(x[j * incx]);
for (const auto i : c10::irange(m)) {
//output values are ignored if beta is 0, and set to 0, nans and infs are not propagated
if (j==0 && beta==scalar_t(0)) {
if constexpr (!is_low_precision) {
y[i * incy] = 0;
}
}
if constexpr (is_low_precision) {
sum[i] += z * column_[i];
} else {
y[i * incy] += z * column_[i];
}
}
}
if constexpr (is_low_precision) {
if (beta == scalar_t(0)) {
for (const auto i : c10::irange(m)) {
y[i * incy] = sum[i];
}
} else {
for (const auto i : c10::irange(m)) {
y[i * incy] += sum[i];
}
}
}
}
return;
}
#define INSTANTIATE(scalar_t, _) \
template void gemv<scalar_t>(char trans, int64_t m, int64_t n, scalar_t alpha, const scalar_t *a, int64_t lda, const scalar_t *x, int64_t incx, scalar_t beta, scalar_t *y, int64_t incy);
AT_FORALL_SCALAR_TYPES_AND2(BFloat16, Half, INSTANTIATE)
AT_FORALL_COMPLEX_TYPES(INSTANTIATE)
#undef INSTANTIATE
namespace blas_impl {
#if AT_BUILD_WITH_BLAS()
static float dot_fast_path(int n, const float* x, int incx, const float* y, int incy) {
return sdot_(&n, x, &incx, y, &incy);
}
static double dot_fast_path(int n, const double* x, int incx, const double* y, int incy) {
return ddot_(&n, const_cast<double*>(x), &incx, const_cast<double*>(y), &incy);
}
static c10::complex<float> vdot_fast_path(int n, const c10::complex<float>* x, int incx, const c10::complex<float>* y, int incy) {
c10::complex<float> result;
cdotc_(reinterpret_cast<std::complex<float>* >(&result), &n, reinterpret_cast<const std::complex<float>*>(x), &incx, reinterpret_cast<const std::complex<float>*>(y), &incy);
return result;
}
static c10::complex<double> vdot_fast_path(int n, const c10::complex<double>* x, int incx, const c10::complex<double>* y, int incy) {
c10::complex<double> result;
zdotc_(reinterpret_cast<std::complex<double>* >(&result), &n, reinterpret_cast<const std::complex<double>*>(x), &incx, reinterpret_cast<const std::complex<double>*>(y), &incy);
return result;
}
static c10::complex<double> dot_fast_path(int n, const c10::complex<double>* x, int incx, const c10::complex<double>* y, int incy) {
c10::complex<double> result;
zdotu_(reinterpret_cast<std::complex<double>* >(&result), &n, reinterpret_cast<const std::complex<double>*>(x), &incx, reinterpret_cast<const std::complex<double>*>(y), &incy);
return result;
}
static c10::complex<float> dot_fast_path(int n, const c10::complex<float>* x, int incx, const c10::complex<float>* y, int incy) {
c10::complex<float> result;
cdotu_(reinterpret_cast<std::complex<float>* >(&result), &n, reinterpret_cast<const std::complex<float>*>(x), &incx, reinterpret_cast<const std::complex<float>*>(y), &incy);
return result;
}
#endif
template <typename scalar_t, typename Functor>
static scalar_t dot_naive(
int64_t n,
const scalar_t* x,
int64_t incx,
const scalar_t* y,
int64_t incy,
Functor op) {
using opmath_t = at::opmath_type<scalar_t>;
opmath_t sum = 0;
for (int64_t i = 0; i < n; i++) {
sum += op(static_cast<opmath_t>(x[i * incx]), static_cast<opmath_t>(y[i * incy]));
}
return static_cast<scalar_t>(sum);
}
} // namespace blas_impl
template <typename scalar_t>
static scalar_t dot_impl_floating(int64_t n, const scalar_t* x, int64_t incx, const scalar_t* y, int64_t incy)
{
if (n == 1) {
incx = 1;
incy = 1;
}
#if AT_BUILD_WITH_BLAS()
if ((n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX)) {
return blas_impl::dot_fast_path(n, x, incx, y, incy);
} else {
return blas_impl::dot_naive(n, x, incx, y, incy, std::multiplies<scalar_t>{});
}
#else
{ return blas_impl::dot_naive(n, x, incx, y, incy, std::multiplies<scalar_t>{}); }
#endif
}
template <typename scalar_t>
scalar_t dot_impl(int64_t n, const scalar_t* x, int64_t incx, const scalar_t* y, int64_t incy) {
if (n == 1) {
incx = 1;
incy = 1;
}
return blas_impl::dot_naive(n, x, incx, y, incy, std::multiplies<at::opmath_type<scalar_t>>{});
}
template <>
float dot_impl(int64_t n, const float* x, int64_t incx, const float* y, int64_t incy) {
return dot_impl_floating(n, x, incx, y, incy);
}
template <>
double dot_impl(int64_t n, const double* x, int64_t incx, const double* y, int64_t incy) {
return dot_impl_floating(n, x, incx, y, incy);
}
template <>
c10::complex<double> dot_impl(int64_t n, const c10::complex<double>* x, int64_t incx, const c10::complex<double>* y, int64_t incy) {
return dot_impl_floating(n, x, incx, y, incy);
}
template <>
c10::complex<float> dot_impl(int64_t n, const c10::complex<float>* x, int64_t incx, const c10::complex<float>* y, int64_t incy) {
return dot_impl_floating(n, x, incx, y, incy);
}
template <>
Half dot_impl(int64_t n, const Half* x, int64_t incx, const Half* y, int64_t incy) {
if (n == 1) {
incx = 1;
incy = 1;
}
#if !defined(C10_MOBILE)
if (incx == 1 && incy == 1) {
return blas_impl::fp16_dot(n, x, incx, y, incy);
}
#endif // !defined(C10_MOBILE)
return blas_impl::dot_naive(n, x, incx, y, incy, std::multiplies<float>{});
}
template <>
BFloat16 dot_impl(int64_t n, const BFloat16* x, int64_t incx, const BFloat16* y, int64_t incy) {
if (n == 1) {
incx = 1;
incy = 1;
}
#if !defined(C10_MOBILE)
if (incx == 1 && incy == 1) {
return blas_impl::bf16_dot(n, x, incx, y, incy);
}
#endif // !defined(C10_MOBILE)
return blas_impl::dot_naive(n, x, incx, y, incy, std::multiplies<float>{});
}
namespace {
template <typename scalar_t>
struct vdot_op {
scalar_t operator()(scalar_t x, scalar_t y) {
return std::conj(x) * y;
}
};
} // anonymous namespace
template <typename scalar_t>
scalar_t vdot_impl(int64_t n, const scalar_t* x, int64_t incx, const scalar_t* y, int64_t incy) {
if (n == 1) {
incx = 1;
incy = 1;
}
#if AT_BUILD_WITH_BLAS()
if ((n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX)) {
return blas_impl::vdot_fast_path(n, x, incx, y, incy);
} else {
return blas_impl::dot_naive(n, x, incx, y, incy, vdot_op<scalar_t>{});
}
#else
{ return blas_impl::dot_naive(n, x, incx, y, incy, vdot_op<scalar_t>{}); }
#endif
}
// Skip reinstantiating the explicitly specialized types `float`, `double`, `half` & `bfloat16`.
#define INSTANTIATE_DOT_IMPL(scalar_t) \
template scalar_t dot_impl<scalar_t>( \
int64_t n, const scalar_t * x, int64_t incx, const scalar_t * y, int64_t incy);
INSTANTIATE_DOT_IMPL(uint8_t)
INSTANTIATE_DOT_IMPL(int8_t)
INSTANTIATE_DOT_IMPL(int16_t)
INSTANTIATE_DOT_IMPL(int)
INSTANTIATE_DOT_IMPL(int64_t)
#define INSTANTIATE_VDOT_IMPL(scalar_t) \
template scalar_t vdot_impl<scalar_t>( \
int64_t n, const scalar_t * x, int64_t incx, const scalar_t * y, int64_t incy);
INSTANTIATE_VDOT_IMPL(c10::complex<float>)
INSTANTIATE_VDOT_IMPL(c10::complex<double>)
#undef INSTANTIATE_DOT_IMPL
} // namespace at::native