2447 lines
78 KiB
C++
2447 lines
78 KiB
C++
/*
|
|
* Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
* All rights reserved.
|
|
*
|
|
* This source code is licensed under the BSD-style license found in the
|
|
* LICENSE file in the root directory of this source tree.
|
|
*/
|
|
|
|
#if defined(__x86_64__) || defined(__i386__) || \
|
|
(defined(_MSC_VER) && (defined(_M_X64) || defined(_M_IX86)))
|
|
#include <immintrin.h>
|
|
#endif
|
|
#include "./TransposeUtils.h"
|
|
#include "./TransposeUtilsAvx2.h"
|
|
namespace fbgemm {
|
|
|
|
namespace {
|
|
|
|
// 16 * 6 = 96 instructions
|
|
inline void transpose_kernel_16x16_avx512(
|
|
const float* src,
|
|
int64_t ld_src,
|
|
float* dst,
|
|
int64_t ld_dst) {
|
|
// load from src to registers
|
|
// a: a0 a1 a2 a3 a4 a5 a6 a7 a8 a9 a10 a11 a12 a13 a14 a15
|
|
// b: b0 b1 b2 b3 b4 b5 b6 b7 b8 b9 b10 b11 b12 b13 b14 b15
|
|
// c: c0 c1 c2 c3 c4 c5 c6 c7 c8 c9 c10 c11 c12 c13 c14 c15
|
|
// d: d0 d1 d2 d3 d4 d5 d6 d7 d8 d9 d10 d11 d12 d13 d14 d15
|
|
// e: e0 e1 e2 e3 e4 e5 e6 e7 e8 e9 e10 e11 e12 e13 e14 e15
|
|
// f: f0 f1 f2 f3 f4 f5 f6 f7 f8 f9 f10 f11 f12 f13 f14 f15
|
|
// g: g0 g1 g2 g3 g4 g5 g6 g7 g8 g9 g10 g11 g12 g13 g14 g15
|
|
// h: h0 h1 h2 h3 h4 h5 h6 h7 h8 h9 h10 h11 h12 h13 h14 h15
|
|
// i: i0 i1 i2 i3 i4 i5 i6 i7 i8 i9 i10 i11 i12 i13 i14 i15
|
|
// j: j0 j1 j2 j3 j4 j5 j6 j7 j8 j9 j10 j11 j12 j13 j14 j15
|
|
// k: k0 k1 k2 k3 k4 k5 k6 k7 k8 k9 k10 k11 k12 k13 k14 k15
|
|
// l: l0 l1 l2 l3 l4 l5 l6 l7 l8 l9 l10 l11 l12 l13 l14 l15
|
|
// m: m0 m1 m2 m3 m4 m5 m6 m7 m8 m9 m10 m11 m12 m13 m14 m15
|
|
// n: n0 n1 n2 n3 n4 n5 n6 n7 n8 n9 n10 n11 n12 n13 n14 n15
|
|
// o: o0 o1 o2 o3 o4 o5 o6 o7 o8 o9 o10 o11 o12 o13 o14 o15
|
|
// p: p0 p1 p2 p3 p4 p5 p6 p7 p8 p9 p10 p11 p12 p13 p14 p15
|
|
__m512 a = _mm512_loadu_ps(&src[0 * ld_src]);
|
|
__m512 b = _mm512_loadu_ps(&src[1 * ld_src]);
|
|
__m512 c = _mm512_loadu_ps(&src[2 * ld_src]);
|
|
__m512 d = _mm512_loadu_ps(&src[3 * ld_src]);
|
|
__m512 e = _mm512_loadu_ps(&src[4 * ld_src]);
|
|
__m512 f = _mm512_loadu_ps(&src[5 * ld_src]);
|
|
__m512 g = _mm512_loadu_ps(&src[6 * ld_src]);
|
|
__m512 h = _mm512_loadu_ps(&src[7 * ld_src]);
|
|
__m512 i = _mm512_loadu_ps(&src[8 * ld_src]);
|
|
__m512 j = _mm512_loadu_ps(&src[9 * ld_src]);
|
|
__m512 k = _mm512_loadu_ps(&src[10 * ld_src]);
|
|
__m512 l = _mm512_loadu_ps(&src[11 * ld_src]);
|
|
__m512 m = _mm512_loadu_ps(&src[12 * ld_src]);
|
|
__m512 n = _mm512_loadu_ps(&src[13 * ld_src]);
|
|
__m512 o = _mm512_loadu_ps(&src[14 * ld_src]);
|
|
__m512 p = _mm512_loadu_ps(&src[15 * ld_src]);
|
|
|
|
__m512 ta, tb, tc, td, te, tf, tg, th, ti, tj, tk, tl, tm, tn, to, tq;
|
|
// unpacking and interleaving 32-bit elements
|
|
// a0 b0 a1 b1 a4 b4 a5 b5 a8 b8 a9 b9 a12 b12 a13 b13
|
|
// a2 b2 a3 b3 a6 b6 a7 b7 a10 b10 a11 b11 a14 b14 a15 b15
|
|
// c0 d0 c1 d1 ...
|
|
// c2 d2 c3 d3 ...
|
|
// e0 f0 e1 f1 ...
|
|
// e2 f2 e3 f3 ...
|
|
// g0 h0 g1 h1 ...
|
|
// g2 h2 g3 h3 ...
|
|
// i0 ...
|
|
// i2 ...
|
|
// k0 ...
|
|
// k2 ...
|
|
// m0 ...
|
|
// m2 ...
|
|
// o0 ...
|
|
// o1 ...
|
|
ta = _mm512_unpacklo_ps(a, b);
|
|
tb = _mm512_unpackhi_ps(a, b);
|
|
tc = _mm512_unpacklo_ps(c, d);
|
|
td = _mm512_unpackhi_ps(c, d);
|
|
te = _mm512_unpacklo_ps(e, f);
|
|
tf = _mm512_unpackhi_ps(e, f);
|
|
tg = _mm512_unpacklo_ps(g, h);
|
|
th = _mm512_unpackhi_ps(g, h);
|
|
ti = _mm512_unpacklo_ps(i, j);
|
|
tj = _mm512_unpackhi_ps(i, j);
|
|
tk = _mm512_unpacklo_ps(k, l);
|
|
tl = _mm512_unpackhi_ps(k, l);
|
|
tm = _mm512_unpacklo_ps(m, n);
|
|
tn = _mm512_unpackhi_ps(m, n);
|
|
to = _mm512_unpacklo_ps(o, p);
|
|
tq = _mm512_unpackhi_ps(o, p);
|
|
|
|
// unpacking and interleaving 64-bit elements
|
|
// a0 b0 c0 d0 a4 b4 c4 d4 a8 b8 c8 d8 a12 b12 c12 d12
|
|
// a1 b1 c1 d1 ...
|
|
// a2 b2 c2 d2 ...
|
|
// a3 b3 c3 d3 ...
|
|
// e0 f0 g0 h0 e4 f4 g4 h4 e8 f8 g8 h8 e12 f12 g12 h12
|
|
// e1 f1 g1 h1 ...
|
|
// e2 f2 g2 h2 ...
|
|
// e3 f3 g3 h3 ...
|
|
// i0 j0 k0 l0 ...
|
|
// i1 j1 k1 l1 ...
|
|
// i2 j2 k2 l2 ...
|
|
// i3 j3 k3 l3 ...
|
|
// m0 n0 o0 p0 ...
|
|
// m1 n1 o1 p1 ...
|
|
// m2 n2 o2 p2 ...
|
|
// m3 n3 o3 p3 ...
|
|
a = _mm512_castpd_ps(
|
|
_mm512_unpacklo_pd(_mm512_castps_pd(ta), _mm512_castps_pd(tc)));
|
|
b = _mm512_castpd_ps(
|
|
_mm512_unpackhi_pd(_mm512_castps_pd(ta), _mm512_castps_pd(tc)));
|
|
c = _mm512_castpd_ps(
|
|
_mm512_unpacklo_pd(_mm512_castps_pd(tb), _mm512_castps_pd(td)));
|
|
d = _mm512_castpd_ps(
|
|
_mm512_unpackhi_pd(_mm512_castps_pd(tb), _mm512_castps_pd(td)));
|
|
e = _mm512_castpd_ps(
|
|
_mm512_unpacklo_pd(_mm512_castps_pd(te), _mm512_castps_pd(tg)));
|
|
f = _mm512_castpd_ps(
|
|
_mm512_unpackhi_pd(_mm512_castps_pd(te), _mm512_castps_pd(tg)));
|
|
g = _mm512_castpd_ps(
|
|
_mm512_unpacklo_pd(_mm512_castps_pd(tf), _mm512_castps_pd(th)));
|
|
h = _mm512_castpd_ps(
|
|
_mm512_unpackhi_pd(_mm512_castps_pd(tf), _mm512_castps_pd(th)));
|
|
i = _mm512_castpd_ps(
|
|
_mm512_unpacklo_pd(_mm512_castps_pd(ti), _mm512_castps_pd(tk)));
|
|
j = _mm512_castpd_ps(
|
|
_mm512_unpackhi_pd(_mm512_castps_pd(ti), _mm512_castps_pd(tk)));
|
|
k = _mm512_castpd_ps(
|
|
_mm512_unpacklo_pd(_mm512_castps_pd(tj), _mm512_castps_pd(tl)));
|
|
l = _mm512_castpd_ps(
|
|
_mm512_unpackhi_pd(_mm512_castps_pd(tj), _mm512_castps_pd(tl)));
|
|
m = _mm512_castpd_ps(
|
|
_mm512_unpacklo_pd(_mm512_castps_pd(tm), _mm512_castps_pd(to)));
|
|
n = _mm512_castpd_ps(
|
|
_mm512_unpackhi_pd(_mm512_castps_pd(tm), _mm512_castps_pd(to)));
|
|
o = _mm512_castpd_ps(
|
|
_mm512_unpacklo_pd(_mm512_castps_pd(tn), _mm512_castps_pd(tq)));
|
|
p = _mm512_castpd_ps(
|
|
_mm512_unpackhi_pd(_mm512_castps_pd(tn), _mm512_castps_pd(tq)));
|
|
|
|
// shuffle 128-bits (composed of 4 32-bit elements)
|
|
// a0 b0 c0 d0 a8 b8 c8 d8 e0 f0 g0 h0 e8 f8 g8 h8
|
|
// a1 b1 c1 d1 ...
|
|
// a2 b2 c2 d2 ...
|
|
// a3 b3 c3 d3 ...
|
|
// a4 b4 c4 d4 ...
|
|
// a5 b5 c5 d5 ...
|
|
// a6 b6 c6 d6 ...
|
|
// a7 b7 c7 d7 ...
|
|
// i0 j0 k0 l0 i8 j8 k8 l8 m0 n0 o0 p0 m8 n8 o8 p8
|
|
// i1 j1 k1 l1 ...
|
|
// i2 j2 k2 l2 ...
|
|
// i3 j3 k3 l3 ...
|
|
// i4 j4 k4 l4 ...
|
|
// i5 j5 k5 l5 ...
|
|
// i6 j6 k6 l6 ...
|
|
// i7 j7 k7 l7 ...
|
|
ta = _mm512_shuffle_f32x4(a, e, 0x88);
|
|
tb = _mm512_shuffle_f32x4(b, f, 0x88);
|
|
tc = _mm512_shuffle_f32x4(c, g, 0x88);
|
|
td = _mm512_shuffle_f32x4(d, h, 0x88);
|
|
te = _mm512_shuffle_f32x4(a, e, 0xdd);
|
|
tf = _mm512_shuffle_f32x4(b, f, 0xdd);
|
|
tg = _mm512_shuffle_f32x4(c, g, 0xdd);
|
|
th = _mm512_shuffle_f32x4(d, h, 0xdd);
|
|
ti = _mm512_shuffle_f32x4(i, m, 0x88);
|
|
tj = _mm512_shuffle_f32x4(j, n, 0x88);
|
|
tk = _mm512_shuffle_f32x4(k, o, 0x88);
|
|
tl = _mm512_shuffle_f32x4(l, p, 0x88);
|
|
tm = _mm512_shuffle_f32x4(i, m, 0xdd);
|
|
tn = _mm512_shuffle_f32x4(j, n, 0xdd);
|
|
to = _mm512_shuffle_f32x4(k, o, 0xdd);
|
|
tq = _mm512_shuffle_f32x4(l, p, 0xdd);
|
|
|
|
// shuffle 128-bits (composed of 4 32-bit elements)
|
|
// a0 b0 c0 d0 ... o0
|
|
// a1 b1 c1 d1 ... o1
|
|
// a2 b2 c2 d2 ... o2
|
|
// a3 b3 c3 d3 ... o3
|
|
// a4 ...
|
|
// a5 ...
|
|
// a6 ...
|
|
// a7 ...
|
|
// a8 ...
|
|
// a9 ...
|
|
// a10 ...
|
|
// a11 ...
|
|
// a12 ...
|
|
// a13 ...
|
|
// a14 ...
|
|
// a15 b15 c15 d15 ... o15
|
|
a = _mm512_shuffle_f32x4(ta, ti, 0x88);
|
|
b = _mm512_shuffle_f32x4(tb, tj, 0x88);
|
|
c = _mm512_shuffle_f32x4(tc, tk, 0x88);
|
|
d = _mm512_shuffle_f32x4(td, tl, 0x88);
|
|
e = _mm512_shuffle_f32x4(te, tm, 0x88);
|
|
f = _mm512_shuffle_f32x4(tf, tn, 0x88);
|
|
g = _mm512_shuffle_f32x4(tg, to, 0x88);
|
|
h = _mm512_shuffle_f32x4(th, tq, 0x88);
|
|
i = _mm512_shuffle_f32x4(ta, ti, 0xdd);
|
|
j = _mm512_shuffle_f32x4(tb, tj, 0xdd);
|
|
k = _mm512_shuffle_f32x4(tc, tk, 0xdd);
|
|
l = _mm512_shuffle_f32x4(td, tl, 0xdd);
|
|
m = _mm512_shuffle_f32x4(te, tm, 0xdd);
|
|
n = _mm512_shuffle_f32x4(tf, tn, 0xdd);
|
|
o = _mm512_shuffle_f32x4(tg, to, 0xdd);
|
|
p = _mm512_shuffle_f32x4(th, tq, 0xdd);
|
|
|
|
// store from registers to dst
|
|
_mm512_storeu_ps(&dst[0 * ld_dst], a);
|
|
_mm512_storeu_ps(&dst[1 * ld_dst], b);
|
|
_mm512_storeu_ps(&dst[2 * ld_dst], c);
|
|
_mm512_storeu_ps(&dst[3 * ld_dst], d);
|
|
_mm512_storeu_ps(&dst[4 * ld_dst], e);
|
|
_mm512_storeu_ps(&dst[5 * ld_dst], f);
|
|
_mm512_storeu_ps(&dst[6 * ld_dst], g);
|
|
_mm512_storeu_ps(&dst[7 * ld_dst], h);
|
|
_mm512_storeu_ps(&dst[8 * ld_dst], i);
|
|
_mm512_storeu_ps(&dst[9 * ld_dst], j);
|
|
_mm512_storeu_ps(&dst[10 * ld_dst], k);
|
|
_mm512_storeu_ps(&dst[11 * ld_dst], l);
|
|
_mm512_storeu_ps(&dst[12 * ld_dst], m);
|
|
_mm512_storeu_ps(&dst[13 * ld_dst], n);
|
|
_mm512_storeu_ps(&dst[14 * ld_dst], o);
|
|
_mm512_storeu_ps(&dst[15 * ld_dst], p);
|
|
}
|
|
|
|
// kernel for transposing mxn where m, n <= 16
|
|
// M + (M + 1) / 2 * 2 + (M + 3) / 4 * 4 + (M + 7) / 8 * 8 + 2 * N instructions
|
|
template <int M>
|
|
void transpose_kernel_mxn_avx512(
|
|
int N,
|
|
const float* src,
|
|
int64_t ld_src,
|
|
float* dst,
|
|
int64_t ld_dst) {
|
|
// load from src to registers
|
|
__mmask16 src_mask = (1 << N) - 1;
|
|
__m512 input[16];
|
|
int i;
|
|
for (i = 0; i < M; ++i) {
|
|
input[i] = _mm512_maskz_loadu_ps(src_mask, &src[i * ld_src]);
|
|
}
|
|
for (; i < 16; ++i) {
|
|
// Not really needed but to avoid uninitialized variable warning.
|
|
// Shouldn't be much overhead because xor can be executed in parallel with
|
|
// other instructions.
|
|
input[i] = _mm512_setzero_ps();
|
|
}
|
|
|
|
// unpacking and interleaving 32-bit elements
|
|
__m512 temp[16];
|
|
for (i = 0; i < (M + 1) / 2; ++i) {
|
|
temp[2 * i] = _mm512_unpacklo_ps(input[2 * i], input[2 * i + 1]);
|
|
temp[2 * i + 1] = _mm512_unpackhi_ps(input[2 * i], input[2 * i + 1]);
|
|
}
|
|
for (i = i * 2; i < 16; ++i) {
|
|
temp[i] = _mm512_setzero_ps();
|
|
}
|
|
|
|
// unpacking and interleaving 64-bit elements
|
|
for (i = 0; i < (M + 3) / 4; ++i) {
|
|
input[4 * i] = _mm512_castpd_ps(_mm512_unpacklo_pd(
|
|
_mm512_castps_pd(temp[4 * i]), _mm512_castps_pd(temp[4 * i + 2])));
|
|
input[4 * i + 1] = _mm512_castpd_ps(_mm512_unpackhi_pd(
|
|
_mm512_castps_pd(temp[4 * i]), _mm512_castps_pd(temp[4 * i + 2])));
|
|
input[4 * i + 2] = _mm512_castpd_ps(_mm512_unpacklo_pd(
|
|
_mm512_castps_pd(temp[4 * i + 1]), _mm512_castps_pd(temp[4 * i + 3])));
|
|
input[4 * i + 3] = _mm512_castpd_ps(_mm512_unpackhi_pd(
|
|
_mm512_castps_pd(temp[4 * i + 1]), _mm512_castps_pd(temp[4 * i + 3])));
|
|
}
|
|
|
|
// shuffle 128-bits (composed of 4 32-bit elements)
|
|
for (i = 0; i < (M + 7) / 8; ++i) {
|
|
temp[8 * i] = _mm512_shuffle_f32x4(input[8 * i], input[8 * i + 4], 0x88);
|
|
temp[8 * i + 1] =
|
|
_mm512_shuffle_f32x4(input[8 * i + 1], input[8 * i + 5], 0x88);
|
|
temp[8 * i + 2] =
|
|
_mm512_shuffle_f32x4(input[8 * i + 2], input[8 * i + 6], 0x88);
|
|
temp[8 * i + 3] =
|
|
_mm512_shuffle_f32x4(input[8 * i + 3], input[8 * i + 7], 0x88);
|
|
temp[8 * i + 4] =
|
|
_mm512_shuffle_f32x4(input[8 * i], input[8 * i + 4], 0xdd);
|
|
temp[8 * i + 5] =
|
|
_mm512_shuffle_f32x4(input[8 * i + 1], input[8 * i + 5], 0xdd);
|
|
temp[8 * i + 6] =
|
|
_mm512_shuffle_f32x4(input[8 * i + 2], input[8 * i + 6], 0xdd);
|
|
temp[8 * i + 7] =
|
|
_mm512_shuffle_f32x4(input[8 * i + 3], input[8 * i + 7], 0xdd);
|
|
}
|
|
|
|
// store from registers to dst
|
|
__mmask16 dst_mask = (1 << M) - 1;
|
|
for (i = 0; i < N; ++i) {
|
|
if (i < 8) {
|
|
input[i] = _mm512_shuffle_f32x4(temp[i], temp[8 + i], 0x88);
|
|
} else {
|
|
input[i] = _mm512_shuffle_f32x4(temp[i - 8], temp[i], 0xdd);
|
|
}
|
|
_mm512_mask_storeu_ps(&dst[i * ld_dst], dst_mask, input[i]);
|
|
}
|
|
}
|
|
|
|
} // namespace
|
|
|
|
namespace internal {
|
|
|
|
template <typename T>
|
|
void transpose_avx512_contiguous_thin(
|
|
const int64_t M,
|
|
const int64_t N,
|
|
const T* src,
|
|
int64_t ld_src,
|
|
T* dst,
|
|
int64_t ld_dst);
|
|
|
|
template <typename T>
|
|
void transpose_avx512_contiguous_wide(
|
|
const int64_t M,
|
|
const int64_t N,
|
|
const T* src,
|
|
int64_t ld_src,
|
|
T* dst,
|
|
int64_t ld_dst);
|
|
|
|
// Permute elements in 128 bit lane
|
|
// e.g., if a 128-bit lane has the following elements:
|
|
// 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15
|
|
//
|
|
// After this function call, it becomes
|
|
// 0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15
|
|
// The same happens with other 3 lanes.
|
|
static inline __m512i permute_row(__m512i row) {
|
|
// clang-format off
|
|
__m256i shuffle_256v0 = _mm256_set_epi8(
|
|
15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0,
|
|
15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0);
|
|
// clang-format on
|
|
__m512i shuffle_512v = _mm512_castsi256_si512(shuffle_256v0);
|
|
row = _mm512_shuffle_epi8(
|
|
row, _mm512_inserti64x4(shuffle_512v, shuffle_256v0, 1));
|
|
return row;
|
|
}
|
|
|
|
static inline void core_transpose_16x32_block_i8(__m512i r[], __m512i u[]) {
|
|
// Result after this operation; Read in conjunction with comments in
|
|
// transpose_16x32_block
|
|
// 00_00 00_01 01_00 01_01 00_04 00_05 01_04 01_05 04_00 04_01 05_00 05_01
|
|
// 04_04 04_05 05_04 05_05
|
|
u[0] = _mm512_unpacklo_epi64(r[0], r[1]);
|
|
// 00_02 00_03 01_02 01_03 00_06 00_07 01_06 01_07 04_02 04_03 05_02 05_03
|
|
// 04_06 04_07 05_06 05_07
|
|
u[1] = _mm512_unpackhi_epi64(r[0], r[1]);
|
|
// 02_00 02_01 03_00 03_01 02_04 02_05 03_04 03_05 06_00 06_01 07_00 07_01
|
|
// 06_04 06_05 07_04 07_05
|
|
u[2] = _mm512_unpacklo_epi64(r[2], r[3]);
|
|
// 02_02 02_03 03_02 03_03 02_06 02_07 03_06 03_07 06_02 06_03 07_02 07_03
|
|
// 06_06 06_07 07_06 07_07
|
|
u[3] = _mm512_unpackhi_epi64(r[2], r[3]);
|
|
// 08_00 08_01 09_00 09_01 08_04 08_05 09_04 09_05 12_00 12_01 13_00 13_01
|
|
// 12_04 12_05 13_04 13_05
|
|
u[4] = _mm512_unpacklo_epi64(r[4], r[5]);
|
|
u[5] = _mm512_unpackhi_epi64(r[4], r[5]);
|
|
u[6] = _mm512_unpacklo_epi64(r[6], r[7]);
|
|
u[7] = _mm512_unpackhi_epi64(r[6], r[7]);
|
|
|
|
// This instruction doesn't exist for epi32 so casting to ps
|
|
// 00_00 01_00 02_00 03_00 00_04 01_04 02_04 03_04 04_00 05_00 06_00 07_00
|
|
// 04_04 05_04 06_04 07_04
|
|
r[0] = _mm512_castps_si512(_mm512_shuffle_ps(
|
|
_mm512_castsi512_ps(u[0]), _mm512_castsi512_ps(u[2]), 0x88));
|
|
// 00_01 01_01 02_01 03_01 00_05 01_05 02_05 03_05 04_01 05_01 06_01 07_01
|
|
// 04_05 05_05 06_05 07_05
|
|
r[1] = _mm512_castps_si512(_mm512_shuffle_ps(
|
|
_mm512_castsi512_ps(u[0]), _mm512_castsi512_ps(u[2]), 0xDD));
|
|
r[2] = _mm512_castps_si512(_mm512_shuffle_ps(
|
|
_mm512_castsi512_ps(u[1]), _mm512_castsi512_ps(u[3]), 0x88));
|
|
r[3] = _mm512_castps_si512(_mm512_shuffle_ps(
|
|
_mm512_castsi512_ps(u[1]), _mm512_castsi512_ps(u[3]), 0xDD));
|
|
// 08_00 09_00 10_00 11_00 08_04 09_04 10_04 11_04 12_00 13_00 14_00 15_00
|
|
// 12_04 13_04 14_04 15_04
|
|
r[4] = _mm512_castps_si512(_mm512_shuffle_ps(
|
|
_mm512_castsi512_ps(u[4]), _mm512_castsi512_ps(u[6]), 0x88));
|
|
r[5] = _mm512_castps_si512(_mm512_shuffle_ps(
|
|
_mm512_castsi512_ps(u[4]), _mm512_castsi512_ps(u[6]), 0xDD));
|
|
r[6] = _mm512_castps_si512(_mm512_shuffle_ps(
|
|
_mm512_castsi512_ps(u[5]), _mm512_castsi512_ps(u[7]), 0x88));
|
|
r[7] = _mm512_castps_si512(_mm512_shuffle_ps(
|
|
_mm512_castsi512_ps(u[5]), _mm512_castsi512_ps(u[7]), 0xDD));
|
|
|
|
// permute among 128-bit lanes
|
|
r[0] = permute_row(r[0]);
|
|
r[1] = permute_row(r[1]);
|
|
r[2] = permute_row(r[2]);
|
|
r[3] = permute_row(r[3]);
|
|
r[4] = permute_row(r[4]);
|
|
r[5] = permute_row(r[5]);
|
|
r[6] = permute_row(r[6]);
|
|
r[7] = permute_row(r[7]);
|
|
|
|
__m512i const1 = _mm512_set_epi32(
|
|
27, 19, 11, 3, 26, 18, 10, 2, 25, 17, 9, 1, 24, 16, 8, 0);
|
|
__m512i const2 = _mm512_set_epi32(
|
|
31, 23, 15, 7, 30, 22, 14, 6, 29, 21, 13, 5, 28, 20, 12, 4);
|
|
|
|
// merge 128-bit values from two regs
|
|
u[0] = _mm512_permutex2var_epi32(r[0], const1, r[4]);
|
|
u[1] = _mm512_permutex2var_epi32(r[0], const2, r[4]);
|
|
u[2] = _mm512_permutex2var_epi32(r[1], const1, r[5]);
|
|
u[3] = _mm512_permutex2var_epi32(r[1], const2, r[5]);
|
|
u[4] = _mm512_permutex2var_epi32(r[2], const1, r[6]);
|
|
u[5] = _mm512_permutex2var_epi32(r[2], const2, r[6]);
|
|
u[6] = _mm512_permutex2var_epi32(r[3], const1, r[7]);
|
|
u[7] = _mm512_permutex2var_epi32(r[3], const2, r[7]);
|
|
}
|
|
|
|
static inline void core_transpose_16x16_block(__m512i r[], __m512i u[]) {
|
|
// a0a1 b0b1 a2a3 b2b3 a8a9 b8b9 a10a11 b10b11 e0e1 f0f1 e2e3 f2f3 e8e9 f8f9
|
|
// e10e11 f10f11
|
|
u[0] = _mm512_unpacklo_epi32(r[0], r[1]);
|
|
// a4a5 b4b5 a6a7 b6b7 a12a13 b12b13 a14a15 b14b15 e4e5 f4f5 e6e7 f6f7
|
|
// e12e13 f12f13 e14e15 f14f15
|
|
u[1] = _mm512_unpackhi_epi32(r[0], r[1]);
|
|
// c0c1 d0d1 c2c3 d2d3 c8c9 d8d9 c10c11 d10d11 g0g1 h0h1 g2g3 h2h3 g8g9 h8h9
|
|
// g10g11 h10h11
|
|
u[2] = _mm512_unpacklo_epi32(r[2], r[3]);
|
|
// c4c5 d4b5 c6c7 d6b7 c12c13 d12d13 c14c15 d14d15 g4g5 h4h5 g6g7 h6h7
|
|
// g12g13 h12h13 g14g15 h14h15
|
|
u[3] = _mm512_unpackhi_epi32(r[2], r[3]);
|
|
// i j m n
|
|
u[4] = _mm512_unpacklo_epi32(r[4], r[5]);
|
|
u[5] = _mm512_unpackhi_epi32(r[4], r[5]);
|
|
// k l o p
|
|
u[6] = _mm512_unpacklo_epi32(r[6], r[7]);
|
|
u[7] = _mm512_unpackhi_epi32(r[6], r[7]);
|
|
|
|
// a0a1 b0b1 c0c1 d0d1 a8a9 b8b9 c8c9 d8d9 e0e1 f0f1 g0g1 h0h1 e8e9 f8f9 g8g9
|
|
// h8h9
|
|
r[0] = _mm512_unpacklo_epi64(u[0], u[2]);
|
|
// a2a3 b2b3 c2c3 d2d3 a10a11 b10b11 c10c11 d10d11 e2e3 f2f3 g2g3 h2h3 e10e11
|
|
// f10f11 g10g11 h10h11
|
|
r[1] = _mm512_unpackhi_epi64(u[0], u[2]);
|
|
// a4a5 b4b5 c4c5 d4b5 a12a13 b12b13 c12c13 d12d13
|
|
r[2] = _mm512_unpacklo_epi64(u[1], u[3]);
|
|
// a6a7 b6b7 c6c7 d6b7 a14a15 b14b15 c14c15 d14d15
|
|
r[3] = _mm512_unpackhi_epi64(u[1], u[3]);
|
|
// i j k l m n o p
|
|
r[4] = _mm512_unpacklo_epi64(u[4], u[6]);
|
|
r[5] = _mm512_unpackhi_epi64(u[4], u[6]);
|
|
r[6] = _mm512_unpacklo_epi64(u[5], u[7]);
|
|
r[7] = _mm512_unpackhi_epi64(u[5], u[7]);
|
|
|
|
__m512i const1 = _mm512_set_epi32(
|
|
0x00370035,
|
|
0x00330031,
|
|
0x00270025,
|
|
0x00230021,
|
|
0x00170015,
|
|
0x00130011,
|
|
0x00070005,
|
|
0x00030001,
|
|
0x00360034,
|
|
0x00320030,
|
|
0x00260024,
|
|
0x00220020,
|
|
0x00160014,
|
|
0x00120010,
|
|
0x00060004,
|
|
0x00020000);
|
|
__m512i const2 = _mm512_set_epi32(
|
|
0x003f003d,
|
|
0x003b0039,
|
|
0x002f002d,
|
|
0x002b0029,
|
|
0x001f001d,
|
|
0x001b0019,
|
|
0x000f000d,
|
|
0x000b0009,
|
|
0x003e003c,
|
|
0x003a0038,
|
|
0x002e002c,
|
|
0x002a0028,
|
|
0x001e001c,
|
|
0x001a0018,
|
|
0x000e000c,
|
|
0x000a0008);
|
|
|
|
// merge values from two regs
|
|
u[0] = _mm512_permutex2var_epi16(r[0], const1, r[4]); // 0-- 1--
|
|
u[4] = _mm512_permutex2var_epi16(r[0], const2, r[4]); // 8-- 9--
|
|
u[2] = _mm512_permutex2var_epi16(r[2], const1, r[6]); // 4-- 5--
|
|
u[6] = _mm512_permutex2var_epi16(r[2], const2, r[6]); // 12-- 13--
|
|
u[1] = _mm512_permutex2var_epi16(r[1], const1, r[5]); // 2-- 3--
|
|
u[5] = _mm512_permutex2var_epi16(r[1], const2, r[5]); // 10-- 11--
|
|
u[3] = _mm512_permutex2var_epi16(r[3], const1, r[7]); // 6-- 7--
|
|
u[7] = _mm512_permutex2var_epi16(r[3], const2, r[7]); // 14-- 15--
|
|
}
|
|
|
|
static inline void load_with_remainders_i16(
|
|
const uint16_t* src,
|
|
int64_t ld_src,
|
|
__m512i r[],
|
|
int mrem,
|
|
int nrem) {
|
|
__m512i t[16];
|
|
if (nrem < 16) {
|
|
__mmask32 mask_nrem_v = (1ULL << nrem) - 1;
|
|
for (int i = 0; i < mrem; ++i) {
|
|
// mask load
|
|
t[i] = _mm512_maskz_loadu_epi16(mask_nrem_v, src + i * ld_src);
|
|
}
|
|
} else {
|
|
for (int i = 0; i < mrem; ++i) {
|
|
// normal load
|
|
t[i] = _mm512_castsi256_si512(_mm256_loadu_si256(
|
|
reinterpret_cast<const __m256i*>(src + i * ld_src)));
|
|
}
|
|
}
|
|
r[0] = _mm512_inserti64x4(t[0], _mm512_castsi512_si256(t[4]), 0x01);
|
|
r[1] = _mm512_inserti64x4(t[1], _mm512_castsi512_si256(t[5]), 0x01);
|
|
r[2] = _mm512_inserti64x4(t[2], _mm512_castsi512_si256(t[6]), 0x01);
|
|
r[3] = _mm512_inserti64x4(t[3], _mm512_castsi512_si256(t[7]), 0x01);
|
|
r[4] = _mm512_inserti64x4(t[8], _mm512_castsi512_si256(t[12]), 0x01);
|
|
r[5] = _mm512_inserti64x4(t[9], _mm512_castsi512_si256(t[13]), 0x01);
|
|
r[6] = _mm512_inserti64x4(t[10], _mm512_castsi512_si256(t[14]), 0x01);
|
|
r[7] = _mm512_inserti64x4(t[11], _mm512_castsi512_si256(t[15]), 0x01);
|
|
}
|
|
|
|
static inline void load_with_remainders_i8(
|
|
const uint8_t* src,
|
|
int64_t ld_src,
|
|
__m512i r[],
|
|
int mrem,
|
|
int nrem) {
|
|
__m512i t[16];
|
|
if (nrem < 32) {
|
|
__mmask64 mask_nrem_v = (1ULL << nrem) - 1;
|
|
for (int i = 0; i < mrem; ++i) {
|
|
// mask load
|
|
t[i] = _mm512_maskz_loadu_epi8(mask_nrem_v, src + i * ld_src);
|
|
}
|
|
} else {
|
|
for (int i = 0; i < mrem; ++i) {
|
|
// normal load
|
|
t[i] = _mm512_castsi256_si512(_mm256_loadu_si256(
|
|
reinterpret_cast<const __m256i*>(src + i * ld_src)));
|
|
}
|
|
}
|
|
r[0] = _mm512_inserti64x4(t[0], _mm512_castsi512_si256(t[4]), 0x01);
|
|
r[1] = _mm512_inserti64x4(t[1], _mm512_castsi512_si256(t[5]), 0x01);
|
|
r[2] = _mm512_inserti64x4(t[2], _mm512_castsi512_si256(t[6]), 0x01);
|
|
r[3] = _mm512_inserti64x4(t[3], _mm512_castsi512_si256(t[7]), 0x01);
|
|
r[4] = _mm512_inserti64x4(t[8], _mm512_castsi512_si256(t[12]), 0x01);
|
|
r[5] = _mm512_inserti64x4(t[9], _mm512_castsi512_si256(t[13]), 0x01);
|
|
r[6] = _mm512_inserti64x4(t[10], _mm512_castsi512_si256(t[14]), 0x01);
|
|
r[7] = _mm512_inserti64x4(t[11], _mm512_castsi512_si256(t[15]), 0x01);
|
|
}
|
|
|
|
static inline void store_with_remainders_i16(
|
|
uint16_t* dst,
|
|
int64_t ld_dst,
|
|
__m512i u[],
|
|
int mrem,
|
|
int nrem) {
|
|
if (mrem < 16) {
|
|
__mmask32 mask_mrem_v = (1ULL << mrem) - 1;
|
|
int i = 0;
|
|
|
|
for (; i < nrem / 2 * 2; i += 2) {
|
|
// mask store
|
|
int reg_idx = i / 2;
|
|
_mm512_mask_storeu_epi16(
|
|
dst + (i + 0) * ld_dst,
|
|
mask_mrem_v,
|
|
_mm512_castsi256_si512(_mm512_extracti32x8_epi32(u[reg_idx], 0x0)));
|
|
_mm512_mask_storeu_epi16(
|
|
dst + (i + 1) * ld_dst,
|
|
mask_mrem_v,
|
|
_mm512_castsi256_si512(_mm512_extracti32x8_epi32(u[reg_idx], 0x1)));
|
|
}
|
|
if (i < nrem) {
|
|
int reg_idx = i / 2;
|
|
_mm512_mask_storeu_epi16(
|
|
dst + (i + 0) * ld_dst,
|
|
mask_mrem_v,
|
|
_mm512_castsi256_si512(_mm512_extracti32x8_epi32(u[reg_idx], 0x0)));
|
|
}
|
|
} else {
|
|
int i = 0;
|
|
for (; i < nrem / 2 * 2; i += 2) {
|
|
// normal store
|
|
int reg_idx = i / 2;
|
|
_mm256_storeu_si256(
|
|
reinterpret_cast<__m256i*>(dst + (i + 0) * ld_dst),
|
|
_mm512_extracti32x8_epi32(u[reg_idx], 0x0));
|
|
_mm256_storeu_si256(
|
|
reinterpret_cast<__m256i*>(dst + (i + 1) * ld_dst),
|
|
_mm512_extracti32x8_epi32(u[reg_idx], 0x1));
|
|
}
|
|
if (i < nrem) {
|
|
int reg_idx = i / 2;
|
|
_mm256_storeu_si256(
|
|
reinterpret_cast<__m256i*>(dst + (i + 0) * ld_dst),
|
|
_mm512_extracti32x8_epi32(u[reg_idx], 0x0));
|
|
}
|
|
}
|
|
}
|
|
|
|
static inline void store_with_remainders_i8(
|
|
uint8_t* dst,
|
|
int64_t ld_dst,
|
|
__m512i u[],
|
|
int mrem,
|
|
int nrem) {
|
|
if (mrem < 16) {
|
|
__mmask64 mask_mrem_v = (1ULL << mrem) - 1;
|
|
int i = 0;
|
|
for (; i < nrem / 4 * 4; i += 4) {
|
|
// mask store
|
|
// we need 0, 4, 8, 16 => 0, 2, 4, 6
|
|
// and 16, 20, 24, 28 => 1, 3, 5, 7
|
|
// See stores for non-rem case
|
|
int reg_idx = i / 16 + 2 * ((i % 16) / 4);
|
|
_mm512_mask_storeu_epi8(
|
|
dst + (i + 0) * ld_dst,
|
|
mask_mrem_v,
|
|
_mm512_castsi128_si512(_mm512_extracti32x4_epi32(u[reg_idx], 0x0)));
|
|
_mm512_mask_storeu_epi8(
|
|
dst + (i + 1) * ld_dst,
|
|
mask_mrem_v,
|
|
_mm512_castsi128_si512(_mm512_extracti32x4_epi32(u[reg_idx], 0x1)));
|
|
_mm512_mask_storeu_epi8(
|
|
dst + (i + 2) * ld_dst,
|
|
mask_mrem_v,
|
|
_mm512_castsi128_si512(_mm512_extracti32x4_epi32(u[reg_idx], 0x2)));
|
|
_mm512_mask_storeu_epi8(
|
|
dst + (i + 3) * ld_dst,
|
|
mask_mrem_v,
|
|
_mm512_castsi128_si512(_mm512_extracti32x4_epi32(u[reg_idx], 0x3)));
|
|
}
|
|
int rem = nrem - i;
|
|
int reg_rem_idx = i / 16 + 2 * ((i % 16) / 4);
|
|
switch (rem) {
|
|
case 1:
|
|
_mm512_mask_storeu_epi8(
|
|
dst + (i + 0) * ld_dst,
|
|
mask_mrem_v,
|
|
_mm512_castsi128_si512(
|
|
_mm512_extracti32x4_epi32(u[reg_rem_idx], 0x0)));
|
|
break;
|
|
case 2:
|
|
_mm512_mask_storeu_epi8(
|
|
dst + (i + 0) * ld_dst,
|
|
mask_mrem_v,
|
|
_mm512_castsi128_si512(
|
|
_mm512_extracti32x4_epi32(u[reg_rem_idx], 0x0)));
|
|
_mm512_mask_storeu_epi8(
|
|
dst + (i + 1) * ld_dst,
|
|
mask_mrem_v,
|
|
_mm512_castsi128_si512(
|
|
_mm512_extracti32x4_epi32(u[reg_rem_idx], 0x1)));
|
|
break;
|
|
case 3:
|
|
_mm512_mask_storeu_epi8(
|
|
dst + (i + 0) * ld_dst,
|
|
mask_mrem_v,
|
|
_mm512_castsi128_si512(
|
|
_mm512_extracti32x4_epi32(u[reg_rem_idx], 0x0)));
|
|
_mm512_mask_storeu_epi8(
|
|
dst + (i + 1) * ld_dst,
|
|
mask_mrem_v,
|
|
_mm512_castsi128_si512(
|
|
_mm512_extracti32x4_epi32(u[reg_rem_idx], 0x1)));
|
|
_mm512_mask_storeu_epi8(
|
|
dst + (i + 2) * ld_dst,
|
|
mask_mrem_v,
|
|
_mm512_castsi128_si512(
|
|
_mm512_extracti32x4_epi32(u[reg_rem_idx], 0x2)));
|
|
break;
|
|
default:
|
|
break;
|
|
}
|
|
|
|
} else {
|
|
int i = 0;
|
|
for (; i < nrem / 4 * 4; i += 4) {
|
|
// normal store
|
|
int reg_idx = i / 16 + 2 * ((i % 16) / 4);
|
|
_mm_storeu_si128(
|
|
reinterpret_cast<__m128i*>(dst + (i + 0) * ld_dst),
|
|
_mm512_extracti32x4_epi32(u[reg_idx], 0x0));
|
|
_mm_storeu_si128(
|
|
reinterpret_cast<__m128i*>(dst + (i + 1) * ld_dst),
|
|
_mm512_extracti32x4_epi32(u[reg_idx], 0x1));
|
|
_mm_storeu_si128(
|
|
reinterpret_cast<__m128i*>(dst + (i + 2) * ld_dst),
|
|
_mm512_extracti32x4_epi32(u[reg_idx], 0x2));
|
|
_mm_storeu_si128(
|
|
reinterpret_cast<__m128i*>(dst + (i + 3) * ld_dst),
|
|
_mm512_extracti32x4_epi32(u[reg_idx], 0x3));
|
|
}
|
|
int rem = nrem - i;
|
|
int reg_rem_idx = i / 16 + 2 * ((i % 16) / 4);
|
|
switch (rem) {
|
|
case 1:
|
|
_mm_storeu_si128(
|
|
reinterpret_cast<__m128i*>(dst + (i + 0) * ld_dst),
|
|
_mm512_extracti32x4_epi32(u[reg_rem_idx], 0x0));
|
|
break;
|
|
case 2:
|
|
_mm_storeu_si128(
|
|
reinterpret_cast<__m128i*>(dst + (i + 0) * ld_dst),
|
|
_mm512_extracti32x4_epi32(u[reg_rem_idx], 0x0));
|
|
_mm_storeu_si128(
|
|
reinterpret_cast<__m128i*>(dst + (i + 1) * ld_dst),
|
|
_mm512_extracti32x4_epi32(u[reg_rem_idx], 0x1));
|
|
break;
|
|
case 3:
|
|
_mm_storeu_si128(
|
|
reinterpret_cast<__m128i*>(dst + (i + 0) * ld_dst),
|
|
_mm512_extracti32x4_epi32(u[reg_rem_idx], 0x0));
|
|
_mm_storeu_si128(
|
|
reinterpret_cast<__m128i*>(dst + (i + 1) * ld_dst),
|
|
_mm512_extracti32x4_epi32(u[reg_rem_idx], 0x1));
|
|
_mm_storeu_si128(
|
|
reinterpret_cast<__m128i*>(dst + (i + 2) * ld_dst),
|
|
_mm512_extracti32x4_epi32(u[reg_rem_idx], 0x2));
|
|
break;
|
|
default:
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
static inline void transpose_contiguous_4x16_block(
|
|
const float* src,
|
|
float* dst,
|
|
int64_t ld_src,
|
|
int nrem = 16) {
|
|
__m512i r[4];
|
|
// load
|
|
if (nrem < 16) {
|
|
__mmask16 mask_mrem_v = (1ULL << nrem) - 1;
|
|
r[0] = _mm512_maskz_loadu_epi32(mask_mrem_v, src);
|
|
r[1] = _mm512_maskz_loadu_epi32(mask_mrem_v, src + ld_src);
|
|
r[2] = _mm512_maskz_loadu_epi32(mask_mrem_v, src + 2 * ld_src);
|
|
r[3] = _mm512_maskz_loadu_epi32(mask_mrem_v, src + 3 * ld_src);
|
|
|
|
} else {
|
|
r[0] = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(src));
|
|
r[1] = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(src + ld_src));
|
|
r[2] =
|
|
_mm512_loadu_si512(reinterpret_cast<const __m512i*>(src + 2 * ld_src));
|
|
r[3] =
|
|
_mm512_loadu_si512(reinterpret_cast<const __m512i*>(src + 3 * ld_src));
|
|
}
|
|
// transpose
|
|
// a0b0 a1b1 a4b4 a5b5 a8b8 a9b9 a12b12 a13b13
|
|
// a2b2 a3b3 a6b6 a7b7 a10b10 a11b11 a14b14 a15b15
|
|
// c0d0 c1d1 c4d4 c5d5 c8d8 c9d9 c12d12 c13d13
|
|
// c2d2 c3d3 c6d6 c7d7 c10b10 c11d11 c14d14 c15d15
|
|
__m512i t0 = _mm512_unpacklo_epi32(r[0], r[1]);
|
|
__m512i t1 = _mm512_unpackhi_epi32(r[0], r[1]);
|
|
__m512i t2 = _mm512_unpacklo_epi32(r[2], r[3]);
|
|
__m512i t3 = _mm512_unpackhi_epi32(r[2], r[3]);
|
|
|
|
r[0] = _mm512_unpacklo_epi64(t0, t2);
|
|
r[1] = _mm512_unpackhi_epi64(t0, t2);
|
|
r[2] = _mm512_unpacklo_epi64(t1, t3);
|
|
r[3] = _mm512_unpackhi_epi64(t1, t3);
|
|
|
|
t0 = _mm512_shuffle_i32x4(r[0], r[1], 0x44);
|
|
t1 = _mm512_shuffle_i32x4(r[0], r[1], 0xee);
|
|
t2 = _mm512_shuffle_i32x4(r[2], r[3], 0x44);
|
|
t3 = _mm512_shuffle_i32x4(r[2], r[3], 0xee);
|
|
|
|
r[0] = _mm512_shuffle_i32x4(t0, t2, 0x88);
|
|
r[1] = _mm512_shuffle_i32x4(t0, t2, 0xdd);
|
|
r[2] = _mm512_shuffle_i32x4(t1, t3, 0x88);
|
|
r[3] = _mm512_shuffle_i32x4(t1, t3, 0xdd);
|
|
// store
|
|
int i = 0;
|
|
for (; (i + 1) * 16 <= nrem * 4; i++) {
|
|
// normal store
|
|
_mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + i * 16), r[i]);
|
|
}
|
|
int erem = nrem * 4 - i * 16;
|
|
if (erem > 0) {
|
|
// mask store
|
|
__mmask16 mask_rem_v = (1ULL << erem) - 1;
|
|
_mm512_mask_storeu_epi32(dst + i * 16, mask_rem_v, r[i]);
|
|
}
|
|
}
|
|
|
|
static inline void transpose_contiguous_4x32_block(
|
|
const uint16_t* src,
|
|
uint16_t* dst,
|
|
int64_t ld_src,
|
|
int nrem = 32) {
|
|
__m512i r[4], d[4];
|
|
// load
|
|
if (nrem < 32) {
|
|
__mmask32 mask_mrem_v = (1ULL << nrem) - 1;
|
|
r[0] = _mm512_maskz_loadu_epi16(mask_mrem_v, src);
|
|
r[1] = _mm512_maskz_loadu_epi16(mask_mrem_v, src + ld_src);
|
|
r[2] = _mm512_maskz_loadu_epi16(mask_mrem_v, src + 2 * ld_src);
|
|
r[3] = _mm512_maskz_loadu_epi16(mask_mrem_v, src + 3 * ld_src);
|
|
} else {
|
|
r[0] = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(src));
|
|
r[1] = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(src + ld_src));
|
|
r[2] =
|
|
_mm512_loadu_si512(reinterpret_cast<const __m512i*>(src + 2 * ld_src));
|
|
r[3] =
|
|
_mm512_loadu_si512(reinterpret_cast<const __m512i*>(src + 3 * ld_src));
|
|
}
|
|
// transpose
|
|
d[0] = _mm512_unpacklo_epi16(r[0], r[1]);
|
|
d[1] = _mm512_unpackhi_epi16(r[0], r[1]);
|
|
d[2] = _mm512_unpacklo_epi16(r[2], r[3]);
|
|
d[3] = _mm512_unpackhi_epi16(r[2], r[3]);
|
|
|
|
r[0] = _mm512_unpacklo_epi32(d[0], d[2]);
|
|
r[1] = _mm512_unpackhi_epi32(d[0], d[2]);
|
|
r[2] = _mm512_unpacklo_epi32(d[1], d[3]);
|
|
r[3] = _mm512_unpackhi_epi32(d[1], d[3]);
|
|
|
|
d[0] = _mm512_shuffle_i32x4(r[0], r[1], 0x44);
|
|
d[1] = _mm512_shuffle_i32x4(r[0], r[1], 0xee);
|
|
d[2] = _mm512_shuffle_i32x4(r[2], r[3], 0x44);
|
|
d[3] = _mm512_shuffle_i32x4(r[2], r[3], 0xee);
|
|
|
|
r[0] = _mm512_shuffle_i32x4(d[0], d[2], 0x88);
|
|
r[1] = _mm512_shuffle_i32x4(d[0], d[2], 0xdd);
|
|
r[2] = _mm512_shuffle_i32x4(d[1], d[3], 0x88);
|
|
r[3] = _mm512_shuffle_i32x4(d[1], d[3], 0xdd);
|
|
// store
|
|
int i = 0;
|
|
for (; (i + 1) * 32 <= nrem * 4; i++) {
|
|
_mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + i * 32), r[i]);
|
|
}
|
|
int erem = nrem * 4 - i * 32;
|
|
if (erem > 0) {
|
|
// mask store
|
|
__mmask32 mask_rem_v = (1ULL << erem) - 1;
|
|
_mm512_mask_storeu_epi16(dst + i * 32, mask_rem_v, r[i]);
|
|
}
|
|
}
|
|
|
|
static inline void transpose_contiguous_16x4_block(
|
|
const float* src,
|
|
float* dst,
|
|
int64_t ld_dst,
|
|
int mrem = 16) {
|
|
__m512i r[4], d[4];
|
|
int i = 0;
|
|
for (; (i + 1) * 16 <= mrem * 4; i++) {
|
|
// normal load
|
|
r[i] = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(src + i * 16));
|
|
}
|
|
if (i * 16 < mrem * 4) {
|
|
__mmask16 mask_mrem_v = (1ULL << (mrem * 4 - i * 16)) - 1;
|
|
r[i] = _mm512_maskz_loadu_epi32(mask_mrem_v, src + i * 16);
|
|
}
|
|
|
|
// transpose
|
|
__m512i index1 = _mm512_set_epi32(
|
|
0x0f,
|
|
0x0b,
|
|
0x07,
|
|
0x03,
|
|
0x0e,
|
|
0x0a,
|
|
0x06,
|
|
0x02,
|
|
0x0d,
|
|
0x09,
|
|
0x05,
|
|
0x01,
|
|
0x0c,
|
|
0x08,
|
|
0x04,
|
|
0x00);
|
|
d[0] = _mm512_permutexvar_epi32(index1, r[0]);
|
|
d[1] = _mm512_permutexvar_epi32(index1, r[1]);
|
|
d[2] = _mm512_permutexvar_epi32(index1, r[2]);
|
|
d[3] = _mm512_permutexvar_epi32(index1, r[3]);
|
|
|
|
r[0] = _mm512_shuffle_i32x4(d[0], d[1], 0x44);
|
|
r[1] = _mm512_shuffle_i32x4(d[0], d[1], 0xee);
|
|
r[2] = _mm512_shuffle_i32x4(d[2], d[3], 0x44);
|
|
r[3] = _mm512_shuffle_i32x4(d[2], d[3], 0xee);
|
|
|
|
d[0] = _mm512_shuffle_i32x4(r[0], r[2], 0x88);
|
|
d[1] = _mm512_shuffle_i32x4(r[0], r[2], 0xdd);
|
|
d[2] = _mm512_shuffle_i32x4(r[1], r[3], 0x88);
|
|
d[3] = _mm512_shuffle_i32x4(r[1], r[3], 0xdd);
|
|
|
|
if (mrem < 16) {
|
|
// mask store
|
|
__mmask16 mask_rem_v = (1ULL << mrem) - 1;
|
|
_mm512_mask_storeu_epi32(dst + 0 * ld_dst, mask_rem_v, d[0]);
|
|
_mm512_mask_storeu_epi32(dst + 1 * ld_dst, mask_rem_v, d[1]);
|
|
_mm512_mask_storeu_epi32(dst + 2 * ld_dst, mask_rem_v, d[2]);
|
|
_mm512_mask_storeu_epi32(dst + 3 * ld_dst, mask_rem_v, d[3]);
|
|
} else {
|
|
// normal load
|
|
_mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 0 * ld_dst), d[0]);
|
|
_mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 1 * ld_dst), d[1]);
|
|
_mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 2 * ld_dst), d[2]);
|
|
_mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 3 * ld_dst), d[3]);
|
|
}
|
|
}
|
|
|
|
static inline void transpose_contiguous_16x2_block(
|
|
const float* src,
|
|
float* dst,
|
|
int64_t ld_dst,
|
|
int mrem = 16) {
|
|
__m512i r[2], d[2];
|
|
// Zero out r[] to avoid `may be used uninitialized` compilation error
|
|
r[0] = _mm512_setzero_si512();
|
|
r[1] = _mm512_setzero_si512();
|
|
int i = 0;
|
|
for (; (i + 1) * 16 <= mrem * 2; i++) {
|
|
// normal load
|
|
r[i] = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(src + i * 16));
|
|
}
|
|
if (i * 16 < mrem * 2) {
|
|
__mmask16 mask_mrem_v = (1ULL << (mrem * 2 - i * 16)) - 1;
|
|
r[i] = _mm512_maskz_loadu_epi32(mask_mrem_v, src + i * 16);
|
|
}
|
|
// transpose
|
|
__m512i index1 = _mm512_set_epi32(
|
|
0x1e,
|
|
0x1c,
|
|
0x1a,
|
|
0x18,
|
|
0x16,
|
|
0x14,
|
|
0x12,
|
|
0x10,
|
|
0x0e,
|
|
0x0c,
|
|
0x0a,
|
|
0x08,
|
|
0x06,
|
|
0x04,
|
|
0x02,
|
|
0x00);
|
|
__m512i index2 = _mm512_set_epi32(
|
|
0x1f,
|
|
0x1d,
|
|
0x1b,
|
|
0x19,
|
|
0x17,
|
|
0x15,
|
|
0x13,
|
|
0x11,
|
|
0x0f,
|
|
0x0d,
|
|
0x0b,
|
|
0x09,
|
|
0x07,
|
|
0x05,
|
|
0x03,
|
|
0x01);
|
|
|
|
// a0--p0
|
|
// a1--p1
|
|
d[0] = _mm512_permutex2var_epi32(r[0], index1, r[1]);
|
|
d[1] = _mm512_permutex2var_epi32(r[0], index2, r[1]);
|
|
|
|
// store
|
|
if (mrem < 16) {
|
|
__mmask16 mask_rem_v = (1ULL << mrem) - 1;
|
|
// mask store
|
|
_mm512_mask_storeu_epi32(dst, mask_rem_v, d[0]);
|
|
_mm512_mask_storeu_epi32(dst + ld_dst, mask_rem_v, d[1]);
|
|
} else {
|
|
// normal store
|
|
_mm512_storeu_si512(dst, d[0]);
|
|
_mm512_storeu_si512(dst + ld_dst, d[1]);
|
|
}
|
|
}
|
|
|
|
static inline void transpose_contiguous_64x4_block(
|
|
const uint8_t* src,
|
|
uint8_t* dst,
|
|
int64_t ld_dst,
|
|
int mrem = 64) {
|
|
__m512i r[4], d[4];
|
|
// normal load
|
|
int i = 0;
|
|
for (; (i + 1) * 64 <= mrem * 4; i++) {
|
|
r[i] = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(src + i * 64));
|
|
}
|
|
int erem = mrem * 4 - i * 64;
|
|
if (erem > 0) {
|
|
__mmask64 mask_mrem_v = (1ULL << erem) - 1;
|
|
r[i] = _mm512_maskz_loadu_epi8(mask_mrem_v, src + i * 64);
|
|
}
|
|
|
|
// transpose
|
|
__m512i index = _mm512_set_epi32(
|
|
0x0f0b0703,
|
|
0x0e0a0602,
|
|
0x0d090501,
|
|
0x0c080400,
|
|
0x0f0b0703,
|
|
0x0e0a0602,
|
|
0x0d090501,
|
|
0x0c080400,
|
|
0x0f0b0703,
|
|
0x0e0a0602,
|
|
0x0d090501,
|
|
0x0c080400,
|
|
0x0f0b0703,
|
|
0x0e0a0602,
|
|
0x0d090501,
|
|
0x0c080400);
|
|
|
|
d[0] = _mm512_shuffle_epi8(r[0], index);
|
|
d[1] = _mm512_shuffle_epi8(r[1], index);
|
|
d[2] = _mm512_shuffle_epi8(r[2], index);
|
|
d[3] = _mm512_shuffle_epi8(r[3], index);
|
|
|
|
__m512i index2 =
|
|
_mm512_set_epi32(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0);
|
|
r[0] = _mm512_permutexvar_epi32(index2, d[0]);
|
|
r[1] = _mm512_permutexvar_epi32(index2, d[1]);
|
|
r[2] = _mm512_permutexvar_epi32(index2, d[2]);
|
|
r[3] = _mm512_permutexvar_epi32(index2, d[3]);
|
|
|
|
__m512i t0 = _mm512_shuffle_i32x4(r[0], r[1], 0x44);
|
|
__m512i t1 = _mm512_shuffle_i32x4(r[0], r[1], 0xee);
|
|
__m512i t2 = _mm512_shuffle_i32x4(r[2], r[3], 0x44);
|
|
__m512i t3 = _mm512_shuffle_i32x4(r[2], r[3], 0xee);
|
|
|
|
d[0] = _mm512_shuffle_i32x4(t0, t2, 0x88);
|
|
d[1] = _mm512_shuffle_i32x4(t0, t2, 0xdd);
|
|
d[2] = _mm512_shuffle_i32x4(t1, t3, 0x88);
|
|
d[3] = _mm512_shuffle_i32x4(t1, t3, 0xdd);
|
|
|
|
// store
|
|
if (mrem < 64) {
|
|
__mmask64 mask_rem_v = (1ULL << mrem) - 1;
|
|
// mask store
|
|
_mm512_mask_storeu_epi8(dst, mask_rem_v, d[0]);
|
|
_mm512_mask_storeu_epi8(dst + ld_dst, mask_rem_v, d[1]);
|
|
_mm512_mask_storeu_epi8(dst + 2 * ld_dst, mask_rem_v, d[2]);
|
|
_mm512_mask_storeu_epi8(dst + 3 * ld_dst, mask_rem_v, d[3]);
|
|
} else {
|
|
// normal store
|
|
_mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 0 * ld_dst), d[0]);
|
|
_mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 1 * ld_dst), d[1]);
|
|
_mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 2 * ld_dst), d[2]);
|
|
_mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 3 * ld_dst), d[3]);
|
|
}
|
|
}
|
|
|
|
static inline void transpose_contiguous_32x4_block(
|
|
const uint16_t* src,
|
|
uint16_t* dst,
|
|
int64_t ld_dst,
|
|
int mrem = 32) {
|
|
__m512i r[4], d[4];
|
|
int i = 0;
|
|
for (; (i + 1) * 32 <= mrem * 4; i++) {
|
|
// normal load
|
|
r[i] = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(src + i * 32));
|
|
}
|
|
if (i * 32 < mrem * 4) {
|
|
__mmask32 mask_mrem_v = (1ULL << (mrem * 4 - i * 32)) - 1;
|
|
r[i] = _mm512_maskz_loadu_epi16(mask_mrem_v, src + i * 32);
|
|
}
|
|
// transpose
|
|
__m512i index = _mm512_set_epi32(
|
|
0x001f001b,
|
|
0x00170013,
|
|
0x000f000b,
|
|
0x00070003,
|
|
0x001e001a,
|
|
0x00160012,
|
|
0x000e000a,
|
|
0x00060002,
|
|
0x001d0019,
|
|
0x00150011,
|
|
0x000d0009,
|
|
0x00050001,
|
|
0x001c0018,
|
|
0x00140010,
|
|
0x000c0008,
|
|
0x00040000);
|
|
|
|
d[0] = _mm512_permutexvar_epi16(index, r[0]);
|
|
d[1] = _mm512_permutexvar_epi16(index, r[1]);
|
|
d[2] = _mm512_permutexvar_epi16(index, r[2]);
|
|
d[3] = _mm512_permutexvar_epi16(index, r[3]);
|
|
|
|
r[0] = _mm512_shuffle_i32x4(d[0], d[1], 0x44);
|
|
r[1] = _mm512_shuffle_i32x4(d[0], d[1], 0xee);
|
|
r[2] = _mm512_shuffle_i32x4(d[2], d[3], 0x44);
|
|
r[3] = _mm512_shuffle_i32x4(d[2], d[3], 0xee);
|
|
|
|
d[0] = _mm512_shuffle_i32x4(r[0], r[2], 0x88);
|
|
d[1] = _mm512_shuffle_i32x4(r[0], r[2], 0xdd);
|
|
d[2] = _mm512_shuffle_i32x4(r[1], r[3], 0x88);
|
|
d[3] = _mm512_shuffle_i32x4(r[1], r[3], 0xdd);
|
|
|
|
if (mrem < 32) {
|
|
// mask store
|
|
__mmask32 mask_rem_v = (1ULL << mrem) - 1;
|
|
_mm512_mask_storeu_epi16(dst + 0 * ld_dst, mask_rem_v, d[0]);
|
|
_mm512_mask_storeu_epi16(dst + ld_dst, mask_rem_v, d[1]);
|
|
_mm512_mask_storeu_epi16(dst + 2 * ld_dst, mask_rem_v, d[2]);
|
|
_mm512_mask_storeu_epi16(dst + 3 * ld_dst, mask_rem_v, d[3]);
|
|
} else {
|
|
// normal load
|
|
_mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 0 * ld_dst), d[0]);
|
|
_mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 1 * ld_dst), d[1]);
|
|
_mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 2 * ld_dst), d[2]);
|
|
_mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 3 * ld_dst), d[3]);
|
|
}
|
|
}
|
|
|
|
static inline void transpose_contiguous_2x16_block(
|
|
const float* src,
|
|
float* dst,
|
|
int64_t ld_src,
|
|
int nrem = 16) {
|
|
__m512i r0, r1;
|
|
// load
|
|
if (nrem < 16) {
|
|
__mmask16 mask_mrem_v = (1ULL << nrem) - 1;
|
|
r0 = _mm512_maskz_loadu_epi32(mask_mrem_v, src);
|
|
r1 = _mm512_maskz_loadu_epi32(mask_mrem_v, src + ld_src);
|
|
} else {
|
|
r0 = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(src));
|
|
r1 = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(src + ld_src));
|
|
}
|
|
// transpose
|
|
__m512i index1 = _mm512_set_epi32(
|
|
0x0017,
|
|
0x0007,
|
|
0x0016,
|
|
0x0006,
|
|
0x0015,
|
|
0x0005,
|
|
0x0014,
|
|
0x0004,
|
|
0x0013,
|
|
0x0003,
|
|
0x0012,
|
|
0x0002,
|
|
0x0011,
|
|
0x0001,
|
|
0x0010,
|
|
0x0000);
|
|
__m512i index2 = _mm512_set_epi32(
|
|
0x001f,
|
|
0x000f,
|
|
0x001e,
|
|
0x000e,
|
|
0x001d,
|
|
0x000d,
|
|
0x001c,
|
|
0x000c,
|
|
0x001b,
|
|
0x000b,
|
|
0x001a,
|
|
0x000a,
|
|
0x0019,
|
|
0x0009,
|
|
0x0018,
|
|
0x0008);
|
|
// a0 b0 a1 b1 a2 b2 a3 b3 a4 b4 a5 b5 a6 b6 a7 b7
|
|
// a8 b8 a9 b9 a10 b10 a11 b11 a12 b12 a13 b13 a14 b14 a15 b15
|
|
__m512i u0 = _mm512_permutex2var_epi32(r0, index1, r1);
|
|
__m512i u1 = _mm512_permutex2var_epi32(r0, index2, r1);
|
|
// store
|
|
if (nrem < 16) {
|
|
// mask store
|
|
if (nrem < 8) {
|
|
__mmask16 mask_rem_v = (1ULL << (nrem * 2)) - 1;
|
|
_mm512_mask_storeu_epi32(dst, mask_rem_v, u0);
|
|
} else {
|
|
__mmask16 mask_rem_v = (1ULL << ((nrem - 8) * 2)) - 1;
|
|
_mm512_storeu_si512(reinterpret_cast<__m512i*>(dst), u0);
|
|
_mm512_mask_storeu_epi32(dst + 16, mask_rem_v, u1);
|
|
}
|
|
} else {
|
|
// normal store
|
|
_mm512_storeu_si512(reinterpret_cast<__m512i*>(dst), u0);
|
|
_mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 16), u1);
|
|
}
|
|
}
|
|
|
|
static inline void transpose_contiguous_64x2_block(
|
|
const uint8_t* src,
|
|
uint8_t* dst,
|
|
int64_t ld_dst,
|
|
int mrem = 64) {
|
|
__m512i r[2], d[2];
|
|
// normal load
|
|
int i = 0;
|
|
for (; (i + 1) * 64 <= mrem * 2; i++) {
|
|
r[i] = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(src + i * 64));
|
|
}
|
|
int erem = mrem * 2 - i * 64;
|
|
if (erem > 0) {
|
|
__mmask64 mask_mrem_v = (1ULL << erem) - 1;
|
|
r[i] = _mm512_maskz_loadu_epi8(mask_mrem_v, src + i * 64);
|
|
}
|
|
|
|
// transpose
|
|
__m512i index1 = _mm512_set_epi32(
|
|
0x0f0d0b09,
|
|
0x07050301,
|
|
0x0e0c0a08,
|
|
0x06040200,
|
|
0x0f0d0b09,
|
|
0x07050301,
|
|
0x0e0c0a08,
|
|
0x06040200,
|
|
0x0f0d0b09,
|
|
0x07050301,
|
|
0x0e0c0a08,
|
|
0x06040200,
|
|
0x0f0d0b09,
|
|
0x07050301,
|
|
0x0e0c0a08,
|
|
0x06040200);
|
|
r[0] = _mm512_shuffle_epi8(r[0], index1);
|
|
r[1] = _mm512_shuffle_epi8(r[1], index1);
|
|
|
|
__m512i index2 = _mm512_set_epi64(14, 12, 10, 8, 6, 4, 2, 0);
|
|
__m512i index3 = _mm512_set_epi64(15, 13, 11, 9, 7, 5, 3, 1);
|
|
|
|
d[0] = _mm512_permutex2var_epi64(r[0], index2, r[1]);
|
|
d[1] = _mm512_permutex2var_epi64(r[0], index3, r[1]);
|
|
|
|
// store
|
|
if (mrem < 64) {
|
|
__mmask64 mask_rem_v = (1ULL << mrem) - 1;
|
|
// mask store
|
|
_mm512_mask_storeu_epi8(dst, mask_rem_v, d[0]);
|
|
_mm512_mask_storeu_epi8(dst + ld_dst, mask_rem_v, d[1]);
|
|
} else {
|
|
// normal store
|
|
_mm512_storeu_si512(reinterpret_cast<__m512i*>(dst), d[0]);
|
|
_mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + ld_dst), d[1]);
|
|
}
|
|
}
|
|
|
|
static inline void transpose_contiguous_4x64_block(
|
|
const uint8_t* src,
|
|
uint8_t* dst,
|
|
int64_t ld_src,
|
|
int nrem = 64) {
|
|
__m512i r[4], d[4];
|
|
// load
|
|
if (nrem < 64) {
|
|
__mmask64 mask_mrem_v = (1ULL << nrem) - 1;
|
|
r[0] = _mm512_maskz_loadu_epi8(mask_mrem_v, src);
|
|
r[1] = _mm512_maskz_loadu_epi8(mask_mrem_v, src + ld_src);
|
|
r[2] = _mm512_maskz_loadu_epi8(mask_mrem_v, src + 2 * ld_src);
|
|
r[3] = _mm512_maskz_loadu_epi8(mask_mrem_v, src + 3 * ld_src);
|
|
} else {
|
|
r[0] = _mm512_loadu_si512(reinterpret_cast<const __m256i*>(src));
|
|
r[1] = _mm512_loadu_si512(reinterpret_cast<const __m256i*>(src + ld_src));
|
|
r[2] =
|
|
_mm512_loadu_si512(reinterpret_cast<const __m256i*>(src + 2 * ld_src));
|
|
r[3] =
|
|
_mm512_loadu_si512(reinterpret_cast<const __m256i*>(src + 3 * ld_src));
|
|
}
|
|
// transpose
|
|
d[0] = _mm512_unpacklo_epi32(r[0], r[1]);
|
|
d[1] = _mm512_unpackhi_epi32(r[0], r[1]);
|
|
d[2] = _mm512_unpacklo_epi32(r[2], r[3]);
|
|
d[3] = _mm512_unpackhi_epi32(r[2], r[3]);
|
|
|
|
r[0] = _mm512_unpacklo_epi64(d[0], d[2]);
|
|
r[1] = _mm512_unpackhi_epi64(d[0], d[2]);
|
|
r[2] = _mm512_unpacklo_epi64(d[1], d[3]);
|
|
r[3] = _mm512_unpackhi_epi64(d[1], d[3]);
|
|
|
|
d[0] = _mm512_shuffle_i32x4(r[0], r[1], 0x44);
|
|
d[1] = _mm512_shuffle_i32x4(r[0], r[1], 0xee);
|
|
d[2] = _mm512_shuffle_i32x4(r[2], r[3], 0x44);
|
|
d[3] = _mm512_shuffle_i32x4(r[2], r[3], 0xee);
|
|
|
|
r[0] = _mm512_shuffle_i32x4(d[0], d[2], 0x88);
|
|
r[1] = _mm512_shuffle_i32x4(d[0], d[2], 0xdd);
|
|
r[2] = _mm512_shuffle_i32x4(d[1], d[3], 0x88);
|
|
r[3] = _mm512_shuffle_i32x4(d[1], d[3], 0xdd);
|
|
|
|
__m512i index = _mm512_set_epi32(
|
|
0x0f0b0703,
|
|
0x0e0a0602,
|
|
0x0d090501,
|
|
0x0c080400,
|
|
0x0f0b0703,
|
|
0x0e0a0602,
|
|
0x0d090501,
|
|
0x0c080400,
|
|
0x0f0b0703,
|
|
0x0e0a0602,
|
|
0x0d090501,
|
|
0x0c080400,
|
|
0x0f0b0703,
|
|
0x0e0a0602,
|
|
0x0d090501,
|
|
0x0c080400);
|
|
|
|
d[0] = _mm512_shuffle_epi8(r[0], index);
|
|
d[1] = _mm512_shuffle_epi8(r[1], index);
|
|
d[2] = _mm512_shuffle_epi8(r[2], index);
|
|
d[3] = _mm512_shuffle_epi8(r[3], index);
|
|
|
|
// store
|
|
int i = 0;
|
|
for (; (i + 1) * 64 <= nrem * 4; i++) {
|
|
_mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + i * 64), d[i]);
|
|
}
|
|
int erem = nrem * 4 - i * 64;
|
|
if (erem > 0) {
|
|
__mmask64 mask_rem_v = (1ULL << erem) - 1;
|
|
_mm512_mask_storeu_epi8(dst + i * 64, mask_rem_v, d[i]);
|
|
}
|
|
}
|
|
|
|
static inline void transpose_contiguous_2x64_block(
|
|
const uint8_t* src,
|
|
uint8_t* dst,
|
|
int64_t ld_src,
|
|
int nrem = 64) {
|
|
__m512i r[2];
|
|
__m512i d[2];
|
|
// load
|
|
if (nrem < 64) {
|
|
__mmask64 mask_mrem_v = (1ULL << nrem) - 1;
|
|
r[0] = _mm512_maskz_loadu_epi8(mask_mrem_v, src);
|
|
r[1] = _mm512_maskz_loadu_epi8(mask_mrem_v, src + ld_src);
|
|
} else {
|
|
r[0] = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(src));
|
|
r[1] = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(src + ld_src));
|
|
}
|
|
// transpose
|
|
// _mm512_mask_blend_epi8(0xaaaaaaaaaaaaaaaa, r0, r1);
|
|
d[0] = _mm512_unpacklo_epi16(r[0], r[1]);
|
|
d[1] = _mm512_unpackhi_epi16(r[0], r[1]);
|
|
__m512i index1 = _mm512_set_epi32(
|
|
0x0f0d0e0c,
|
|
0x0b090a08,
|
|
0x07050604,
|
|
0x03010200,
|
|
0x0f0d0e0c,
|
|
0x0b090a08,
|
|
0x07050604,
|
|
0x03010200,
|
|
0x0f0d0e0c,
|
|
0x0b090a08,
|
|
0x07050604,
|
|
0x03010200,
|
|
0x0f0d0e0c,
|
|
0x0b090a08,
|
|
0x07050604,
|
|
0x03010200);
|
|
r[0] = _mm512_shuffle_epi8(d[0], index1);
|
|
r[1] = _mm512_shuffle_epi8(d[1], index1);
|
|
__m512i index2 = _mm512_set_epi64(11, 10, 3, 2, 9, 8, 1, 0);
|
|
__m512i index3 = _mm512_set_epi64(15, 14, 7, 6, 13, 12, 5, 4);
|
|
// a0b0 a1b1 ... a31b31
|
|
// a32b32 ... a63b63
|
|
d[0] = _mm512_permutex2var_epi64(r[0], index2, r[1]);
|
|
d[1] = _mm512_permutex2var_epi64(r[0], index3, r[1]);
|
|
|
|
int i = 0;
|
|
for (; (i + 1) * 64 <= nrem * 2; i++) {
|
|
_mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + i * 64), d[i]);
|
|
}
|
|
int erem = nrem * 2 - i * 64;
|
|
if (erem > 0) {
|
|
__mmask64 mask_rem_v = (1ULL << erem) - 1;
|
|
_mm512_mask_storeu_epi8(dst + i * 64, mask_rem_v, d[i]);
|
|
}
|
|
}
|
|
|
|
static inline void transpose_contiguous_2x32_block(
|
|
const uint16_t* src,
|
|
uint16_t* dst,
|
|
int64_t ld_src,
|
|
int nrem = 32) {
|
|
__m512i r0, r1;
|
|
__m512i d0, d1;
|
|
// load
|
|
if (nrem < 32) {
|
|
__mmask32 mask_mrem_v = (1ULL << nrem) - 1;
|
|
r0 = _mm512_maskz_loadu_epi16(mask_mrem_v, src);
|
|
r1 = _mm512_maskz_loadu_epi16(mask_mrem_v, src + ld_src);
|
|
} else {
|
|
r0 = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(src));
|
|
r1 = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(src + ld_src));
|
|
}
|
|
// transpose
|
|
d0 = _mm512_unpacklo_epi16(r0, r1);
|
|
d1 = _mm512_unpackhi_epi16(r0, r1);
|
|
r0 = _mm512_shuffle_i32x4(d0, d1, 0x88);
|
|
r1 = _mm512_shuffle_i32x4(d0, d1, 0xdd);
|
|
d0 = _mm512_shuffle_i32x4(r0, r1, 0x88);
|
|
d1 = _mm512_shuffle_i32x4(r0, r1, 0xdd);
|
|
|
|
// store
|
|
if (nrem < 16) {
|
|
__mmask32 mask_rem_v = (1ULL << (nrem * 2)) - 1;
|
|
_mm512_mask_storeu_epi16(dst, mask_rem_v, d0);
|
|
} else if (nrem == 16) {
|
|
_mm512_storeu_si512(reinterpret_cast<__m512i*>(dst), d0);
|
|
} else if (nrem < 32) {
|
|
__mmask32 mask_rem_v = (1ULL << (nrem * 2 - 32)) - 1;
|
|
_mm512_mask_storeu_epi16(dst, mask_rem_v, d0);
|
|
_mm512_storeu_si512(reinterpret_cast<__m512i*>(dst), d0);
|
|
_mm512_mask_storeu_epi16(
|
|
reinterpret_cast<__m512i*>(dst + 32), mask_rem_v, d1);
|
|
} else {
|
|
// normal store
|
|
_mm512_storeu_si512(reinterpret_cast<__m512i*>(dst), d0);
|
|
_mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 32), d1);
|
|
}
|
|
}
|
|
|
|
static inline void transpose_contiguous_32x2_block(
|
|
const uint16_t* src,
|
|
uint16_t* dst,
|
|
int64_t ld_dst,
|
|
int mrem = 32) {
|
|
__m512i r[2], d[2];
|
|
// load
|
|
int i = 0;
|
|
for (; (i + 1) * 32 <= mrem * 2; i++) {
|
|
r[i] = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(src + i * 32));
|
|
}
|
|
int erem = mrem * 2 - i * 32;
|
|
if (erem > 0) {
|
|
__mmask32 mask_mrem_v = (1ULL << erem) - 1;
|
|
r[i] = _mm512_maskz_loadu_epi16(mask_mrem_v, src + i * 32);
|
|
}
|
|
// transpose
|
|
__m512i index = _mm512_set_epi32(
|
|
0x001f001d,
|
|
0x001b0019,
|
|
0x00170015,
|
|
0x00130011,
|
|
0x000f000d,
|
|
0x000b0009,
|
|
0x00070005,
|
|
0x00030001,
|
|
0x001e001c,
|
|
0x001a0018,
|
|
0x00160014,
|
|
0x00120010,
|
|
0x000e000c,
|
|
0x000a0008,
|
|
0x00060004,
|
|
0x00020000);
|
|
d[0] = _mm512_permutexvar_epi16(index, r[0]);
|
|
d[1] = _mm512_permutexvar_epi16(index, r[1]);
|
|
r[0] = _mm512_shuffle_i32x4(d[0], d[1], 0x44);
|
|
r[1] = _mm512_shuffle_i32x4(d[0], d[1], 0xee);
|
|
|
|
// store
|
|
if (mrem < 32) {
|
|
__mmask32 mask_rem_v = (1ULL << mrem) - 1;
|
|
// mask store
|
|
_mm512_mask_storeu_epi16(dst, mask_rem_v, r[0]);
|
|
_mm512_mask_storeu_epi16(dst + ld_dst, mask_rem_v, r[1]);
|
|
} else {
|
|
// normal store
|
|
_mm512_storeu_si512(reinterpret_cast<__m512i*>(dst), r[0]);
|
|
_mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + ld_dst), r[1]);
|
|
}
|
|
}
|
|
|
|
template <bool MREM = false, bool NREM = false>
|
|
void transpose_16x16_block(
|
|
const uint16_t* src,
|
|
int64_t ld_src,
|
|
uint16_t* dst,
|
|
int64_t ld_dst,
|
|
int mrem = 16,
|
|
int nrem = 16) {
|
|
__m512i r[8];
|
|
if (MREM || NREM) {
|
|
load_with_remainders_i16(src, ld_src, r, mrem, nrem);
|
|
} else {
|
|
__m256i t00 =
|
|
_mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + 0 * ld_src));
|
|
__m256i t01 =
|
|
_mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + 1 * ld_src));
|
|
__m256i t02 =
|
|
_mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + 2 * ld_src));
|
|
__m256i t03 =
|
|
_mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + 3 * ld_src));
|
|
__m256i t04 =
|
|
_mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + 4 * ld_src));
|
|
__m256i t05 =
|
|
_mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + 5 * ld_src));
|
|
__m256i t06 =
|
|
_mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + 6 * ld_src));
|
|
__m256i t07 =
|
|
_mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + 7 * ld_src));
|
|
__m256i t08 =
|
|
_mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + 8 * ld_src));
|
|
__m256i t09 =
|
|
_mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + 9 * ld_src));
|
|
__m256i t10 =
|
|
_mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + 10 * ld_src));
|
|
__m256i t11 =
|
|
_mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + 11 * ld_src));
|
|
__m256i t12 =
|
|
_mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + 12 * ld_src));
|
|
__m256i t13 =
|
|
_mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + 13 * ld_src));
|
|
__m256i t14 =
|
|
_mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + 14 * ld_src));
|
|
__m256i t15 =
|
|
_mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + 15 * ld_src));
|
|
|
|
// a0a1 a2a3 a4a5 a6a7 a8a9 a10a11 a12a13 a14a15
|
|
// e0e1 e2e3 e4e5 e6e7 e8e9 e10e11 e12e13 e14e15
|
|
r[0] = _mm512_inserti64x4(_mm512_castsi256_si512(t00), t04, 0x01);
|
|
// b0-b15
|
|
// f0-f15
|
|
r[1] = _mm512_inserti64x4(_mm512_castsi256_si512(t01), t05, 0x01);
|
|
// c0-c15
|
|
// g0-g15
|
|
r[2] = _mm512_inserti64x4(_mm512_castsi256_si512(t02), t06, 0x01);
|
|
// d0-d15
|
|
// g0-h15
|
|
r[3] = _mm512_inserti64x4(_mm512_castsi256_si512(t03), t07, 0x01);
|
|
// i0-i15
|
|
// m0-m15
|
|
r[4] = _mm512_inserti64x4(_mm512_castsi256_si512(t08), t12, 0x01);
|
|
// j0-j15
|
|
// n0-n15
|
|
r[5] = _mm512_inserti64x4(_mm512_castsi256_si512(t09), t13, 0x01);
|
|
// k0-k15
|
|
// o0-o15
|
|
r[6] = _mm512_inserti64x4(_mm512_castsi256_si512(t10), t14, 0x01);
|
|
// l0-l15
|
|
// p0-p15
|
|
r[7] = _mm512_inserti64x4(_mm512_castsi256_si512(t11), t15, 0x01);
|
|
}
|
|
__m512i u[8];
|
|
core_transpose_16x16_block(r, u);
|
|
if (MREM || NREM) {
|
|
store_with_remainders_i16(dst, ld_dst, u, mrem, nrem);
|
|
} else {
|
|
_mm256_storeu_si256(
|
|
reinterpret_cast<__m256i*>(dst + 0 * ld_dst),
|
|
_mm512_extracti32x8_epi32(u[0], 0x0));
|
|
_mm256_storeu_si256(
|
|
reinterpret_cast<__m256i*>(dst + 1 * ld_dst),
|
|
_mm512_extracti32x8_epi32(u[0], 0x01));
|
|
_mm256_storeu_si256(
|
|
reinterpret_cast<__m256i*>(dst + 2 * ld_dst),
|
|
_mm512_extracti32x8_epi32(u[1], 0x0));
|
|
_mm256_storeu_si256(
|
|
reinterpret_cast<__m256i*>(dst + 3 * ld_dst),
|
|
_mm512_extracti32x8_epi32(u[1], 0x01));
|
|
_mm256_storeu_si256(
|
|
reinterpret_cast<__m256i*>(dst + 4 * ld_dst),
|
|
_mm512_extracti32x8_epi32(u[2], 0x0));
|
|
_mm256_storeu_si256(
|
|
reinterpret_cast<__m256i*>(dst + 5 * ld_dst),
|
|
_mm512_extracti32x8_epi32(u[2], 0x01));
|
|
_mm256_storeu_si256(
|
|
reinterpret_cast<__m256i*>(dst + 6 * ld_dst),
|
|
_mm512_extracti32x8_epi32(u[3], 0x0));
|
|
_mm256_storeu_si256(
|
|
reinterpret_cast<__m256i*>(dst + 7 * ld_dst),
|
|
_mm512_extracti32x8_epi32(u[3], 0x01));
|
|
_mm256_storeu_si256(
|
|
reinterpret_cast<__m256i*>(dst + 8 * ld_dst),
|
|
_mm512_extracti32x8_epi32(u[4], 0x0));
|
|
_mm256_storeu_si256(
|
|
reinterpret_cast<__m256i*>(dst + 9 * ld_dst),
|
|
_mm512_extracti32x8_epi32(u[4], 0x01));
|
|
_mm256_storeu_si256(
|
|
reinterpret_cast<__m256i*>(dst + 10 * ld_dst),
|
|
_mm512_extracti32x8_epi32(u[5], 0x0));
|
|
_mm256_storeu_si256(
|
|
reinterpret_cast<__m256i*>(dst + 11 * ld_dst),
|
|
_mm512_extracti32x8_epi32(u[5], 0x01));
|
|
_mm256_storeu_si256(
|
|
reinterpret_cast<__m256i*>(dst + 12 * ld_dst),
|
|
_mm512_extracti32x8_epi32(u[6], 0x0));
|
|
_mm256_storeu_si256(
|
|
reinterpret_cast<__m256i*>(dst + 13 * ld_dst),
|
|
_mm512_extracti32x8_epi32(u[6], 0x01));
|
|
_mm256_storeu_si256(
|
|
reinterpret_cast<__m256i*>(dst + 14 * ld_dst),
|
|
_mm512_extracti32x8_epi32(u[7], 0x0));
|
|
_mm256_storeu_si256(
|
|
reinterpret_cast<__m256i*>(dst + 15 * ld_dst),
|
|
_mm512_extracti32x8_epi32(u[7], 0x01));
|
|
}
|
|
}
|
|
|
|
template <bool MREM = false, bool NREM = false>
|
|
void transpose_16x32_block(
|
|
const uint8_t* src,
|
|
int64_t ld_src,
|
|
uint8_t* dst,
|
|
int64_t ld_dst,
|
|
int mrem = 16,
|
|
int nrem = 32) {
|
|
// Treat the numbers in a row as 4-Byte integers.
|
|
// Thus 03_04 is is 4-byte element in 03 row and 04 column
|
|
//
|
|
// 00_00 00_01 00_02 00_03 00_04 00_05 00_06 00_07
|
|
// 01_00 01_01 01_02 01_03 01_04 01_05 01_06 01_07
|
|
// 02_00 02_01 02_02 02_03 02_04 02_05 02_06 02_07
|
|
// 03_00 03_01 03_02 03_03 03_04 03_05 03_06 03_07
|
|
// 04_00 04_01 04_02 04_03 04_04 04_05 04_06 04_07
|
|
// 05_00 05_01 05_02 05_03 05_04 05_05 05_06 05_07
|
|
// 06_00 06_01 06_02 06_03 06_04 06_05 06_06 06_07
|
|
// 07_00 07_01 07_02 07_03 07_04 07_05 07_06 07_07
|
|
// 08_00 08_01 08_02 08_03 08_04 08_05 08_06 08_07
|
|
// 09_00 09_01 09_02 09_03 09_04 09_05 09_06 09_07
|
|
// 10_00 10_01 10_02 10_03 10_04 10_05 10_06 10_07
|
|
// 11_00 11_01 11_02 11_03 11_04 11_05 11_06 11_07
|
|
// 12_00 12_01 12_02 12_03 12_04 12_05 12_06 12_07
|
|
// 13_00 13_01 13_02 13_03 13_04 13_05 13_06 13_07
|
|
// 14_00 14_01 14_02 14_03 14_04 14_05 14_06 14_07
|
|
// 15_00 15_01 15_02 15_03 15_04 15_05 15_06 15_07
|
|
|
|
__m512i r[8];
|
|
if (MREM || NREM) {
|
|
load_with_remainders_i8(src, ld_src, r, mrem, nrem);
|
|
} else {
|
|
__m256i t00 =
|
|
_mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + 0 * ld_src));
|
|
__m256i t04 =
|
|
_mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + 4 * ld_src));
|
|
// 00_00 00_01 00_02 00_03 00_04 00_05 00_06 00_07 04_00 04_01 04_02 04_03
|
|
// 04_04 04_05 04_06 04_07
|
|
r[0] = _mm512_inserti64x4(_mm512_castsi256_si512(t00), t04, 0x01);
|
|
|
|
__m256i t01 =
|
|
_mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + 1 * ld_src));
|
|
__m256i t05 =
|
|
_mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + 5 * ld_src));
|
|
// 01_00 01_01 01_02 01_03 01_04 01_05 01_06 01_07 05_00 05_01 05_02 05_03
|
|
// 05_04 05_05 05_06 05_07
|
|
r[1] = _mm512_inserti64x4(_mm512_castsi256_si512(t01), t05, 0x01);
|
|
|
|
__m256i t02 =
|
|
_mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + 2 * ld_src));
|
|
__m256i t06 =
|
|
_mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + 6 * ld_src));
|
|
// 02_00 02_01 02_02 02_03 02_04 02_05 02_06 02_07 06_00 06_01 06_02 06_03
|
|
// 06_04 06_05 06_06 06_07
|
|
r[2] = _mm512_inserti64x4(_mm512_castsi256_si512(t02), t06, 0x01);
|
|
|
|
__m256i t03 =
|
|
_mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + 3 * ld_src));
|
|
__m256i t07 =
|
|
_mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + 7 * ld_src));
|
|
// 03_00 03_01 03_02 03_03 03_04 03_05 03_06 03_07 07_00 07_01 07_02 07_03
|
|
// 07_04 07_05 07_06 07_07
|
|
r[3] = _mm512_inserti64x4(_mm512_castsi256_si512(t03), t07, 0x01);
|
|
|
|
__m256i t08 =
|
|
_mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + 8 * ld_src));
|
|
__m256i t12 =
|
|
_mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + 12 * ld_src));
|
|
// 08_00 08_01 08_02 08_03 08_04 08_05 08_06 08_07 12_00 12_01 12_02 12_03
|
|
// 12_04 12_05 12_06 12_07
|
|
r[4] = _mm512_inserti64x4(_mm512_castsi256_si512(t08), t12, 0x01);
|
|
|
|
__m256i t09 =
|
|
_mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + 9 * ld_src));
|
|
__m256i t13 =
|
|
_mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + 13 * ld_src));
|
|
// 09_00 09_01 09_02 09_03 09_04 09_05 09_06 09_07 13_00 13_01 13_02 13_03
|
|
// 13_04 13_05 13_06 13_07
|
|
r[5] = _mm512_inserti64x4(_mm512_castsi256_si512(t09), t13, 0x01);
|
|
|
|
__m256i t10 =
|
|
_mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + 10 * ld_src));
|
|
__m256i t14 =
|
|
_mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + 14 * ld_src));
|
|
// 10_00 10_01 10_02 10_03 10_04 10_05 10_06 10_07 14_00 14_01 14_02 14_03
|
|
// 14_04 14_05 14_06 14_07
|
|
r[6] = _mm512_inserti64x4(_mm512_castsi256_si512(t10), t14, 0x01);
|
|
|
|
__m256i t11 =
|
|
_mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + 11 * ld_src));
|
|
__m256i t15 =
|
|
_mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + 15 * ld_src));
|
|
// 11_00 11_01 11_02 11_03 11_04 11_05 11_06 11_07 15_00 15_01 15_02 15_03
|
|
// 15_04 15_05 15_06 15_07
|
|
r[7] = _mm512_inserti64x4(_mm512_castsi256_si512(t11), t15, 0x01);
|
|
}
|
|
|
|
__m512i u[8];
|
|
core_transpose_16x32_block_i8(r, u);
|
|
|
|
if (MREM || NREM) {
|
|
store_with_remainders_i8(dst, ld_dst, u, mrem, nrem);
|
|
} else {
|
|
_mm_storeu_si128(
|
|
reinterpret_cast<__m128i*>(dst + 0 * ld_dst),
|
|
_mm512_extracti32x4_epi32(u[0], 0x0));
|
|
_mm_storeu_si128(
|
|
reinterpret_cast<__m128i*>(dst + 1 * ld_dst),
|
|
_mm512_extracti32x4_epi32(u[0], 0x1));
|
|
_mm_storeu_si128(
|
|
reinterpret_cast<__m128i*>(dst + 2 * ld_dst),
|
|
_mm512_extracti32x4_epi32(u[0], 0x2));
|
|
_mm_storeu_si128(
|
|
reinterpret_cast<__m128i*>(dst + 3 * ld_dst),
|
|
_mm512_extracti32x4_epi32(u[0], 0x3));
|
|
_mm_storeu_si128(
|
|
reinterpret_cast<__m128i*>(dst + 16 * ld_dst),
|
|
_mm512_extracti32x4_epi32(u[1], 0x0));
|
|
_mm_storeu_si128(
|
|
reinterpret_cast<__m128i*>(dst + 17 * ld_dst),
|
|
_mm512_extracti32x4_epi32(u[1], 0x1));
|
|
_mm_storeu_si128(
|
|
reinterpret_cast<__m128i*>(dst + 18 * ld_dst),
|
|
_mm512_extracti32x4_epi32(u[1], 0x2));
|
|
_mm_storeu_si128(
|
|
reinterpret_cast<__m128i*>(dst + 19 * ld_dst),
|
|
_mm512_extracti32x4_epi32(u[1], 0x3));
|
|
_mm_storeu_si128(
|
|
reinterpret_cast<__m128i*>(dst + 4 * ld_dst),
|
|
_mm512_extracti32x4_epi32(u[2], 0x0));
|
|
_mm_storeu_si128(
|
|
reinterpret_cast<__m128i*>(dst + 5 * ld_dst),
|
|
_mm512_extracti32x4_epi32(u[2], 0x1));
|
|
_mm_storeu_si128(
|
|
reinterpret_cast<__m128i*>(dst + 6 * ld_dst),
|
|
_mm512_extracti32x4_epi32(u[2], 0x2));
|
|
_mm_storeu_si128(
|
|
reinterpret_cast<__m128i*>(dst + 7 * ld_dst),
|
|
_mm512_extracti32x4_epi32(u[2], 0x3));
|
|
_mm_storeu_si128(
|
|
reinterpret_cast<__m128i*>(dst + 20 * ld_dst),
|
|
_mm512_extracti32x4_epi32(u[3], 0x0));
|
|
_mm_storeu_si128(
|
|
reinterpret_cast<__m128i*>(dst + 21 * ld_dst),
|
|
_mm512_extracti32x4_epi32(u[3], 0x1));
|
|
_mm_storeu_si128(
|
|
reinterpret_cast<__m128i*>(dst + 22 * ld_dst),
|
|
_mm512_extracti32x4_epi32(u[3], 0x2));
|
|
_mm_storeu_si128(
|
|
reinterpret_cast<__m128i*>(dst + 23 * ld_dst),
|
|
_mm512_extracti32x4_epi32(u[3], 0x3));
|
|
_mm_storeu_si128(
|
|
reinterpret_cast<__m128i*>(dst + 8 * ld_dst),
|
|
_mm512_extracti32x4_epi32(u[4], 0x0));
|
|
_mm_storeu_si128(
|
|
reinterpret_cast<__m128i*>(dst + 9 * ld_dst),
|
|
_mm512_extracti32x4_epi32(u[4], 0x1));
|
|
_mm_storeu_si128(
|
|
reinterpret_cast<__m128i*>(dst + 10 * ld_dst),
|
|
_mm512_extracti32x4_epi32(u[4], 0x2));
|
|
_mm_storeu_si128(
|
|
reinterpret_cast<__m128i*>(dst + 11 * ld_dst),
|
|
_mm512_extracti32x4_epi32(u[4], 0x3));
|
|
_mm_storeu_si128(
|
|
reinterpret_cast<__m128i*>(dst + 24 * ld_dst),
|
|
_mm512_extracti32x4_epi32(u[5], 0x0));
|
|
_mm_storeu_si128(
|
|
reinterpret_cast<__m128i*>(dst + 25 * ld_dst),
|
|
_mm512_extracti32x4_epi32(u[5], 0x1));
|
|
_mm_storeu_si128(
|
|
reinterpret_cast<__m128i*>(dst + 26 * ld_dst),
|
|
_mm512_extracti32x4_epi32(u[5], 0x2));
|
|
_mm_storeu_si128(
|
|
reinterpret_cast<__m128i*>(dst + 27 * ld_dst),
|
|
_mm512_extracti32x4_epi32(u[5], 0x3));
|
|
_mm_storeu_si128(
|
|
reinterpret_cast<__m128i*>(dst + 12 * ld_dst),
|
|
_mm512_extracti32x4_epi32(u[6], 0x0));
|
|
_mm_storeu_si128(
|
|
reinterpret_cast<__m128i*>(dst + 13 * ld_dst),
|
|
_mm512_extracti32x4_epi32(u[6], 0x1));
|
|
_mm_storeu_si128(
|
|
reinterpret_cast<__m128i*>(dst + 14 * ld_dst),
|
|
_mm512_extracti32x4_epi32(u[6], 0x2));
|
|
_mm_storeu_si128(
|
|
reinterpret_cast<__m128i*>(dst + 15 * ld_dst),
|
|
_mm512_extracti32x4_epi32(u[6], 0x3));
|
|
_mm_storeu_si128(
|
|
reinterpret_cast<__m128i*>(dst + 28 * ld_dst),
|
|
_mm512_extracti32x4_epi32(u[7], 0x0));
|
|
_mm_storeu_si128(
|
|
reinterpret_cast<__m128i*>(dst + 29 * ld_dst),
|
|
_mm512_extracti32x4_epi32(u[7], 0x1));
|
|
_mm_storeu_si128(
|
|
reinterpret_cast<__m128i*>(dst + 30 * ld_dst),
|
|
_mm512_extracti32x4_epi32(u[7], 0x2));
|
|
_mm_storeu_si128(
|
|
reinterpret_cast<__m128i*>(dst + 31 * ld_dst),
|
|
_mm512_extracti32x4_epi32(u[7], 0x3));
|
|
}
|
|
}
|
|
|
|
template <>
|
|
void transpose_avx512_contiguous_thin(
|
|
int64_t M,
|
|
int64_t N,
|
|
const float* src,
|
|
int64_t ld_src,
|
|
float* dst,
|
|
int64_t ld_dst) {
|
|
if (N == 2) {
|
|
int64_t i = 0;
|
|
for (; i < M / 16 * 16; i += 16) {
|
|
transpose_contiguous_16x2_block(src + i * ld_src, dst + i, ld_dst);
|
|
}
|
|
int mrem = M - i;
|
|
if (mrem > 0) {
|
|
transpose_contiguous_16x2_block(src + i * ld_src, dst + i, ld_dst, mrem);
|
|
}
|
|
} else if (N == 4) {
|
|
int64_t i = 0;
|
|
for (; i < M / 16 * 16; i += 16) {
|
|
transpose_contiguous_16x4_block(src + i * ld_src, dst + i, ld_dst);
|
|
}
|
|
int mrem = M - i;
|
|
if (mrem > 0) {
|
|
transpose_contiguous_16x4_block(src + i * ld_src, dst + i, ld_dst, mrem);
|
|
}
|
|
}
|
|
}
|
|
|
|
template <>
|
|
void transpose_avx512_contiguous_thin(
|
|
int64_t M,
|
|
int64_t N,
|
|
const uint16_t* src,
|
|
int64_t ld_src,
|
|
uint16_t* dst,
|
|
int64_t ld_dst) {
|
|
if (N == 2) {
|
|
int64_t i = 0;
|
|
for (; i < M / 32 * 32; i += 32) {
|
|
transpose_contiguous_32x2_block(src + i * ld_src, dst + i, ld_dst);
|
|
}
|
|
int mrem = M - i;
|
|
if (mrem > 0) {
|
|
transpose_contiguous_32x2_block(src + i * ld_src, dst + i, ld_dst, mrem);
|
|
}
|
|
} else if (N == 4) {
|
|
int64_t i = 0;
|
|
for (; i < M / 32 * 32; i += 32) {
|
|
transpose_contiguous_32x4_block(src + i * ld_src, dst + i, ld_dst);
|
|
}
|
|
int mrem = M - i;
|
|
if (mrem > 0) {
|
|
transpose_contiguous_32x4_block(src + i * ld_src, dst + i, ld_dst, mrem);
|
|
}
|
|
}
|
|
}
|
|
|
|
template <>
|
|
void transpose_avx512_contiguous_thin(
|
|
int64_t M,
|
|
int64_t N,
|
|
const uint8_t* src,
|
|
int64_t ld_src,
|
|
uint8_t* dst,
|
|
int64_t ld_dst) {
|
|
if (N == 2) {
|
|
int64_t i = 0;
|
|
for (; i < M / 64 * 64; i += 64) {
|
|
transpose_contiguous_64x2_block(src + i * ld_src, dst + i, ld_dst);
|
|
}
|
|
int mrem = M - i;
|
|
if (mrem > 0) {
|
|
transpose_contiguous_64x2_block(src + i * ld_src, dst + i, ld_dst, mrem);
|
|
}
|
|
} else if (N == 4) {
|
|
int64_t i = 0;
|
|
for (; i < M / 64 * 64; i += 64) {
|
|
transpose_contiguous_64x4_block(src + i * ld_src, dst + i, ld_dst);
|
|
}
|
|
int mrem = M - i;
|
|
if (mrem > 0) {
|
|
transpose_contiguous_64x4_block(src + i * ld_src, dst + i, ld_dst, mrem);
|
|
}
|
|
}
|
|
}
|
|
|
|
template <>
|
|
void transpose_avx512_contiguous_wide(
|
|
int64_t M,
|
|
int64_t N,
|
|
const float* src,
|
|
int64_t ld_src,
|
|
float* dst,
|
|
int64_t ld_dst) {
|
|
if (M == 2) {
|
|
int64_t i = 0;
|
|
for (; i < N / 16 * 16; i += 16) {
|
|
transpose_contiguous_2x16_block(src + i, dst + i * ld_dst, ld_src);
|
|
}
|
|
int nrem = N - i;
|
|
if (nrem > 0) {
|
|
transpose_contiguous_2x16_block(src + i, dst + i * ld_dst, ld_src, nrem);
|
|
}
|
|
} else if (M == 4) {
|
|
int64_t i = 0;
|
|
for (; i < N / 16 * 16; i += 16) {
|
|
transpose_contiguous_4x16_block(src + i, dst + i * ld_dst, ld_src);
|
|
}
|
|
int nrem = N - i;
|
|
if (nrem > 0) {
|
|
transpose_contiguous_4x16_block(src + i, dst + i * ld_dst, ld_src, nrem);
|
|
}
|
|
}
|
|
}
|
|
|
|
template <>
|
|
void transpose_avx512_contiguous_wide(
|
|
int64_t M,
|
|
int64_t N,
|
|
const uint16_t* src,
|
|
int64_t ld_src,
|
|
uint16_t* dst,
|
|
int64_t ld_dst) {
|
|
if (M == 2) {
|
|
int64_t i = 0;
|
|
for (; i < N / 32 * 32; i += 32) {
|
|
transpose_contiguous_2x32_block(src + i, dst + i * ld_dst, ld_src);
|
|
}
|
|
int nrem = N - i;
|
|
if (nrem > 0) {
|
|
transpose_contiguous_2x32_block(src + i, dst + i * ld_dst, ld_src, nrem);
|
|
}
|
|
} else if (M == 4) {
|
|
int64_t i = 0;
|
|
for (; i < N / 32 * 32; i += 32) {
|
|
transpose_contiguous_4x32_block(src + i, dst + i * ld_dst, ld_src);
|
|
}
|
|
int nrem = N - i;
|
|
if (nrem > 0) {
|
|
transpose_contiguous_4x32_block(src + i, dst + i * ld_dst, ld_src, nrem);
|
|
}
|
|
}
|
|
}
|
|
|
|
template <>
|
|
void transpose_avx512_contiguous_wide(
|
|
int64_t M,
|
|
int64_t N,
|
|
const uint8_t* src,
|
|
int64_t ld_src,
|
|
uint8_t* dst,
|
|
int64_t ld_dst) {
|
|
if (M == 2) {
|
|
int64_t i = 0;
|
|
for (; i < N / 64 * 64; i += 64) {
|
|
transpose_contiguous_2x64_block(src + i, dst + i * ld_dst, ld_src);
|
|
}
|
|
int nrem = N - i;
|
|
if (nrem > 0) {
|
|
transpose_contiguous_2x64_block(src + i, dst + i * ld_dst, ld_src, nrem);
|
|
}
|
|
} else if (M == 4) {
|
|
int64_t i = 0;
|
|
for (; i < N / 64 * 64; i += 64) {
|
|
transpose_contiguous_4x64_block(src + i, dst + i * ld_dst, ld_src);
|
|
}
|
|
int nrem = N - i;
|
|
if (nrem > 0) {
|
|
transpose_contiguous_4x64_block(src + i, dst + i * ld_dst, ld_src, nrem);
|
|
}
|
|
}
|
|
}
|
|
|
|
template <>
|
|
void transpose_avx512(
|
|
int64_t M,
|
|
int64_t N,
|
|
const float* src,
|
|
int64_t ld_src,
|
|
float* dst,
|
|
int64_t ld_dst) {
|
|
if (M == ld_dst && (M == 2 || M == 4)) {
|
|
transpose_avx512_contiguous_wide(M, N, src, ld_src, dst, ld_dst);
|
|
} else if (N == ld_src && (N == 2 || N == 4)) {
|
|
transpose_avx512_contiguous_thin(M, N, src, ld_src, dst, ld_dst);
|
|
} else {
|
|
int64_t ib = 0, jb = 0;
|
|
if (N % 16 > 0 && N % 16 < 4) {
|
|
// If the remainder has n < 4 columns, we use the SSE kernel for the
|
|
// remainder because it requires 4 * (2 * 4 + 2 * N) = 32 + 8N
|
|
// instructions instead of 4 * 16 + 2 * N = 64 + 2N instructions needed in
|
|
// the masked AVX512 kernel.
|
|
for (ib = 0; ib + 16 <= M; ib += 16) {
|
|
for (jb = 0; jb + 16 <= N; jb += 16) {
|
|
transpose_kernel_16x16_avx512(
|
|
&src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
|
|
}
|
|
for (int64_t i = ib; i < ib + 16; i += 4) {
|
|
transpose_kernel_mxn_sse<4>(
|
|
N - jb,
|
|
&src[i * ld_src + jb],
|
|
ld_src,
|
|
&dst[i + jb * ld_dst],
|
|
ld_dst);
|
|
}
|
|
}
|
|
} else if (N % 16 == 4) {
|
|
// If the remainder has 4 columns, we use the SSE kernel for the remainder
|
|
// because it requires 4 * 16 = 64 instructions instead of 4 * 16 + 2 * 4
|
|
// = 72 instructions needed in the masked AVX512 kernel.
|
|
for (ib = 0; ib + 16 <= M; ib += 16) {
|
|
for (jb = 0; jb + 16 <= N; jb += 16) {
|
|
transpose_kernel_16x16_avx512(
|
|
&src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
|
|
}
|
|
for (int64_t i = ib; i < ib + 16; i += 4) {
|
|
transpose_kernel_4x4_sse(
|
|
&src[i * ld_src + jb], ld_src, &dst[i + jb * ld_dst], ld_dst);
|
|
}
|
|
}
|
|
} else if (N % 16 == 8) {
|
|
// If the remainder has 8 columns, we use the AVX kenrel for the remainder
|
|
// because it requires 2 * 40 = 80 instructions instead of 4 * 16 + 2 * 8
|
|
// = 80 instructions + looping overhead in the masked AVX512 kernel.
|
|
for (ib = 0; ib + 16 <= M; ib += 16) {
|
|
for (jb = 0; jb + 16 <= N; jb += 16) {
|
|
transpose_kernel_16x16_avx512(
|
|
&src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
|
|
}
|
|
for (int64_t i = ib; i < ib + 16; i += 8) {
|
|
transpose_kernel_8x8_avx2(
|
|
&src[i * ld_src + jb], ld_src, &dst[i + jb * ld_dst], ld_dst);
|
|
}
|
|
}
|
|
} else {
|
|
for (ib = 0; ib + 16 <= M; ib += 16) {
|
|
for (jb = 0; jb + 16 <= N; jb += 16) {
|
|
transpose_kernel_16x16_avx512(
|
|
&src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
|
|
}
|
|
if (jb < N) {
|
|
transpose_kernel_mxn_avx512<16>(
|
|
N - jb,
|
|
&src[ib * ld_src + jb],
|
|
ld_src,
|
|
&dst[ib + jb * ld_dst],
|
|
ld_dst);
|
|
}
|
|
}
|
|
}
|
|
|
|
// Specialization for small M - ib cases so that the compiler can inline
|
|
// transpose_kernel_mxn_avx512 and unroll the loops whose iteration count
|
|
// depends on by M - ib .
|
|
// Specialization for m helps more than for n in transpose_kernel_mxn_avx512
|
|
// because we have more loops in that function whose iteration count depends
|
|
// on m.
|
|
switch (M - ib) {
|
|
case 1:
|
|
for (int64_t j = 0; j < N; ++j) {
|
|
dst[ib + j * ld_dst] = src[ib * ld_src + j];
|
|
}
|
|
break;
|
|
case 2:
|
|
for (jb = 0; jb + 4 <= N; jb += 4) {
|
|
transpose_kernel_mxn_sse<2>(
|
|
4,
|
|
&src[ib * ld_src + jb],
|
|
ld_src,
|
|
&dst[ib + jb * ld_dst],
|
|
ld_dst);
|
|
}
|
|
if (jb < N) {
|
|
transpose_kernel_mxn_sse<2>(
|
|
N - jb,
|
|
&src[ib * ld_src + jb],
|
|
ld_src,
|
|
&dst[ib + jb * ld_dst],
|
|
ld_dst);
|
|
}
|
|
break;
|
|
case 3:
|
|
for (jb = 0; jb + 4 <= N; jb += 4) {
|
|
transpose_kernel_mxn_sse<3>(
|
|
4,
|
|
&src[ib * ld_src + jb],
|
|
ld_src,
|
|
&dst[ib + jb * ld_dst],
|
|
ld_dst);
|
|
}
|
|
if (jb < N) {
|
|
transpose_kernel_mxn_sse<3>(
|
|
N - jb,
|
|
&src[ib * ld_src + jb],
|
|
ld_src,
|
|
&dst[ib + jb * ld_dst],
|
|
ld_dst);
|
|
}
|
|
break;
|
|
case 4:
|
|
for (jb = 0; jb + 4 <= N; jb += 4) {
|
|
transpose_kernel_4x4_sse(
|
|
&src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
|
|
}
|
|
if (jb < N) {
|
|
transpose_kernel_mxn_sse<4>(
|
|
N - jb,
|
|
&src[ib * ld_src + jb],
|
|
ld_src,
|
|
&dst[ib + jb * ld_dst],
|
|
ld_dst);
|
|
}
|
|
break;
|
|
case 5:
|
|
for (jb = 0; jb + 8 <= N; jb += 8) {
|
|
transpose_kernel_mxn_avx2<5>(
|
|
8,
|
|
&src[ib * ld_src + jb],
|
|
ld_src,
|
|
&dst[ib + jb * ld_dst],
|
|
ld_dst);
|
|
}
|
|
if (jb < N) {
|
|
transpose_kernel_mxn_avx2<5>(
|
|
N - jb,
|
|
&src[ib * ld_src + jb],
|
|
ld_src,
|
|
&dst[ib + jb * ld_dst],
|
|
ld_dst);
|
|
}
|
|
break;
|
|
case 6:
|
|
for (jb = 0; jb + 8 <= N; jb += 8) {
|
|
transpose_kernel_mxn_avx2<6>(
|
|
8,
|
|
&src[ib * ld_src + jb],
|
|
ld_src,
|
|
&dst[ib + jb * ld_dst],
|
|
ld_dst);
|
|
}
|
|
if (jb < N) {
|
|
transpose_kernel_mxn_avx2<6>(
|
|
N - jb,
|
|
&src[ib * ld_src + jb],
|
|
ld_src,
|
|
&dst[ib + jb * ld_dst],
|
|
ld_dst);
|
|
}
|
|
break;
|
|
case 7:
|
|
for (jb = 0; jb + 16 <= N; jb += 16) {
|
|
transpose_kernel_mxn_avx512<7>(
|
|
16,
|
|
&src[ib * ld_src + jb],
|
|
ld_src,
|
|
&dst[ib + jb * ld_dst],
|
|
ld_dst);
|
|
}
|
|
if (jb < N) {
|
|
transpose_kernel_mxn_avx512<7>(
|
|
N - jb,
|
|
&src[ib * ld_src + jb],
|
|
ld_src,
|
|
&dst[ib + jb * ld_dst],
|
|
ld_dst);
|
|
}
|
|
break;
|
|
case 8:
|
|
for (jb = 0; jb + 8 <= N; jb += 8) {
|
|
transpose_kernel_8x8_avx2(
|
|
&src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
|
|
}
|
|
if (jb < N) {
|
|
transpose_kernel_mxn_avx2<8>(
|
|
N - jb,
|
|
&src[ib * ld_src + jb],
|
|
ld_src,
|
|
&dst[ib + jb * ld_dst],
|
|
ld_dst);
|
|
}
|
|
break;
|
|
case 9:
|
|
for (jb = 0; jb + 16 <= N; jb += 16) {
|
|
transpose_kernel_mxn_avx512<9>(
|
|
16,
|
|
&src[ib * ld_src + jb],
|
|
ld_src,
|
|
&dst[ib + jb * ld_dst],
|
|
ld_dst);
|
|
}
|
|
if (jb < N) {
|
|
transpose_kernel_mxn_avx512<9>(
|
|
N - jb,
|
|
&src[ib * ld_src + jb],
|
|
ld_src,
|
|
&dst[ib + jb * ld_dst],
|
|
ld_dst);
|
|
}
|
|
break;
|
|
case 10:
|
|
for (jb = 0; jb + 16 <= N; jb += 16) {
|
|
transpose_kernel_mxn_avx512<10>(
|
|
16,
|
|
&src[ib * ld_src + jb],
|
|
ld_src,
|
|
&dst[ib + jb * ld_dst],
|
|
ld_dst);
|
|
}
|
|
if (jb < N) {
|
|
transpose_kernel_mxn_avx512<10>(
|
|
N - jb,
|
|
&src[ib * ld_src + jb],
|
|
ld_src,
|
|
&dst[ib + jb * ld_dst],
|
|
ld_dst);
|
|
}
|
|
break;
|
|
case 11:
|
|
for (jb = 0; jb + 16 <= N; jb += 16) {
|
|
transpose_kernel_mxn_avx512<11>(
|
|
16,
|
|
&src[ib * ld_src + jb],
|
|
ld_src,
|
|
&dst[ib + jb * ld_dst],
|
|
ld_dst);
|
|
}
|
|
if (jb < N) {
|
|
transpose_kernel_mxn_avx512<11>(
|
|
N - jb,
|
|
&src[ib * ld_src + jb],
|
|
ld_src,
|
|
&dst[ib + jb * ld_dst],
|
|
ld_dst);
|
|
}
|
|
break;
|
|
case 12:
|
|
for (jb = 0; jb + 16 <= N; jb += 16) {
|
|
transpose_kernel_mxn_avx512<12>(
|
|
16,
|
|
&src[ib * ld_src + jb],
|
|
ld_src,
|
|
&dst[ib + jb * ld_dst],
|
|
ld_dst);
|
|
}
|
|
if (jb < N) {
|
|
transpose_kernel_mxn_avx512<12>(
|
|
N - jb,
|
|
&src[ib * ld_src + jb],
|
|
ld_src,
|
|
&dst[ib + jb * ld_dst],
|
|
ld_dst);
|
|
}
|
|
break;
|
|
case 13:
|
|
for (jb = 0; jb + 16 <= N; jb += 16) {
|
|
transpose_kernel_mxn_avx512<13>(
|
|
16,
|
|
&src[ib * ld_src + jb],
|
|
ld_src,
|
|
&dst[ib + jb * ld_dst],
|
|
ld_dst);
|
|
}
|
|
if (jb < N) {
|
|
transpose_kernel_mxn_avx512<13>(
|
|
N - jb,
|
|
&src[ib * ld_src + jb],
|
|
ld_src,
|
|
&dst[ib + jb * ld_dst],
|
|
ld_dst);
|
|
}
|
|
break;
|
|
case 14:
|
|
for (jb = 0; jb + 16 <= N; jb += 16) {
|
|
transpose_kernel_mxn_avx512<14>(
|
|
16,
|
|
&src[ib * ld_src + jb],
|
|
ld_src,
|
|
&dst[ib + jb * ld_dst],
|
|
ld_dst);
|
|
}
|
|
if (jb < N) {
|
|
transpose_kernel_mxn_avx512<14>(
|
|
N - jb,
|
|
&src[ib * ld_src + jb],
|
|
ld_src,
|
|
&dst[ib + jb * ld_dst],
|
|
ld_dst);
|
|
}
|
|
break;
|
|
case 15:
|
|
for (jb = 0; jb + 16 <= N; jb += 16) {
|
|
transpose_kernel_mxn_avx512<15>(
|
|
16,
|
|
&src[ib * ld_src + jb],
|
|
ld_src,
|
|
&dst[ib + jb * ld_dst],
|
|
ld_dst);
|
|
}
|
|
if (jb < N) {
|
|
transpose_kernel_mxn_avx512<15>(
|
|
N - jb,
|
|
&src[ib * ld_src + jb],
|
|
ld_src,
|
|
&dst[ib + jb * ld_dst],
|
|
ld_dst);
|
|
}
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
template <>
|
|
void transpose_avx512(
|
|
int64_t M,
|
|
int64_t N,
|
|
const uint16_t* src,
|
|
int64_t ld_src,
|
|
uint16_t* dst,
|
|
int64_t ld_dst) {
|
|
if (M == ld_dst && (M == 2 || M == 4)) {
|
|
transpose_avx512_contiguous_wide(M, N, src, ld_src, dst, ld_dst);
|
|
} else if (N == ld_src && (N == 2 || N == 4)) {
|
|
transpose_avx512_contiguous_thin(M, N, src, ld_src, dst, ld_dst);
|
|
} else {
|
|
int64_t i = 0;
|
|
for (; i < M / 16 * 16; i += 16) {
|
|
int64_t j = 0;
|
|
for (; j < N / 16 * 16; j += 16) {
|
|
transpose_16x16_block<false, false>(
|
|
src + i * ld_src + j, ld_src, dst + j * ld_dst + i, ld_dst);
|
|
}
|
|
// handle j rem
|
|
int nrem = N - j;
|
|
if (nrem > 0) {
|
|
transpose_16x16_block<false, true>(
|
|
src + i * ld_src + j,
|
|
ld_src,
|
|
dst + j * ld_dst + i,
|
|
ld_dst,
|
|
16,
|
|
nrem);
|
|
}
|
|
}
|
|
// handle i rem
|
|
int mrem = M - i;
|
|
if (mrem > 0) {
|
|
int j = 0;
|
|
for (; j < N / 16 * 16; j += 16) {
|
|
transpose_16x16_block<true, false>(
|
|
src + i * ld_src + j,
|
|
ld_src,
|
|
dst + j * ld_dst + i,
|
|
ld_dst,
|
|
mrem,
|
|
16);
|
|
}
|
|
// handle j rem
|
|
int nrem = N - j;
|
|
transpose_16x16_block<true, true>(
|
|
src + i * ld_src + j,
|
|
ld_src,
|
|
dst + j * ld_dst + i,
|
|
ld_dst,
|
|
mrem,
|
|
nrem);
|
|
}
|
|
}
|
|
}
|
|
|
|
template <>
|
|
void transpose_avx512(
|
|
int64_t M,
|
|
int64_t N,
|
|
const uint8_t* src,
|
|
int64_t ld_src,
|
|
uint8_t* dst,
|
|
int64_t ld_dst) {
|
|
if (M == ld_dst && (M == 2 || M == 4)) {
|
|
transpose_avx512_contiguous_wide(M, N, src, ld_src, dst, ld_dst);
|
|
} else if (N == ld_src && (N == 2 || N == 4)) {
|
|
transpose_avx512_contiguous_thin(M, N, src, ld_src, dst, ld_dst);
|
|
} else {
|
|
int64_t i = 0;
|
|
for (; i < M / 16 * 16; i += 16) {
|
|
int64_t j = 0;
|
|
for (; j < N / 32 * 32; j += 32) {
|
|
transpose_16x32_block<false, false>(
|
|
src + i * ld_src + j, ld_src, dst + j * ld_dst + i, ld_dst);
|
|
}
|
|
// handle j rem
|
|
int nrem = N - j;
|
|
if (nrem > 0) {
|
|
transpose_16x32_block<false, true>(
|
|
src + i * ld_src + j,
|
|
ld_src,
|
|
dst + j * ld_dst + i,
|
|
ld_dst,
|
|
16,
|
|
nrem);
|
|
}
|
|
}
|
|
|
|
// handle i rem
|
|
int mrem = M - i;
|
|
if (mrem > 0) {
|
|
int64_t j = 0;
|
|
for (; j < N / 32 * 32; j += 32) {
|
|
transpose_16x32_block<true, false>(
|
|
src + i * ld_src + j,
|
|
ld_src,
|
|
dst + j * ld_dst + i,
|
|
ld_dst,
|
|
mrem,
|
|
32);
|
|
}
|
|
// handle j rem
|
|
int nrem = N - j;
|
|
transpose_16x32_block<true, true>(
|
|
src + i * ld_src + j,
|
|
ld_src,
|
|
dst + j * ld_dst + i,
|
|
ld_dst,
|
|
mrem,
|
|
nrem);
|
|
}
|
|
}
|
|
}
|
|
|
|
} // namespace internal
|
|
|
|
} // namespace fbgemm
|