#define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #if defined(__aarch64__) && !defined(C10_MOBILE) #include #include #endif namespace { /// Wrapper for const_cast with type-inference. /// /// Use this to call into APIs that are not const-correct. template T* remove_const(const T* x) { return const_cast(x); } } // namespace #if AT_BUILD_WITH_BLAS() #ifndef _ARMPL_H extern "C" double ddot_(int *n, double *x, int *incx, double *y, int *incy); extern "C" void dscal_(int *n, double *a, double *x, int *incx); extern "C" void sscal_(int *n, float *a, float *x, int *incx); extern "C" void dgemv_(char *trans, int *m, int *n, double *alpha, double *a, int *lda, double *x, int *incx, double *beta, double *y, int *incy); extern "C" void sgemv_(char *trans, int *m, int *n, float *alpha, float *a, int *lda, float *x, int *incx, float *beta, float *y, int *incy); #endif #if AT_BLAS_F2C() # define ffloat double #else # define ffloat float #endif #if AT_BLAS_USE_CBLAS_DOT() extern "C" float cblas_sdot(const int n, const float *x, const int incx, const float *y, const int incy); extern "C" void cblas_cdotu_sub(const int n, const void *x, const int incx, const void *y, const int incy, void *dotu); extern "C" void cblas_zdotu_sub(const int n, const void *x, const int incx, const void *y, const int incy, void *dotu); extern "C" void cblas_cdotc_sub(const int n, const void *x, const int incx, const void *y, const int incy, void *dotc); extern "C" void cblas_zdotc_sub(const int n, const void *x, const int incx, const void *y, const int incy, void *dotc); #ifndef _ARMPL_H static inline ffloat sdot_(const int *n, const float *x, const int *incx, const float *y, const int *incy) { return cblas_sdot(*n, x, *incx, y, *incy); } #endif static inline void cdotu_(std::complex *res, const int *n, const std::complex *x, const int *incx, const std::complex *y, const int *incy) { cblas_cdotu_sub(*n, x, *incx, y, *incy, res); } static inline void zdotu_(std::complex *res, const int *n, const std::complex *x, const int *incx, const std::complex *y, const int *incy) { cblas_zdotu_sub(*n, x, *incx, y, *incy, res); } static inline void cdotc_(std::complex *res, const int *n, const std::complex *x, const int *incx, const std::complex *y, const int *incy) { cblas_cdotc_sub(*n, x, *incx, y, *incy, res); } static inline void zdotc_(std::complex *res, const int *n, const std::complex *x, const int *incx, const std::complex *y, const int *incy) { cblas_zdotc_sub(*n, x, *incx, y, *incy, res); } #else extern "C" ffloat sdot_(int *n, const float *x, int *incx, const float *y, int *incy); extern "C" void cdotu_(std::complex *res, int *n, const std::complex *x, int *incx, const std::complex *y, int *incy); extern "C" void zdotu_(std::complex *res, int *n, const std::complex *x, int *incx, const std::complex *y, int *incy); extern "C" void cdotc_(std::complex *res, int *n, const std::complex *x, int *incx, const std::complex *y, int *incy); extern "C" void zdotc_(std::complex *res, int *n, const std::complex *x, int *incx, const std::complex *y, int *incy); #endif // AT_BLAS_USE_CBLAS_DOT #endif // AT_BUILD_WITH_BLAS namespace at::native { #if !defined(C10_MOBILE) DEFINE_DISPATCH(fp16_gemv_trans_stub); DEFINE_DISPATCH(bf16_gemv_trans_stub); DEFINE_DISPATCH(fp16_dot_stub); DEFINE_DISPATCH(bf16_dot_stub); #endif // !defined(C10_MOBILE) namespace blas_impl { #if !defined(C10_MOBILE) void fp16_gemv_trans( const int m, const int n, const float alpha, const Half* a, const int lda, const Half* x, const int incx, const float beta, Half* y, const int incy); void fp16_gemv_trans( const int m, const int n, const float alpha, const Half* a, const int lda, const Half* x, const int incx, const float beta, Half* y, const int incy) { fp16_gemv_trans_stub(kCPU, m, n, alpha, a, lda, x, incx, beta, y, incy); } static float fp16_dot( const int64_t n, const Half* x, const int64_t incx, const Half* y, const int64_t incy) { return fp16_dot_stub(kCPU, n, x, incx, y, incy); } static float bf16_dot( const int64_t n, const BFloat16* x, const int64_t incx, const BFloat16* y, const int64_t incy) { return bf16_dot_stub(kCPU, n, x, incx, y, incy); } #endif // !defined(C10_MOBILE) #if defined(__aarch64__) && !defined(C10_MOBILE) #ifdef __ARM_FEATURE_FP16_SCALAR_ARITHMETIC static void fp16_gemv_notrans_fp16_arith(int m, int n, const float16_t* a, const int lda, const float16_t *x, float16_t *y) { for (auto j = 0; j < n; j++) { auto vecCol = vdup_n_f16(x[j]); const auto* column = a + lda * j; for (auto i = 0; i < m; i += 4) { auto yf16 = y + i; auto matRow = vld1_f16(column + i); auto resVec = j != 0 ? vld1_f16(yf16) : vdup_n_f16(0); resVec = vfma_lane_f16(resVec, matRow, vecCol, 0); vst1_f16(yf16, resVec); } } } #endif static void fp16_gemv_notrans_fp32_arith(int m, int n, const float16_t* a, const int lda, const float16_t *x, float16_t *y) { std::vector sum(m); for (auto j = 0; j < n; j++) { auto vecCol = vdup_n_f32(x[j]); const auto* column = a + lda * j; for (auto i = 0; i < m; i += 4) { auto sf32 = sum.data() + i; auto matRow = vcvt_f32_f16(vld1_f16(column + i)); auto resVec = j != 0 ? vld1q_f32(sf32) : vdupq_n_f32(0); resVec = vfmaq_lane_f32(resVec, matRow, vecCol, 0); vst1q_f32(sf32, resVec); } } for (auto i = 0; i < m; i+= 4) { vst1_f16(y + i, vcvt_f16_f32(vld1q_f32(sum.data() + i))); } } void fp16_gemv_notrans( const int m, const int n, const float alpha, const Half* a, const int lda, const Half* x, const int incx, const float beta, Half* y, const int incy); void fp16_gemv_notrans( const int m, const int n, const float alpha, const Half* a, const int lda, const Half* x, const int incx, const float beta, Half* y, const int incy) { if (incx == 1 && alpha == 1.0 && beta == 0.0 && m % 4 == 0 && incy == 1) { #ifdef __ARM_FEATURE_FP16_SCALAR_ARITHMETIC if (at::globalContext().allowFP16ReductionCPU()) { return fp16_gemv_notrans_fp16_arith(m, n, reinterpret_cast(a), lda, reinterpret_cast(x), reinterpret_cast(y)); } #endif return fp16_gemv_notrans_fp32_arith(m, n, reinterpret_cast(a), lda, reinterpret_cast(x), reinterpret_cast(y)); } std::vector sum(m); for (const auto j : c10::irange(n)) { const auto* column_ = a + lda * j; auto z = alpha * x[j * incx]; for (const auto i : c10::irange(m)) { sum[i] += z * column_[i]; } } if (beta == 0.0) { for (const auto i : c10::irange(m)) { y[i * incy] = sum[i]; } } else { for (const auto i : c10::irange(m)) { y[i * incy] += sum[i]; } } } #endif // defined(__aarch64__) && !defined(C10_MOBILE) template static bool scal_use_fast_path( [[maybe_unused]] int64_t n, [[maybe_unused]] int64_t incx) { return false; } template static bool gemv_use_fast_path( [[maybe_unused]] char trans, [[maybe_unused]] int64_t m, [[maybe_unused]] int64_t n, [[maybe_unused]] scalar_t alpha, [[maybe_unused]] int64_t lda, [[maybe_unused]] int64_t incx, [[maybe_unused]] scalar_t beta, [[maybe_unused]] int64_t incy) { return false; } template static void scal_fast_path( [[maybe_unused]] int* n, [[maybe_unused]] scalar_t* a, [[maybe_unused]] scalar_t* x, [[maybe_unused]] int* incx) { TORCH_INTERNAL_ASSERT( false, "scal_fast_path shouldn't be called for this configuration"); } template static void gemv_fast_path( [[maybe_unused]] const char* trans, [[maybe_unused]] const int* m, [[maybe_unused]] const int* n, [[maybe_unused]] const scalar_t* alpha, [[maybe_unused]] const scalar_t* a, [[maybe_unused]] const int* lda, [[maybe_unused]] const scalar_t* x, [[maybe_unused]] const int* incx, [[maybe_unused]] const scalar_t* beta, [[maybe_unused]] scalar_t* y, [[maybe_unused]] const int* incy) { TORCH_INTERNAL_ASSERT( false, "gemv_fast_path shouldn't be called for this configuration"); } #define INSTANTIATE(scalar_t) \ template bool scal_use_fast_path(int64_t n, int64_t incx); \ template bool gemv_use_fast_path(char trans, int64_t m, int64_t n, scalar_t alpha, int64_t lda, int64_t incx, scalar_t beta, int64_t incy); \ template void gemv_fast_path(const char *trans, const int *m, const int *n, const scalar_t *alpha, const scalar_t *a, const int *lda, const scalar_t *x, const int *incx, const scalar_t *beta, scalar_t *y, const int *incy); \ template void scal_fast_path(int *n, scalar_t *a, scalar_t *x, int *incx); #if AT_BUILD_WITH_BLAS() template <> bool scal_use_fast_path(int64_t n, int64_t incx) { auto intmax = std::numeric_limits::max(); return n <= intmax && incx <= intmax; } template <> bool scal_use_fast_path(int64_t n, int64_t incx) { return scal_use_fast_path(n, incx); } template <> void scal_fast_path(int *n, double *a, double *x, int *incx) { dscal_(n, a, x, incx); } template <> void scal_fast_path(int *n, float *a, float *x, int *incx) { sscal_(n, a, x, incx); } template <> bool gemv_use_fast_path( [[maybe_unused]] char trans, int64_t m, int64_t n, [[maybe_unused]] float alpha, int64_t lda, int64_t incx, [[maybe_unused]] float beta, int64_t incy) { auto intmax = std::numeric_limits::max(); return (m <= intmax) && (n <= intmax) && (lda <= intmax) && (incx > 0) && (incx <= intmax) && (incy > 0) && (incy <= intmax); } template <> bool gemv_use_fast_path( [[maybe_unused]] char trans, int64_t m, int64_t n, [[maybe_unused]] double alpha, int64_t lda, int64_t incx, [[maybe_unused]] double beta, int64_t incy) { return gemv_use_fast_path( trans, m, n, (float)alpha, lda, incx, (float)beta, incy); } template <> void gemv_fast_path(const char *trans, const int *m, const int *n, const double *alpha, const double *a, const int *lda, const double *x, const int *incx, const double *beta, double *y, const int *incy) { dgemv_(remove_const(trans), remove_const(m), remove_const(n), remove_const(alpha), remove_const(a), remove_const(lda), remove_const(x), remove_const(incx), remove_const(beta), y, remove_const(incy)); } template <> void gemv_fast_path(const char *trans, const int *m, const int *n, const float *alpha, const float *a, const int *lda, const float *x, const int *incx, const float *beta, float *y, const int *incy) { sgemv_(remove_const(trans), remove_const(m), remove_const(n), remove_const(alpha), remove_const(a), remove_const(lda), remove_const(x), remove_const(incx), remove_const(beta), y, remove_const(incy)); } INSTANTIATE(uint8_t) INSTANTIATE(int8_t) INSTANTIATE(int16_t) INSTANTIATE(int) INSTANTIATE(int64_t) #if !defined(C10_MOBILE) template <> bool gemv_use_fast_path( [[maybe_unused]] char trans, [[maybe_unused]] int64_t m, [[maybe_unused]] int64_t n, at::BFloat16 alpha, [[maybe_unused]] int64_t lda, [[maybe_unused]] int64_t incx, at::BFloat16 beta, [[maybe_unused]] int64_t incy) { return (trans == 'T' || trans == 't') && incx == 1 && alpha == 1.0 && beta == 0.0; } static void bf16_gemv_trans( const int m, const int n, const at::BFloat16 alpha, const at::BFloat16* a, const int lda, const at::BFloat16* x, const int incx, const at::BFloat16 beta, at::BFloat16* y, const int incy) { return bf16_gemv_trans_stub(kCPU, m, n, alpha, a, lda, x, incx, beta, y, incy); } template <> void gemv_fast_path( const char* trans, const int* m, const int* n, const at::BFloat16* alpha, const at::BFloat16* a, const int* lda, const at::BFloat16* x, const int* incx, const at::BFloat16* beta, at::BFloat16* y, const int* incy) { TORCH_INTERNAL_ASSERT_DEBUG_ONLY(trans[0] == 'T' || trans[0] == 't'); bf16_gemv_trans( *m, *n, *alpha, a, *lda, x, *incx, *beta, y, *incy); } #if !defined(__aarch64__) // Currently, only fp16_gemv_trans is built for non-aarch64. template <> bool gemv_use_fast_path( char trans, [[maybe_unused]] int64_t m, [[maybe_unused]] int64_t n, at::Half alpha, [[maybe_unused]] int64_t lda, [[maybe_unused]] int64_t incx, [[maybe_unused]] at::Half beta, [[maybe_unused]] int64_t incy) { // clang is capable of constant-folding fp16_ieee_from_fp32_value, // so use it to get simple integer comparisons. // https://godbolt.org/z/v936hroYb using c10::detail::fp16_ieee_from_fp32_value;; return (trans == 'T' || trans == 't') && incx == 1 && alpha.x == fp16_ieee_from_fp32_value(1.0f); } template <> void gemv_fast_path( const char* trans, const int* m, const int* n, const at::Half* alpha, const at::Half* a, const int* lda, const at::Half* x, const int* incx, const at::Half* beta, at::Half* y, const int* incy) { TORCH_INTERNAL_ASSERT_DEBUG_ONLY(trans[0] == 'T' || trans[0] == 't'); fp16_gemv_trans( *m, *n, *alpha, a, *lda, x, *incx, *beta, y, *incy); } #else // !defined(__aarch64__)) template <> bool gemv_use_fast_path( char trans, [[maybe_unused]] int64_t m, [[maybe_unused]] int64_t n, at::Half alpha, [[maybe_unused]] int64_t lda, [[maybe_unused]] int64_t incx, at::Half beta, [[maybe_unused]] int64_t incy) { return incx == 1 && c10::detail::fp16_from_bits(alpha.x) == 1.0f && // TODO: enable nonzero beta for fp16_gemv_notrans (c10::detail::fp16_from_bits(beta.x) == 0.0f || trans == 't' || trans == 'T'); } template <> void gemv_fast_path( const char* trans, const int* m, const int* n, const at::Half* alpha, const at::Half* a, const int* lda, const at::Half* x, const int* incx, const at::Half* beta, at::Half* y, const int* incy) { using namespace c10::detail; if ((trans[0] == 'T') || (trans[0] == 't')) { fp16_gemv_trans( *m, *n, *alpha, a, *lda, x, *incx, *beta, y, *incy); } else { fp16_gemv_notrans( *m, *n, *alpha, a, *lda, x, *incx, *beta, y, *incy); } } // Note that the above block was an else, so it's active if __aarch64__ *is* defined. #endif // !defined(__aarch64__) #else // !defined(C10_MOBILE)) INSTANTIATE(c10::Half) INSTANTIATE(c10::BFloat16) #endif // !defined(C10_MOBILE) #endif // AT_BUILD_WITH_BLAS #undef INSTANTIATE } // namespace blas_impl template static inline void scal(int64_t n, scalar_t a, scalar_t *x, int64_t incx) { if (n == 1) incx = 1; #if AT_BUILD_WITH_BLAS() if (blas_impl::scal_use_fast_path(n, incx)) { int i_n = (int)n; int i_incx = (int)incx; blas_impl::scal_fast_path(&i_n, &a, x, &i_incx); return; } #endif for (const auto i : c10::irange(n)) { if (a == scalar_t(0)) { x[i * incx] = 0; } else { x[i * incx] *= a; } } } template void gemv(char trans, int64_t m, int64_t n, scalar_t alpha, const scalar_t *a, int64_t lda, const scalar_t *x, int64_t incx, scalar_t beta, scalar_t *y, int64_t incy) { if(n == 1) lda = m; #if AT_BUILD_WITH_BLAS() if (blas_impl::gemv_use_fast_path(trans, m, n, alpha, lda, incx, beta, incy)) { TORCH_CHECK(lda >= std::max(1L, m), "lda should be at least max(1,", m, "), but have ", lda); int i_m = (int)m; int i_n = (int)n; int i_lda = (int)lda; int i_incx = (int)incx; int i_incy = (int)incy; blas_impl::gemv_fast_path(&trans, &i_m, &i_n, &alpha, a, &i_lda, x, &i_incx, &beta, y, &i_incy); return; } #endif using opmath_t = at::opmath_type; if ((trans == 'T') || (trans == 't')) { for (const auto i : c10::irange(n)) { opmath_t sum = 0; const scalar_t *row_ = a + lda * i; for (const auto j : c10::irange(m)) { sum += static_cast(x[j * incx]) * static_cast(row_[j]); } if (beta == scalar_t(0)) { y[i * incy] = alpha * sum; } else { y[i * incy] = beta * y[i * incy] + alpha * sum; } } } else { if (beta != scalar_t(1) && beta != scalar_t(0)) scal(m, beta, y, incy); constexpr bool is_low_precision = !std::is_same_v; std::vector sum; if constexpr (is_low_precision) { sum.resize(m); } for (const auto j : c10::irange(n)) { const scalar_t *column_ = a + lda * j; opmath_t z = alpha * static_cast(x[j * incx]); for (const auto i : c10::irange(m)) { //output values are ignored if beta is 0, and set to 0, nans and infs are not propagated if (j==0 && beta==scalar_t(0)) { if constexpr (!is_low_precision) { y[i * incy] = 0; } } if constexpr (is_low_precision) { sum[i] += z * column_[i]; } else { y[i * incy] += z * column_[i]; } } } if constexpr (is_low_precision) { if (beta == scalar_t(0)) { for (const auto i : c10::irange(m)) { y[i * incy] = sum[i]; } } else { for (const auto i : c10::irange(m)) { y[i * incy] += sum[i]; } } } } return; } #define INSTANTIATE(scalar_t, _) \ template void gemv(char trans, int64_t m, int64_t n, scalar_t alpha, const scalar_t *a, int64_t lda, const scalar_t *x, int64_t incx, scalar_t beta, scalar_t *y, int64_t incy); AT_FORALL_SCALAR_TYPES_AND2(BFloat16, Half, INSTANTIATE) AT_FORALL_COMPLEX_TYPES(INSTANTIATE) #undef INSTANTIATE namespace blas_impl { #if AT_BUILD_WITH_BLAS() static float dot_fast_path(int n, const float* x, int incx, const float* y, int incy) { return sdot_(&n, x, &incx, y, &incy); } static double dot_fast_path(int n, const double* x, int incx, const double* y, int incy) { return ddot_(&n, const_cast(x), &incx, const_cast(y), &incy); } static c10::complex vdot_fast_path(int n, const c10::complex* x, int incx, const c10::complex* y, int incy) { c10::complex result; cdotc_(reinterpret_cast* >(&result), &n, reinterpret_cast*>(x), &incx, reinterpret_cast*>(y), &incy); return result; } static c10::complex vdot_fast_path(int n, const c10::complex* x, int incx, const c10::complex* y, int incy) { c10::complex result; zdotc_(reinterpret_cast* >(&result), &n, reinterpret_cast*>(x), &incx, reinterpret_cast*>(y), &incy); return result; } static c10::complex dot_fast_path(int n, const c10::complex* x, int incx, const c10::complex* y, int incy) { c10::complex result; zdotu_(reinterpret_cast* >(&result), &n, reinterpret_cast*>(x), &incx, reinterpret_cast*>(y), &incy); return result; } static c10::complex dot_fast_path(int n, const c10::complex* x, int incx, const c10::complex* y, int incy) { c10::complex result; cdotu_(reinterpret_cast* >(&result), &n, reinterpret_cast*>(x), &incx, reinterpret_cast*>(y), &incy); return result; } #endif template static scalar_t dot_naive( int64_t n, const scalar_t* x, int64_t incx, const scalar_t* y, int64_t incy, Functor op) { using opmath_t = at::opmath_type; opmath_t sum = 0; for (int64_t i = 0; i < n; i++) { sum += op(static_cast(x[i * incx]), static_cast(y[i * incy])); } return static_cast(sum); } } // namespace blas_impl template static scalar_t dot_impl_floating(int64_t n, const scalar_t* x, int64_t incx, const scalar_t* y, int64_t incy) { if (n == 1) { incx = 1; incy = 1; } #if AT_BUILD_WITH_BLAS() if ((n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX)) { return blas_impl::dot_fast_path(n, x, incx, y, incy); } else { return blas_impl::dot_naive(n, x, incx, y, incy, std::multiplies{}); } #else { return blas_impl::dot_naive(n, x, incx, y, incy, std::multiplies{}); } #endif } template scalar_t dot_impl(int64_t n, const scalar_t* x, int64_t incx, const scalar_t* y, int64_t incy) { if (n == 1) { incx = 1; incy = 1; } return blas_impl::dot_naive(n, x, incx, y, incy, std::multiplies>{}); } template <> float dot_impl(int64_t n, const float* x, int64_t incx, const float* y, int64_t incy) { return dot_impl_floating(n, x, incx, y, incy); } template <> double dot_impl(int64_t n, const double* x, int64_t incx, const double* y, int64_t incy) { return dot_impl_floating(n, x, incx, y, incy); } template <> c10::complex dot_impl(int64_t n, const c10::complex* x, int64_t incx, const c10::complex* y, int64_t incy) { return dot_impl_floating(n, x, incx, y, incy); } template <> c10::complex dot_impl(int64_t n, const c10::complex* x, int64_t incx, const c10::complex* y, int64_t incy) { return dot_impl_floating(n, x, incx, y, incy); } template <> Half dot_impl(int64_t n, const Half* x, int64_t incx, const Half* y, int64_t incy) { if (n == 1) { incx = 1; incy = 1; } #if !defined(C10_MOBILE) if (incx == 1 && incy == 1) { return blas_impl::fp16_dot(n, x, incx, y, incy); } #endif // !defined(C10_MOBILE) return blas_impl::dot_naive(n, x, incx, y, incy, std::multiplies{}); } template <> BFloat16 dot_impl(int64_t n, const BFloat16* x, int64_t incx, const BFloat16* y, int64_t incy) { if (n == 1) { incx = 1; incy = 1; } #if !defined(C10_MOBILE) if (incx == 1 && incy == 1) { return blas_impl::bf16_dot(n, x, incx, y, incy); } #endif // !defined(C10_MOBILE) return blas_impl::dot_naive(n, x, incx, y, incy, std::multiplies{}); } namespace { template struct vdot_op { scalar_t operator()(scalar_t x, scalar_t y) { return std::conj(x) * y; } }; } // anonymous namespace template scalar_t vdot_impl(int64_t n, const scalar_t* x, int64_t incx, const scalar_t* y, int64_t incy) { if (n == 1) { incx = 1; incy = 1; } #if AT_BUILD_WITH_BLAS() if ((n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX)) { return blas_impl::vdot_fast_path(n, x, incx, y, incy); } else { return blas_impl::dot_naive(n, x, incx, y, incy, vdot_op{}); } #else { return blas_impl::dot_naive(n, x, incx, y, incy, vdot_op{}); } #endif } // Skip reinstantiating the explicitly specialized types `float`, `double`, `half` & `bfloat16`. #define INSTANTIATE_DOT_IMPL(scalar_t) \ template scalar_t dot_impl( \ int64_t n, const scalar_t * x, int64_t incx, const scalar_t * y, int64_t incy); INSTANTIATE_DOT_IMPL(uint8_t) INSTANTIATE_DOT_IMPL(int8_t) INSTANTIATE_DOT_IMPL(int16_t) INSTANTIATE_DOT_IMPL(int) INSTANTIATE_DOT_IMPL(int64_t) #define INSTANTIATE_VDOT_IMPL(scalar_t) \ template scalar_t vdot_impl( \ int64_t n, const scalar_t * x, int64_t incx, const scalar_t * y, int64_t incy); INSTANTIATE_VDOT_IMPL(c10::complex) INSTANTIATE_VDOT_IMPL(c10::complex) #undef INSTANTIATE_DOT_IMPL } // namespace at::native