sglang_v0.5.2/pytorch_2.8.0/third_party/fbgemm/src/UtilsSve.cc

182 lines
5.1 KiB
C++

/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#include "./TransposeUtils.h"
#include "./TransposeUtilsSve.h"
namespace fbgemm {
namespace internal {
#if HAVE_SVE
template <>
void transpose_sve(
int64_t M,
int64_t N,
const float* src,
int64_t ld_src,
float* dst,
int64_t ld_dst) {
int64_t ib = 0, jb = 0;
int64_t x = (N & 7) ^ 4;
if (x > 4) {
// If the remainder has n < 4 columns, we use the SSE kernel for the
// remainder because it requires 2 * (2 * 4 + 2 * N) = 16 + 4N instructions
// instead of 3 * 8 + 2 * N = 24 + 2N instructions in the masked AVX2
// kernel.
for (ib = 0; ib + 8 <= M; ib += 8) {
for (jb = 0; jb + 8 <= N; jb += 8) {
transpose_kernel_8x8_sve(
&src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
}
for (int64_t i = ib; i < ib + 8; i += 4) {
transpose_kernel_mxn_small_sve<4>(
N - jb,
&src[i * ld_src + jb],
ld_src,
&dst[i + jb * ld_dst],
ld_dst);
}
}
} else if (x == 0) {
// If the remainder has 4 columns, we use the SSE kernel for the remainder
// because it requires 2 * 16 = 32 instructions instead of 3 * 8 + 2 * 4 =
// 32 instructions + looping overhead needed in the masked AVX2 kernel.
for (ib = 0; ib + 8 <= M; ib += 8) {
for (jb = 0; jb + 8 <= N; jb += 8) {
transpose_kernel_8x8_sve(
&src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
}
for (int64_t i = ib; i < ib + 8; i += 4) {
transpose_kernel_4x4_sve(
&src[i * ld_src + jb], ld_src, &dst[i + jb * ld_dst], ld_dst);
}
}
} else {
for (ib = 0; ib + 8 <= M; ib += 8) {
for (jb = 0; jb + 8 <= N; jb += 8) {
transpose_kernel_8x8_sve(
&src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
}
if (jb < N) {
transpose_kernel_mxn_large_sve<8>(
N - jb,
&src[ib * ld_src + jb],
ld_src,
&dst[ib + jb * ld_dst],
ld_dst);
}
}
}
// Specialization for small M - ib cases so that the compiler can inline
// transpose_kernel_mxn_avx2 and unroll the loops whose iteration count
// depends on by M - ib .
// Specialization for m helps more than for n in transpose_kernel_mxn_avx2
// because we have more loops in that function whose iteration count depends
// on m.
switch (M - ib) {
case 1:
for (int64_t j = 0; j < N; ++j) {
dst[ib + j * ld_dst] = src[ib * ld_src + j];
}
break;
case 2:
for (jb = 0; jb + 4 <= N; jb += 4) {
transpose_kernel_mxn_small_sve<2>(
4, &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
}
if (jb < N) {
transpose_kernel_mxn_small_sve<2>(
N - jb,
&src[ib * ld_src + jb],
ld_src,
&dst[ib + jb * ld_dst],
ld_dst);
}
break;
case 3:
for (jb = 0; jb + 4 <= N; jb += 4) {
transpose_kernel_mxn_small_sve<3>(
4, &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
}
if (jb < N) {
transpose_kernel_mxn_small_sve<3>(
N - jb,
&src[ib * ld_src + jb],
ld_src,
&dst[ib + jb * ld_dst],
ld_dst);
}
break;
case 4:
for (jb = 0; jb + 4 <= N; jb += 4) {
transpose_kernel_4x4_sve(
&src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
}
if (jb < N) {
transpose_kernel_mxn_small_sve<4>(
N - jb,
&src[ib * ld_src + jb],
ld_src,
&dst[ib + jb * ld_dst],
ld_dst);
}
break;
case 5:
for (jb = 0; jb + 8 <= N; jb += 8) {
transpose_kernel_mxn_large_sve<5>(
8, &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
}
if (jb < N) {
transpose_kernel_mxn_large_sve<5>(
N - jb,
&src[ib * ld_src + jb],
ld_src,
&dst[ib + jb * ld_dst],
ld_dst);
}
break;
case 6:
for (jb = 0; jb + 8 <= N; jb += 8) {
transpose_kernel_mxn_large_sve<6>(
8, &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
}
if (jb < N) {
transpose_kernel_mxn_large_sve<6>(
N - jb,
&src[ib * ld_src + jb],
ld_src,
&dst[ib + jb * ld_dst],
ld_dst);
}
break;
case 7:
for (jb = 0; jb + 8 <= N; jb += 8) {
transpose_kernel_mxn_large_sve<7>(
8, &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
}
if (jb < N) {
transpose_kernel_mxn_large_sve<7>(
N - jb,
&src[ib * ld_src + jb],
ld_src,
&dst[ib + jb * ld_dst],
ld_dst);
}
break;
}
}
#endif // HAVE_SVE
} // namespace internal
} // namespace fbgemm