/* * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. */ #if defined(__x86_64__) || defined(__i386__) || \ (defined(_MSC_VER) && (defined(_M_X64) || defined(_M_IX86))) #include #endif #include "./TransposeUtils.h" #include "./TransposeUtilsAvx2.h" namespace fbgemm { namespace { // 16 * 6 = 96 instructions inline void transpose_kernel_16x16_avx512( const float* src, int64_t ld_src, float* dst, int64_t ld_dst) { // load from src to registers // a: a0 a1 a2 a3 a4 a5 a6 a7 a8 a9 a10 a11 a12 a13 a14 a15 // b: b0 b1 b2 b3 b4 b5 b6 b7 b8 b9 b10 b11 b12 b13 b14 b15 // c: c0 c1 c2 c3 c4 c5 c6 c7 c8 c9 c10 c11 c12 c13 c14 c15 // d: d0 d1 d2 d3 d4 d5 d6 d7 d8 d9 d10 d11 d12 d13 d14 d15 // e: e0 e1 e2 e3 e4 e5 e6 e7 e8 e9 e10 e11 e12 e13 e14 e15 // f: f0 f1 f2 f3 f4 f5 f6 f7 f8 f9 f10 f11 f12 f13 f14 f15 // g: g0 g1 g2 g3 g4 g5 g6 g7 g8 g9 g10 g11 g12 g13 g14 g15 // h: h0 h1 h2 h3 h4 h5 h6 h7 h8 h9 h10 h11 h12 h13 h14 h15 // i: i0 i1 i2 i3 i4 i5 i6 i7 i8 i9 i10 i11 i12 i13 i14 i15 // j: j0 j1 j2 j3 j4 j5 j6 j7 j8 j9 j10 j11 j12 j13 j14 j15 // k: k0 k1 k2 k3 k4 k5 k6 k7 k8 k9 k10 k11 k12 k13 k14 k15 // l: l0 l1 l2 l3 l4 l5 l6 l7 l8 l9 l10 l11 l12 l13 l14 l15 // m: m0 m1 m2 m3 m4 m5 m6 m7 m8 m9 m10 m11 m12 m13 m14 m15 // n: n0 n1 n2 n3 n4 n5 n6 n7 n8 n9 n10 n11 n12 n13 n14 n15 // o: o0 o1 o2 o3 o4 o5 o6 o7 o8 o9 o10 o11 o12 o13 o14 o15 // p: p0 p1 p2 p3 p4 p5 p6 p7 p8 p9 p10 p11 p12 p13 p14 p15 __m512 a = _mm512_loadu_ps(&src[0 * ld_src]); __m512 b = _mm512_loadu_ps(&src[1 * ld_src]); __m512 c = _mm512_loadu_ps(&src[2 * ld_src]); __m512 d = _mm512_loadu_ps(&src[3 * ld_src]); __m512 e = _mm512_loadu_ps(&src[4 * ld_src]); __m512 f = _mm512_loadu_ps(&src[5 * ld_src]); __m512 g = _mm512_loadu_ps(&src[6 * ld_src]); __m512 h = _mm512_loadu_ps(&src[7 * ld_src]); __m512 i = _mm512_loadu_ps(&src[8 * ld_src]); __m512 j = _mm512_loadu_ps(&src[9 * ld_src]); __m512 k = _mm512_loadu_ps(&src[10 * ld_src]); __m512 l = _mm512_loadu_ps(&src[11 * ld_src]); __m512 m = _mm512_loadu_ps(&src[12 * ld_src]); __m512 n = _mm512_loadu_ps(&src[13 * ld_src]); __m512 o = _mm512_loadu_ps(&src[14 * ld_src]); __m512 p = _mm512_loadu_ps(&src[15 * ld_src]); __m512 ta, tb, tc, td, te, tf, tg, th, ti, tj, tk, tl, tm, tn, to, tq; // unpacking and interleaving 32-bit elements // a0 b0 a1 b1 a4 b4 a5 b5 a8 b8 a9 b9 a12 b12 a13 b13 // a2 b2 a3 b3 a6 b6 a7 b7 a10 b10 a11 b11 a14 b14 a15 b15 // c0 d0 c1 d1 ... // c2 d2 c3 d3 ... // e0 f0 e1 f1 ... // e2 f2 e3 f3 ... // g0 h0 g1 h1 ... // g2 h2 g3 h3 ... // i0 ... // i2 ... // k0 ... // k2 ... // m0 ... // m2 ... // o0 ... // o1 ... ta = _mm512_unpacklo_ps(a, b); tb = _mm512_unpackhi_ps(a, b); tc = _mm512_unpacklo_ps(c, d); td = _mm512_unpackhi_ps(c, d); te = _mm512_unpacklo_ps(e, f); tf = _mm512_unpackhi_ps(e, f); tg = _mm512_unpacklo_ps(g, h); th = _mm512_unpackhi_ps(g, h); ti = _mm512_unpacklo_ps(i, j); tj = _mm512_unpackhi_ps(i, j); tk = _mm512_unpacklo_ps(k, l); tl = _mm512_unpackhi_ps(k, l); tm = _mm512_unpacklo_ps(m, n); tn = _mm512_unpackhi_ps(m, n); to = _mm512_unpacklo_ps(o, p); tq = _mm512_unpackhi_ps(o, p); // unpacking and interleaving 64-bit elements // a0 b0 c0 d0 a4 b4 c4 d4 a8 b8 c8 d8 a12 b12 c12 d12 // a1 b1 c1 d1 ... // a2 b2 c2 d2 ... // a3 b3 c3 d3 ... // e0 f0 g0 h0 e4 f4 g4 h4 e8 f8 g8 h8 e12 f12 g12 h12 // e1 f1 g1 h1 ... // e2 f2 g2 h2 ... // e3 f3 g3 h3 ... // i0 j0 k0 l0 ... // i1 j1 k1 l1 ... // i2 j2 k2 l2 ... // i3 j3 k3 l3 ... // m0 n0 o0 p0 ... // m1 n1 o1 p1 ... // m2 n2 o2 p2 ... // m3 n3 o3 p3 ... a = _mm512_castpd_ps( _mm512_unpacklo_pd(_mm512_castps_pd(ta), _mm512_castps_pd(tc))); b = _mm512_castpd_ps( _mm512_unpackhi_pd(_mm512_castps_pd(ta), _mm512_castps_pd(tc))); c = _mm512_castpd_ps( _mm512_unpacklo_pd(_mm512_castps_pd(tb), _mm512_castps_pd(td))); d = _mm512_castpd_ps( _mm512_unpackhi_pd(_mm512_castps_pd(tb), _mm512_castps_pd(td))); e = _mm512_castpd_ps( _mm512_unpacklo_pd(_mm512_castps_pd(te), _mm512_castps_pd(tg))); f = _mm512_castpd_ps( _mm512_unpackhi_pd(_mm512_castps_pd(te), _mm512_castps_pd(tg))); g = _mm512_castpd_ps( _mm512_unpacklo_pd(_mm512_castps_pd(tf), _mm512_castps_pd(th))); h = _mm512_castpd_ps( _mm512_unpackhi_pd(_mm512_castps_pd(tf), _mm512_castps_pd(th))); i = _mm512_castpd_ps( _mm512_unpacklo_pd(_mm512_castps_pd(ti), _mm512_castps_pd(tk))); j = _mm512_castpd_ps( _mm512_unpackhi_pd(_mm512_castps_pd(ti), _mm512_castps_pd(tk))); k = _mm512_castpd_ps( _mm512_unpacklo_pd(_mm512_castps_pd(tj), _mm512_castps_pd(tl))); l = _mm512_castpd_ps( _mm512_unpackhi_pd(_mm512_castps_pd(tj), _mm512_castps_pd(tl))); m = _mm512_castpd_ps( _mm512_unpacklo_pd(_mm512_castps_pd(tm), _mm512_castps_pd(to))); n = _mm512_castpd_ps( _mm512_unpackhi_pd(_mm512_castps_pd(tm), _mm512_castps_pd(to))); o = _mm512_castpd_ps( _mm512_unpacklo_pd(_mm512_castps_pd(tn), _mm512_castps_pd(tq))); p = _mm512_castpd_ps( _mm512_unpackhi_pd(_mm512_castps_pd(tn), _mm512_castps_pd(tq))); // shuffle 128-bits (composed of 4 32-bit elements) // a0 b0 c0 d0 a8 b8 c8 d8 e0 f0 g0 h0 e8 f8 g8 h8 // a1 b1 c1 d1 ... // a2 b2 c2 d2 ... // a3 b3 c3 d3 ... // a4 b4 c4 d4 ... // a5 b5 c5 d5 ... // a6 b6 c6 d6 ... // a7 b7 c7 d7 ... // i0 j0 k0 l0 i8 j8 k8 l8 m0 n0 o0 p0 m8 n8 o8 p8 // i1 j1 k1 l1 ... // i2 j2 k2 l2 ... // i3 j3 k3 l3 ... // i4 j4 k4 l4 ... // i5 j5 k5 l5 ... // i6 j6 k6 l6 ... // i7 j7 k7 l7 ... ta = _mm512_shuffle_f32x4(a, e, 0x88); tb = _mm512_shuffle_f32x4(b, f, 0x88); tc = _mm512_shuffle_f32x4(c, g, 0x88); td = _mm512_shuffle_f32x4(d, h, 0x88); te = _mm512_shuffle_f32x4(a, e, 0xdd); tf = _mm512_shuffle_f32x4(b, f, 0xdd); tg = _mm512_shuffle_f32x4(c, g, 0xdd); th = _mm512_shuffle_f32x4(d, h, 0xdd); ti = _mm512_shuffle_f32x4(i, m, 0x88); tj = _mm512_shuffle_f32x4(j, n, 0x88); tk = _mm512_shuffle_f32x4(k, o, 0x88); tl = _mm512_shuffle_f32x4(l, p, 0x88); tm = _mm512_shuffle_f32x4(i, m, 0xdd); tn = _mm512_shuffle_f32x4(j, n, 0xdd); to = _mm512_shuffle_f32x4(k, o, 0xdd); tq = _mm512_shuffle_f32x4(l, p, 0xdd); // shuffle 128-bits (composed of 4 32-bit elements) // a0 b0 c0 d0 ... o0 // a1 b1 c1 d1 ... o1 // a2 b2 c2 d2 ... o2 // a3 b3 c3 d3 ... o3 // a4 ... // a5 ... // a6 ... // a7 ... // a8 ... // a9 ... // a10 ... // a11 ... // a12 ... // a13 ... // a14 ... // a15 b15 c15 d15 ... o15 a = _mm512_shuffle_f32x4(ta, ti, 0x88); b = _mm512_shuffle_f32x4(tb, tj, 0x88); c = _mm512_shuffle_f32x4(tc, tk, 0x88); d = _mm512_shuffle_f32x4(td, tl, 0x88); e = _mm512_shuffle_f32x4(te, tm, 0x88); f = _mm512_shuffle_f32x4(tf, tn, 0x88); g = _mm512_shuffle_f32x4(tg, to, 0x88); h = _mm512_shuffle_f32x4(th, tq, 0x88); i = _mm512_shuffle_f32x4(ta, ti, 0xdd); j = _mm512_shuffle_f32x4(tb, tj, 0xdd); k = _mm512_shuffle_f32x4(tc, tk, 0xdd); l = _mm512_shuffle_f32x4(td, tl, 0xdd); m = _mm512_shuffle_f32x4(te, tm, 0xdd); n = _mm512_shuffle_f32x4(tf, tn, 0xdd); o = _mm512_shuffle_f32x4(tg, to, 0xdd); p = _mm512_shuffle_f32x4(th, tq, 0xdd); // store from registers to dst _mm512_storeu_ps(&dst[0 * ld_dst], a); _mm512_storeu_ps(&dst[1 * ld_dst], b); _mm512_storeu_ps(&dst[2 * ld_dst], c); _mm512_storeu_ps(&dst[3 * ld_dst], d); _mm512_storeu_ps(&dst[4 * ld_dst], e); _mm512_storeu_ps(&dst[5 * ld_dst], f); _mm512_storeu_ps(&dst[6 * ld_dst], g); _mm512_storeu_ps(&dst[7 * ld_dst], h); _mm512_storeu_ps(&dst[8 * ld_dst], i); _mm512_storeu_ps(&dst[9 * ld_dst], j); _mm512_storeu_ps(&dst[10 * ld_dst], k); _mm512_storeu_ps(&dst[11 * ld_dst], l); _mm512_storeu_ps(&dst[12 * ld_dst], m); _mm512_storeu_ps(&dst[13 * ld_dst], n); _mm512_storeu_ps(&dst[14 * ld_dst], o); _mm512_storeu_ps(&dst[15 * ld_dst], p); } // kernel for transposing mxn where m, n <= 16 // M + (M + 1) / 2 * 2 + (M + 3) / 4 * 4 + (M + 7) / 8 * 8 + 2 * N instructions template void transpose_kernel_mxn_avx512( int N, const float* src, int64_t ld_src, float* dst, int64_t ld_dst) { // load from src to registers __mmask16 src_mask = (1 << N) - 1; __m512 input[16]; int i; for (i = 0; i < M; ++i) { input[i] = _mm512_maskz_loadu_ps(src_mask, &src[i * ld_src]); } for (; i < 16; ++i) { // Not really needed but to avoid uninitialized variable warning. // Shouldn't be much overhead because xor can be executed in parallel with // other instructions. input[i] = _mm512_setzero_ps(); } // unpacking and interleaving 32-bit elements __m512 temp[16]; for (i = 0; i < (M + 1) / 2; ++i) { temp[2 * i] = _mm512_unpacklo_ps(input[2 * i], input[2 * i + 1]); temp[2 * i + 1] = _mm512_unpackhi_ps(input[2 * i], input[2 * i + 1]); } for (i = i * 2; i < 16; ++i) { temp[i] = _mm512_setzero_ps(); } // unpacking and interleaving 64-bit elements for (i = 0; i < (M + 3) / 4; ++i) { input[4 * i] = _mm512_castpd_ps(_mm512_unpacklo_pd( _mm512_castps_pd(temp[4 * i]), _mm512_castps_pd(temp[4 * i + 2]))); input[4 * i + 1] = _mm512_castpd_ps(_mm512_unpackhi_pd( _mm512_castps_pd(temp[4 * i]), _mm512_castps_pd(temp[4 * i + 2]))); input[4 * i + 2] = _mm512_castpd_ps(_mm512_unpacklo_pd( _mm512_castps_pd(temp[4 * i + 1]), _mm512_castps_pd(temp[4 * i + 3]))); input[4 * i + 3] = _mm512_castpd_ps(_mm512_unpackhi_pd( _mm512_castps_pd(temp[4 * i + 1]), _mm512_castps_pd(temp[4 * i + 3]))); } // shuffle 128-bits (composed of 4 32-bit elements) for (i = 0; i < (M + 7) / 8; ++i) { temp[8 * i] = _mm512_shuffle_f32x4(input[8 * i], input[8 * i + 4], 0x88); temp[8 * i + 1] = _mm512_shuffle_f32x4(input[8 * i + 1], input[8 * i + 5], 0x88); temp[8 * i + 2] = _mm512_shuffle_f32x4(input[8 * i + 2], input[8 * i + 6], 0x88); temp[8 * i + 3] = _mm512_shuffle_f32x4(input[8 * i + 3], input[8 * i + 7], 0x88); temp[8 * i + 4] = _mm512_shuffle_f32x4(input[8 * i], input[8 * i + 4], 0xdd); temp[8 * i + 5] = _mm512_shuffle_f32x4(input[8 * i + 1], input[8 * i + 5], 0xdd); temp[8 * i + 6] = _mm512_shuffle_f32x4(input[8 * i + 2], input[8 * i + 6], 0xdd); temp[8 * i + 7] = _mm512_shuffle_f32x4(input[8 * i + 3], input[8 * i + 7], 0xdd); } // store from registers to dst __mmask16 dst_mask = (1 << M) - 1; for (i = 0; i < N; ++i) { if (i < 8) { input[i] = _mm512_shuffle_f32x4(temp[i], temp[8 + i], 0x88); } else { input[i] = _mm512_shuffle_f32x4(temp[i - 8], temp[i], 0xdd); } _mm512_mask_storeu_ps(&dst[i * ld_dst], dst_mask, input[i]); } } } // namespace namespace internal { template void transpose_avx512_contiguous_thin( const int64_t M, const int64_t N, const T* src, int64_t ld_src, T* dst, int64_t ld_dst); template void transpose_avx512_contiguous_wide( const int64_t M, const int64_t N, const T* src, int64_t ld_src, T* dst, int64_t ld_dst); // Permute elements in 128 bit lane // e.g., if a 128-bit lane has the following elements: // 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 // // After this function call, it becomes // 0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15 // The same happens with other 3 lanes. static inline __m512i permute_row(__m512i row) { // clang-format off __m256i shuffle_256v0 = _mm256_set_epi8( 15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0); // clang-format on __m512i shuffle_512v = _mm512_castsi256_si512(shuffle_256v0); row = _mm512_shuffle_epi8( row, _mm512_inserti64x4(shuffle_512v, shuffle_256v0, 1)); return row; } static inline void core_transpose_16x32_block_i8(__m512i r[], __m512i u[]) { // Result after this operation; Read in conjunction with comments in // transpose_16x32_block // 00_00 00_01 01_00 01_01 00_04 00_05 01_04 01_05 04_00 04_01 05_00 05_01 // 04_04 04_05 05_04 05_05 u[0] = _mm512_unpacklo_epi64(r[0], r[1]); // 00_02 00_03 01_02 01_03 00_06 00_07 01_06 01_07 04_02 04_03 05_02 05_03 // 04_06 04_07 05_06 05_07 u[1] = _mm512_unpackhi_epi64(r[0], r[1]); // 02_00 02_01 03_00 03_01 02_04 02_05 03_04 03_05 06_00 06_01 07_00 07_01 // 06_04 06_05 07_04 07_05 u[2] = _mm512_unpacklo_epi64(r[2], r[3]); // 02_02 02_03 03_02 03_03 02_06 02_07 03_06 03_07 06_02 06_03 07_02 07_03 // 06_06 06_07 07_06 07_07 u[3] = _mm512_unpackhi_epi64(r[2], r[3]); // 08_00 08_01 09_00 09_01 08_04 08_05 09_04 09_05 12_00 12_01 13_00 13_01 // 12_04 12_05 13_04 13_05 u[4] = _mm512_unpacklo_epi64(r[4], r[5]); u[5] = _mm512_unpackhi_epi64(r[4], r[5]); u[6] = _mm512_unpacklo_epi64(r[6], r[7]); u[7] = _mm512_unpackhi_epi64(r[6], r[7]); // This instruction doesn't exist for epi32 so casting to ps // 00_00 01_00 02_00 03_00 00_04 01_04 02_04 03_04 04_00 05_00 06_00 07_00 // 04_04 05_04 06_04 07_04 r[0] = _mm512_castps_si512(_mm512_shuffle_ps( _mm512_castsi512_ps(u[0]), _mm512_castsi512_ps(u[2]), 0x88)); // 00_01 01_01 02_01 03_01 00_05 01_05 02_05 03_05 04_01 05_01 06_01 07_01 // 04_05 05_05 06_05 07_05 r[1] = _mm512_castps_si512(_mm512_shuffle_ps( _mm512_castsi512_ps(u[0]), _mm512_castsi512_ps(u[2]), 0xDD)); r[2] = _mm512_castps_si512(_mm512_shuffle_ps( _mm512_castsi512_ps(u[1]), _mm512_castsi512_ps(u[3]), 0x88)); r[3] = _mm512_castps_si512(_mm512_shuffle_ps( _mm512_castsi512_ps(u[1]), _mm512_castsi512_ps(u[3]), 0xDD)); // 08_00 09_00 10_00 11_00 08_04 09_04 10_04 11_04 12_00 13_00 14_00 15_00 // 12_04 13_04 14_04 15_04 r[4] = _mm512_castps_si512(_mm512_shuffle_ps( _mm512_castsi512_ps(u[4]), _mm512_castsi512_ps(u[6]), 0x88)); r[5] = _mm512_castps_si512(_mm512_shuffle_ps( _mm512_castsi512_ps(u[4]), _mm512_castsi512_ps(u[6]), 0xDD)); r[6] = _mm512_castps_si512(_mm512_shuffle_ps( _mm512_castsi512_ps(u[5]), _mm512_castsi512_ps(u[7]), 0x88)); r[7] = _mm512_castps_si512(_mm512_shuffle_ps( _mm512_castsi512_ps(u[5]), _mm512_castsi512_ps(u[7]), 0xDD)); // permute among 128-bit lanes r[0] = permute_row(r[0]); r[1] = permute_row(r[1]); r[2] = permute_row(r[2]); r[3] = permute_row(r[3]); r[4] = permute_row(r[4]); r[5] = permute_row(r[5]); r[6] = permute_row(r[6]); r[7] = permute_row(r[7]); __m512i const1 = _mm512_set_epi32( 27, 19, 11, 3, 26, 18, 10, 2, 25, 17, 9, 1, 24, 16, 8, 0); __m512i const2 = _mm512_set_epi32( 31, 23, 15, 7, 30, 22, 14, 6, 29, 21, 13, 5, 28, 20, 12, 4); // merge 128-bit values from two regs u[0] = _mm512_permutex2var_epi32(r[0], const1, r[4]); u[1] = _mm512_permutex2var_epi32(r[0], const2, r[4]); u[2] = _mm512_permutex2var_epi32(r[1], const1, r[5]); u[3] = _mm512_permutex2var_epi32(r[1], const2, r[5]); u[4] = _mm512_permutex2var_epi32(r[2], const1, r[6]); u[5] = _mm512_permutex2var_epi32(r[2], const2, r[6]); u[6] = _mm512_permutex2var_epi32(r[3], const1, r[7]); u[7] = _mm512_permutex2var_epi32(r[3], const2, r[7]); } static inline void core_transpose_16x16_block(__m512i r[], __m512i u[]) { // a0a1 b0b1 a2a3 b2b3 a8a9 b8b9 a10a11 b10b11 e0e1 f0f1 e2e3 f2f3 e8e9 f8f9 // e10e11 f10f11 u[0] = _mm512_unpacklo_epi32(r[0], r[1]); // a4a5 b4b5 a6a7 b6b7 a12a13 b12b13 a14a15 b14b15 e4e5 f4f5 e6e7 f6f7 // e12e13 f12f13 e14e15 f14f15 u[1] = _mm512_unpackhi_epi32(r[0], r[1]); // c0c1 d0d1 c2c3 d2d3 c8c9 d8d9 c10c11 d10d11 g0g1 h0h1 g2g3 h2h3 g8g9 h8h9 // g10g11 h10h11 u[2] = _mm512_unpacklo_epi32(r[2], r[3]); // c4c5 d4b5 c6c7 d6b7 c12c13 d12d13 c14c15 d14d15 g4g5 h4h5 g6g7 h6h7 // g12g13 h12h13 g14g15 h14h15 u[3] = _mm512_unpackhi_epi32(r[2], r[3]); // i j m n u[4] = _mm512_unpacklo_epi32(r[4], r[5]); u[5] = _mm512_unpackhi_epi32(r[4], r[5]); // k l o p u[6] = _mm512_unpacklo_epi32(r[6], r[7]); u[7] = _mm512_unpackhi_epi32(r[6], r[7]); // a0a1 b0b1 c0c1 d0d1 a8a9 b8b9 c8c9 d8d9 e0e1 f0f1 g0g1 h0h1 e8e9 f8f9 g8g9 // h8h9 r[0] = _mm512_unpacklo_epi64(u[0], u[2]); // a2a3 b2b3 c2c3 d2d3 a10a11 b10b11 c10c11 d10d11 e2e3 f2f3 g2g3 h2h3 e10e11 // f10f11 g10g11 h10h11 r[1] = _mm512_unpackhi_epi64(u[0], u[2]); // a4a5 b4b5 c4c5 d4b5 a12a13 b12b13 c12c13 d12d13 r[2] = _mm512_unpacklo_epi64(u[1], u[3]); // a6a7 b6b7 c6c7 d6b7 a14a15 b14b15 c14c15 d14d15 r[3] = _mm512_unpackhi_epi64(u[1], u[3]); // i j k l m n o p r[4] = _mm512_unpacklo_epi64(u[4], u[6]); r[5] = _mm512_unpackhi_epi64(u[4], u[6]); r[6] = _mm512_unpacklo_epi64(u[5], u[7]); r[7] = _mm512_unpackhi_epi64(u[5], u[7]); __m512i const1 = _mm512_set_epi32( 0x00370035, 0x00330031, 0x00270025, 0x00230021, 0x00170015, 0x00130011, 0x00070005, 0x00030001, 0x00360034, 0x00320030, 0x00260024, 0x00220020, 0x00160014, 0x00120010, 0x00060004, 0x00020000); __m512i const2 = _mm512_set_epi32( 0x003f003d, 0x003b0039, 0x002f002d, 0x002b0029, 0x001f001d, 0x001b0019, 0x000f000d, 0x000b0009, 0x003e003c, 0x003a0038, 0x002e002c, 0x002a0028, 0x001e001c, 0x001a0018, 0x000e000c, 0x000a0008); // merge values from two regs u[0] = _mm512_permutex2var_epi16(r[0], const1, r[4]); // 0-- 1-- u[4] = _mm512_permutex2var_epi16(r[0], const2, r[4]); // 8-- 9-- u[2] = _mm512_permutex2var_epi16(r[2], const1, r[6]); // 4-- 5-- u[6] = _mm512_permutex2var_epi16(r[2], const2, r[6]); // 12-- 13-- u[1] = _mm512_permutex2var_epi16(r[1], const1, r[5]); // 2-- 3-- u[5] = _mm512_permutex2var_epi16(r[1], const2, r[5]); // 10-- 11-- u[3] = _mm512_permutex2var_epi16(r[3], const1, r[7]); // 6-- 7-- u[7] = _mm512_permutex2var_epi16(r[3], const2, r[7]); // 14-- 15-- } static inline void load_with_remainders_i16( const uint16_t* src, int64_t ld_src, __m512i r[], int mrem, int nrem) { __m512i t[16]; if (nrem < 16) { __mmask32 mask_nrem_v = (1ULL << nrem) - 1; for (int i = 0; i < mrem; ++i) { // mask load t[i] = _mm512_maskz_loadu_epi16(mask_nrem_v, src + i * ld_src); } } else { for (int i = 0; i < mrem; ++i) { // normal load t[i] = _mm512_castsi256_si512(_mm256_loadu_si256( reinterpret_cast(src + i * ld_src))); } } r[0] = _mm512_inserti64x4(t[0], _mm512_castsi512_si256(t[4]), 0x01); r[1] = _mm512_inserti64x4(t[1], _mm512_castsi512_si256(t[5]), 0x01); r[2] = _mm512_inserti64x4(t[2], _mm512_castsi512_si256(t[6]), 0x01); r[3] = _mm512_inserti64x4(t[3], _mm512_castsi512_si256(t[7]), 0x01); r[4] = _mm512_inserti64x4(t[8], _mm512_castsi512_si256(t[12]), 0x01); r[5] = _mm512_inserti64x4(t[9], _mm512_castsi512_si256(t[13]), 0x01); r[6] = _mm512_inserti64x4(t[10], _mm512_castsi512_si256(t[14]), 0x01); r[7] = _mm512_inserti64x4(t[11], _mm512_castsi512_si256(t[15]), 0x01); } static inline void load_with_remainders_i8( const uint8_t* src, int64_t ld_src, __m512i r[], int mrem, int nrem) { __m512i t[16]; if (nrem < 32) { __mmask64 mask_nrem_v = (1ULL << nrem) - 1; for (int i = 0; i < mrem; ++i) { // mask load t[i] = _mm512_maskz_loadu_epi8(mask_nrem_v, src + i * ld_src); } } else { for (int i = 0; i < mrem; ++i) { // normal load t[i] = _mm512_castsi256_si512(_mm256_loadu_si256( reinterpret_cast(src + i * ld_src))); } } r[0] = _mm512_inserti64x4(t[0], _mm512_castsi512_si256(t[4]), 0x01); r[1] = _mm512_inserti64x4(t[1], _mm512_castsi512_si256(t[5]), 0x01); r[2] = _mm512_inserti64x4(t[2], _mm512_castsi512_si256(t[6]), 0x01); r[3] = _mm512_inserti64x4(t[3], _mm512_castsi512_si256(t[7]), 0x01); r[4] = _mm512_inserti64x4(t[8], _mm512_castsi512_si256(t[12]), 0x01); r[5] = _mm512_inserti64x4(t[9], _mm512_castsi512_si256(t[13]), 0x01); r[6] = _mm512_inserti64x4(t[10], _mm512_castsi512_si256(t[14]), 0x01); r[7] = _mm512_inserti64x4(t[11], _mm512_castsi512_si256(t[15]), 0x01); } static inline void store_with_remainders_i16( uint16_t* dst, int64_t ld_dst, __m512i u[], int mrem, int nrem) { if (mrem < 16) { __mmask32 mask_mrem_v = (1ULL << mrem) - 1; int i = 0; for (; i < nrem / 2 * 2; i += 2) { // mask store int reg_idx = i / 2; _mm512_mask_storeu_epi16( dst + (i + 0) * ld_dst, mask_mrem_v, _mm512_castsi256_si512(_mm512_extracti32x8_epi32(u[reg_idx], 0x0))); _mm512_mask_storeu_epi16( dst + (i + 1) * ld_dst, mask_mrem_v, _mm512_castsi256_si512(_mm512_extracti32x8_epi32(u[reg_idx], 0x1))); } if (i < nrem) { int reg_idx = i / 2; _mm512_mask_storeu_epi16( dst + (i + 0) * ld_dst, mask_mrem_v, _mm512_castsi256_si512(_mm512_extracti32x8_epi32(u[reg_idx], 0x0))); } } else { int i = 0; for (; i < nrem / 2 * 2; i += 2) { // normal store int reg_idx = i / 2; _mm256_storeu_si256( reinterpret_cast<__m256i*>(dst + (i + 0) * ld_dst), _mm512_extracti32x8_epi32(u[reg_idx], 0x0)); _mm256_storeu_si256( reinterpret_cast<__m256i*>(dst + (i + 1) * ld_dst), _mm512_extracti32x8_epi32(u[reg_idx], 0x1)); } if (i < nrem) { int reg_idx = i / 2; _mm256_storeu_si256( reinterpret_cast<__m256i*>(dst + (i + 0) * ld_dst), _mm512_extracti32x8_epi32(u[reg_idx], 0x0)); } } } static inline void store_with_remainders_i8( uint8_t* dst, int64_t ld_dst, __m512i u[], int mrem, int nrem) { if (mrem < 16) { __mmask64 mask_mrem_v = (1ULL << mrem) - 1; int i = 0; for (; i < nrem / 4 * 4; i += 4) { // mask store // we need 0, 4, 8, 16 => 0, 2, 4, 6 // and 16, 20, 24, 28 => 1, 3, 5, 7 // See stores for non-rem case int reg_idx = i / 16 + 2 * ((i % 16) / 4); _mm512_mask_storeu_epi8( dst + (i + 0) * ld_dst, mask_mrem_v, _mm512_castsi128_si512(_mm512_extracti32x4_epi32(u[reg_idx], 0x0))); _mm512_mask_storeu_epi8( dst + (i + 1) * ld_dst, mask_mrem_v, _mm512_castsi128_si512(_mm512_extracti32x4_epi32(u[reg_idx], 0x1))); _mm512_mask_storeu_epi8( dst + (i + 2) * ld_dst, mask_mrem_v, _mm512_castsi128_si512(_mm512_extracti32x4_epi32(u[reg_idx], 0x2))); _mm512_mask_storeu_epi8( dst + (i + 3) * ld_dst, mask_mrem_v, _mm512_castsi128_si512(_mm512_extracti32x4_epi32(u[reg_idx], 0x3))); } int rem = nrem - i; int reg_rem_idx = i / 16 + 2 * ((i % 16) / 4); switch (rem) { case 1: _mm512_mask_storeu_epi8( dst + (i + 0) * ld_dst, mask_mrem_v, _mm512_castsi128_si512( _mm512_extracti32x4_epi32(u[reg_rem_idx], 0x0))); break; case 2: _mm512_mask_storeu_epi8( dst + (i + 0) * ld_dst, mask_mrem_v, _mm512_castsi128_si512( _mm512_extracti32x4_epi32(u[reg_rem_idx], 0x0))); _mm512_mask_storeu_epi8( dst + (i + 1) * ld_dst, mask_mrem_v, _mm512_castsi128_si512( _mm512_extracti32x4_epi32(u[reg_rem_idx], 0x1))); break; case 3: _mm512_mask_storeu_epi8( dst + (i + 0) * ld_dst, mask_mrem_v, _mm512_castsi128_si512( _mm512_extracti32x4_epi32(u[reg_rem_idx], 0x0))); _mm512_mask_storeu_epi8( dst + (i + 1) * ld_dst, mask_mrem_v, _mm512_castsi128_si512( _mm512_extracti32x4_epi32(u[reg_rem_idx], 0x1))); _mm512_mask_storeu_epi8( dst + (i + 2) * ld_dst, mask_mrem_v, _mm512_castsi128_si512( _mm512_extracti32x4_epi32(u[reg_rem_idx], 0x2))); break; default: break; } } else { int i = 0; for (; i < nrem / 4 * 4; i += 4) { // normal store int reg_idx = i / 16 + 2 * ((i % 16) / 4); _mm_storeu_si128( reinterpret_cast<__m128i*>(dst + (i + 0) * ld_dst), _mm512_extracti32x4_epi32(u[reg_idx], 0x0)); _mm_storeu_si128( reinterpret_cast<__m128i*>(dst + (i + 1) * ld_dst), _mm512_extracti32x4_epi32(u[reg_idx], 0x1)); _mm_storeu_si128( reinterpret_cast<__m128i*>(dst + (i + 2) * ld_dst), _mm512_extracti32x4_epi32(u[reg_idx], 0x2)); _mm_storeu_si128( reinterpret_cast<__m128i*>(dst + (i + 3) * ld_dst), _mm512_extracti32x4_epi32(u[reg_idx], 0x3)); } int rem = nrem - i; int reg_rem_idx = i / 16 + 2 * ((i % 16) / 4); switch (rem) { case 1: _mm_storeu_si128( reinterpret_cast<__m128i*>(dst + (i + 0) * ld_dst), _mm512_extracti32x4_epi32(u[reg_rem_idx], 0x0)); break; case 2: _mm_storeu_si128( reinterpret_cast<__m128i*>(dst + (i + 0) * ld_dst), _mm512_extracti32x4_epi32(u[reg_rem_idx], 0x0)); _mm_storeu_si128( reinterpret_cast<__m128i*>(dst + (i + 1) * ld_dst), _mm512_extracti32x4_epi32(u[reg_rem_idx], 0x1)); break; case 3: _mm_storeu_si128( reinterpret_cast<__m128i*>(dst + (i + 0) * ld_dst), _mm512_extracti32x4_epi32(u[reg_rem_idx], 0x0)); _mm_storeu_si128( reinterpret_cast<__m128i*>(dst + (i + 1) * ld_dst), _mm512_extracti32x4_epi32(u[reg_rem_idx], 0x1)); _mm_storeu_si128( reinterpret_cast<__m128i*>(dst + (i + 2) * ld_dst), _mm512_extracti32x4_epi32(u[reg_rem_idx], 0x2)); break; default: break; } } } static inline void transpose_contiguous_4x16_block( const float* src, float* dst, int64_t ld_src, int nrem = 16) { __m512i r[4]; // load if (nrem < 16) { __mmask16 mask_mrem_v = (1ULL << nrem) - 1; r[0] = _mm512_maskz_loadu_epi32(mask_mrem_v, src); r[1] = _mm512_maskz_loadu_epi32(mask_mrem_v, src + ld_src); r[2] = _mm512_maskz_loadu_epi32(mask_mrem_v, src + 2 * ld_src); r[3] = _mm512_maskz_loadu_epi32(mask_mrem_v, src + 3 * ld_src); } else { r[0] = _mm512_loadu_si512(reinterpret_cast(src)); r[1] = _mm512_loadu_si512(reinterpret_cast(src + ld_src)); r[2] = _mm512_loadu_si512(reinterpret_cast(src + 2 * ld_src)); r[3] = _mm512_loadu_si512(reinterpret_cast(src + 3 * ld_src)); } // transpose // a0b0 a1b1 a4b4 a5b5 a8b8 a9b9 a12b12 a13b13 // a2b2 a3b3 a6b6 a7b7 a10b10 a11b11 a14b14 a15b15 // c0d0 c1d1 c4d4 c5d5 c8d8 c9d9 c12d12 c13d13 // c2d2 c3d3 c6d6 c7d7 c10b10 c11d11 c14d14 c15d15 __m512i t0 = _mm512_unpacklo_epi32(r[0], r[1]); __m512i t1 = _mm512_unpackhi_epi32(r[0], r[1]); __m512i t2 = _mm512_unpacklo_epi32(r[2], r[3]); __m512i t3 = _mm512_unpackhi_epi32(r[2], r[3]); r[0] = _mm512_unpacklo_epi64(t0, t2); r[1] = _mm512_unpackhi_epi64(t0, t2); r[2] = _mm512_unpacklo_epi64(t1, t3); r[3] = _mm512_unpackhi_epi64(t1, t3); t0 = _mm512_shuffle_i32x4(r[0], r[1], 0x44); t1 = _mm512_shuffle_i32x4(r[0], r[1], 0xee); t2 = _mm512_shuffle_i32x4(r[2], r[3], 0x44); t3 = _mm512_shuffle_i32x4(r[2], r[3], 0xee); r[0] = _mm512_shuffle_i32x4(t0, t2, 0x88); r[1] = _mm512_shuffle_i32x4(t0, t2, 0xdd); r[2] = _mm512_shuffle_i32x4(t1, t3, 0x88); r[3] = _mm512_shuffle_i32x4(t1, t3, 0xdd); // store int i = 0; for (; (i + 1) * 16 <= nrem * 4; i++) { // normal store _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + i * 16), r[i]); } int erem = nrem * 4 - i * 16; if (erem > 0) { // mask store __mmask16 mask_rem_v = (1ULL << erem) - 1; _mm512_mask_storeu_epi32(dst + i * 16, mask_rem_v, r[i]); } } static inline void transpose_contiguous_4x32_block( const uint16_t* src, uint16_t* dst, int64_t ld_src, int nrem = 32) { __m512i r[4], d[4]; // load if (nrem < 32) { __mmask32 mask_mrem_v = (1ULL << nrem) - 1; r[0] = _mm512_maskz_loadu_epi16(mask_mrem_v, src); r[1] = _mm512_maskz_loadu_epi16(mask_mrem_v, src + ld_src); r[2] = _mm512_maskz_loadu_epi16(mask_mrem_v, src + 2 * ld_src); r[3] = _mm512_maskz_loadu_epi16(mask_mrem_v, src + 3 * ld_src); } else { r[0] = _mm512_loadu_si512(reinterpret_cast(src)); r[1] = _mm512_loadu_si512(reinterpret_cast(src + ld_src)); r[2] = _mm512_loadu_si512(reinterpret_cast(src + 2 * ld_src)); r[3] = _mm512_loadu_si512(reinterpret_cast(src + 3 * ld_src)); } // transpose d[0] = _mm512_unpacklo_epi16(r[0], r[1]); d[1] = _mm512_unpackhi_epi16(r[0], r[1]); d[2] = _mm512_unpacklo_epi16(r[2], r[3]); d[3] = _mm512_unpackhi_epi16(r[2], r[3]); r[0] = _mm512_unpacklo_epi32(d[0], d[2]); r[1] = _mm512_unpackhi_epi32(d[0], d[2]); r[2] = _mm512_unpacklo_epi32(d[1], d[3]); r[3] = _mm512_unpackhi_epi32(d[1], d[3]); d[0] = _mm512_shuffle_i32x4(r[0], r[1], 0x44); d[1] = _mm512_shuffle_i32x4(r[0], r[1], 0xee); d[2] = _mm512_shuffle_i32x4(r[2], r[3], 0x44); d[3] = _mm512_shuffle_i32x4(r[2], r[3], 0xee); r[0] = _mm512_shuffle_i32x4(d[0], d[2], 0x88); r[1] = _mm512_shuffle_i32x4(d[0], d[2], 0xdd); r[2] = _mm512_shuffle_i32x4(d[1], d[3], 0x88); r[3] = _mm512_shuffle_i32x4(d[1], d[3], 0xdd); // store int i = 0; for (; (i + 1) * 32 <= nrem * 4; i++) { _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + i * 32), r[i]); } int erem = nrem * 4 - i * 32; if (erem > 0) { // mask store __mmask32 mask_rem_v = (1ULL << erem) - 1; _mm512_mask_storeu_epi16(dst + i * 32, mask_rem_v, r[i]); } } static inline void transpose_contiguous_16x4_block( const float* src, float* dst, int64_t ld_dst, int mrem = 16) { __m512i r[4], d[4]; int i = 0; for (; (i + 1) * 16 <= mrem * 4; i++) { // normal load r[i] = _mm512_loadu_si512(reinterpret_cast(src + i * 16)); } if (i * 16 < mrem * 4) { __mmask16 mask_mrem_v = (1ULL << (mrem * 4 - i * 16)) - 1; r[i] = _mm512_maskz_loadu_epi32(mask_mrem_v, src + i * 16); } // transpose __m512i index1 = _mm512_set_epi32( 0x0f, 0x0b, 0x07, 0x03, 0x0e, 0x0a, 0x06, 0x02, 0x0d, 0x09, 0x05, 0x01, 0x0c, 0x08, 0x04, 0x00); d[0] = _mm512_permutexvar_epi32(index1, r[0]); d[1] = _mm512_permutexvar_epi32(index1, r[1]); d[2] = _mm512_permutexvar_epi32(index1, r[2]); d[3] = _mm512_permutexvar_epi32(index1, r[3]); r[0] = _mm512_shuffle_i32x4(d[0], d[1], 0x44); r[1] = _mm512_shuffle_i32x4(d[0], d[1], 0xee); r[2] = _mm512_shuffle_i32x4(d[2], d[3], 0x44); r[3] = _mm512_shuffle_i32x4(d[2], d[3], 0xee); d[0] = _mm512_shuffle_i32x4(r[0], r[2], 0x88); d[1] = _mm512_shuffle_i32x4(r[0], r[2], 0xdd); d[2] = _mm512_shuffle_i32x4(r[1], r[3], 0x88); d[3] = _mm512_shuffle_i32x4(r[1], r[3], 0xdd); if (mrem < 16) { // mask store __mmask16 mask_rem_v = (1ULL << mrem) - 1; _mm512_mask_storeu_epi32(dst + 0 * ld_dst, mask_rem_v, d[0]); _mm512_mask_storeu_epi32(dst + 1 * ld_dst, mask_rem_v, d[1]); _mm512_mask_storeu_epi32(dst + 2 * ld_dst, mask_rem_v, d[2]); _mm512_mask_storeu_epi32(dst + 3 * ld_dst, mask_rem_v, d[3]); } else { // normal load _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 0 * ld_dst), d[0]); _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 1 * ld_dst), d[1]); _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 2 * ld_dst), d[2]); _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 3 * ld_dst), d[3]); } } static inline void transpose_contiguous_16x2_block( const float* src, float* dst, int64_t ld_dst, int mrem = 16) { __m512i r[2], d[2]; // Zero out r[] to avoid `may be used uninitialized` compilation error r[0] = _mm512_setzero_si512(); r[1] = _mm512_setzero_si512(); int i = 0; for (; (i + 1) * 16 <= mrem * 2; i++) { // normal load r[i] = _mm512_loadu_si512(reinterpret_cast(src + i * 16)); } if (i * 16 < mrem * 2) { __mmask16 mask_mrem_v = (1ULL << (mrem * 2 - i * 16)) - 1; r[i] = _mm512_maskz_loadu_epi32(mask_mrem_v, src + i * 16); } // transpose __m512i index1 = _mm512_set_epi32( 0x1e, 0x1c, 0x1a, 0x18, 0x16, 0x14, 0x12, 0x10, 0x0e, 0x0c, 0x0a, 0x08, 0x06, 0x04, 0x02, 0x00); __m512i index2 = _mm512_set_epi32( 0x1f, 0x1d, 0x1b, 0x19, 0x17, 0x15, 0x13, 0x11, 0x0f, 0x0d, 0x0b, 0x09, 0x07, 0x05, 0x03, 0x01); // a0--p0 // a1--p1 d[0] = _mm512_permutex2var_epi32(r[0], index1, r[1]); d[1] = _mm512_permutex2var_epi32(r[0], index2, r[1]); // store if (mrem < 16) { __mmask16 mask_rem_v = (1ULL << mrem) - 1; // mask store _mm512_mask_storeu_epi32(dst, mask_rem_v, d[0]); _mm512_mask_storeu_epi32(dst + ld_dst, mask_rem_v, d[1]); } else { // normal store _mm512_storeu_si512(dst, d[0]); _mm512_storeu_si512(dst + ld_dst, d[1]); } } static inline void transpose_contiguous_64x4_block( const uint8_t* src, uint8_t* dst, int64_t ld_dst, int mrem = 64) { __m512i r[4], d[4]; // normal load int i = 0; for (; (i + 1) * 64 <= mrem * 4; i++) { r[i] = _mm512_loadu_si512(reinterpret_cast(src + i * 64)); } int erem = mrem * 4 - i * 64; if (erem > 0) { __mmask64 mask_mrem_v = (1ULL << erem) - 1; r[i] = _mm512_maskz_loadu_epi8(mask_mrem_v, src + i * 64); } // transpose __m512i index = _mm512_set_epi32( 0x0f0b0703, 0x0e0a0602, 0x0d090501, 0x0c080400, 0x0f0b0703, 0x0e0a0602, 0x0d090501, 0x0c080400, 0x0f0b0703, 0x0e0a0602, 0x0d090501, 0x0c080400, 0x0f0b0703, 0x0e0a0602, 0x0d090501, 0x0c080400); d[0] = _mm512_shuffle_epi8(r[0], index); d[1] = _mm512_shuffle_epi8(r[1], index); d[2] = _mm512_shuffle_epi8(r[2], index); d[3] = _mm512_shuffle_epi8(r[3], index); __m512i index2 = _mm512_set_epi32(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0); r[0] = _mm512_permutexvar_epi32(index2, d[0]); r[1] = _mm512_permutexvar_epi32(index2, d[1]); r[2] = _mm512_permutexvar_epi32(index2, d[2]); r[3] = _mm512_permutexvar_epi32(index2, d[3]); __m512i t0 = _mm512_shuffle_i32x4(r[0], r[1], 0x44); __m512i t1 = _mm512_shuffle_i32x4(r[0], r[1], 0xee); __m512i t2 = _mm512_shuffle_i32x4(r[2], r[3], 0x44); __m512i t3 = _mm512_shuffle_i32x4(r[2], r[3], 0xee); d[0] = _mm512_shuffle_i32x4(t0, t2, 0x88); d[1] = _mm512_shuffle_i32x4(t0, t2, 0xdd); d[2] = _mm512_shuffle_i32x4(t1, t3, 0x88); d[3] = _mm512_shuffle_i32x4(t1, t3, 0xdd); // store if (mrem < 64) { __mmask64 mask_rem_v = (1ULL << mrem) - 1; // mask store _mm512_mask_storeu_epi8(dst, mask_rem_v, d[0]); _mm512_mask_storeu_epi8(dst + ld_dst, mask_rem_v, d[1]); _mm512_mask_storeu_epi8(dst + 2 * ld_dst, mask_rem_v, d[2]); _mm512_mask_storeu_epi8(dst + 3 * ld_dst, mask_rem_v, d[3]); } else { // normal store _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 0 * ld_dst), d[0]); _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 1 * ld_dst), d[1]); _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 2 * ld_dst), d[2]); _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 3 * ld_dst), d[3]); } } static inline void transpose_contiguous_32x4_block( const uint16_t* src, uint16_t* dst, int64_t ld_dst, int mrem = 32) { __m512i r[4], d[4]; int i = 0; for (; (i + 1) * 32 <= mrem * 4; i++) { // normal load r[i] = _mm512_loadu_si512(reinterpret_cast(src + i * 32)); } if (i * 32 < mrem * 4) { __mmask32 mask_mrem_v = (1ULL << (mrem * 4 - i * 32)) - 1; r[i] = _mm512_maskz_loadu_epi16(mask_mrem_v, src + i * 32); } // transpose __m512i index = _mm512_set_epi32( 0x001f001b, 0x00170013, 0x000f000b, 0x00070003, 0x001e001a, 0x00160012, 0x000e000a, 0x00060002, 0x001d0019, 0x00150011, 0x000d0009, 0x00050001, 0x001c0018, 0x00140010, 0x000c0008, 0x00040000); d[0] = _mm512_permutexvar_epi16(index, r[0]); d[1] = _mm512_permutexvar_epi16(index, r[1]); d[2] = _mm512_permutexvar_epi16(index, r[2]); d[3] = _mm512_permutexvar_epi16(index, r[3]); r[0] = _mm512_shuffle_i32x4(d[0], d[1], 0x44); r[1] = _mm512_shuffle_i32x4(d[0], d[1], 0xee); r[2] = _mm512_shuffle_i32x4(d[2], d[3], 0x44); r[3] = _mm512_shuffle_i32x4(d[2], d[3], 0xee); d[0] = _mm512_shuffle_i32x4(r[0], r[2], 0x88); d[1] = _mm512_shuffle_i32x4(r[0], r[2], 0xdd); d[2] = _mm512_shuffle_i32x4(r[1], r[3], 0x88); d[3] = _mm512_shuffle_i32x4(r[1], r[3], 0xdd); if (mrem < 32) { // mask store __mmask32 mask_rem_v = (1ULL << mrem) - 1; _mm512_mask_storeu_epi16(dst + 0 * ld_dst, mask_rem_v, d[0]); _mm512_mask_storeu_epi16(dst + ld_dst, mask_rem_v, d[1]); _mm512_mask_storeu_epi16(dst + 2 * ld_dst, mask_rem_v, d[2]); _mm512_mask_storeu_epi16(dst + 3 * ld_dst, mask_rem_v, d[3]); } else { // normal load _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 0 * ld_dst), d[0]); _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 1 * ld_dst), d[1]); _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 2 * ld_dst), d[2]); _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 3 * ld_dst), d[3]); } } static inline void transpose_contiguous_2x16_block( const float* src, float* dst, int64_t ld_src, int nrem = 16) { __m512i r0, r1; // load if (nrem < 16) { __mmask16 mask_mrem_v = (1ULL << nrem) - 1; r0 = _mm512_maskz_loadu_epi32(mask_mrem_v, src); r1 = _mm512_maskz_loadu_epi32(mask_mrem_v, src + ld_src); } else { r0 = _mm512_loadu_si512(reinterpret_cast(src)); r1 = _mm512_loadu_si512(reinterpret_cast(src + ld_src)); } // transpose __m512i index1 = _mm512_set_epi32( 0x0017, 0x0007, 0x0016, 0x0006, 0x0015, 0x0005, 0x0014, 0x0004, 0x0013, 0x0003, 0x0012, 0x0002, 0x0011, 0x0001, 0x0010, 0x0000); __m512i index2 = _mm512_set_epi32( 0x001f, 0x000f, 0x001e, 0x000e, 0x001d, 0x000d, 0x001c, 0x000c, 0x001b, 0x000b, 0x001a, 0x000a, 0x0019, 0x0009, 0x0018, 0x0008); // a0 b0 a1 b1 a2 b2 a3 b3 a4 b4 a5 b5 a6 b6 a7 b7 // a8 b8 a9 b9 a10 b10 a11 b11 a12 b12 a13 b13 a14 b14 a15 b15 __m512i u0 = _mm512_permutex2var_epi32(r0, index1, r1); __m512i u1 = _mm512_permutex2var_epi32(r0, index2, r1); // store if (nrem < 16) { // mask store if (nrem < 8) { __mmask16 mask_rem_v = (1ULL << (nrem * 2)) - 1; _mm512_mask_storeu_epi32(dst, mask_rem_v, u0); } else { __mmask16 mask_rem_v = (1ULL << ((nrem - 8) * 2)) - 1; _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst), u0); _mm512_mask_storeu_epi32(dst + 16, mask_rem_v, u1); } } else { // normal store _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst), u0); _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 16), u1); } } static inline void transpose_contiguous_64x2_block( const uint8_t* src, uint8_t* dst, int64_t ld_dst, int mrem = 64) { __m512i r[2], d[2]; // normal load int i = 0; for (; (i + 1) * 64 <= mrem * 2; i++) { r[i] = _mm512_loadu_si512(reinterpret_cast(src + i * 64)); } int erem = mrem * 2 - i * 64; if (erem > 0) { __mmask64 mask_mrem_v = (1ULL << erem) - 1; r[i] = _mm512_maskz_loadu_epi8(mask_mrem_v, src + i * 64); } // transpose __m512i index1 = _mm512_set_epi32( 0x0f0d0b09, 0x07050301, 0x0e0c0a08, 0x06040200, 0x0f0d0b09, 0x07050301, 0x0e0c0a08, 0x06040200, 0x0f0d0b09, 0x07050301, 0x0e0c0a08, 0x06040200, 0x0f0d0b09, 0x07050301, 0x0e0c0a08, 0x06040200); r[0] = _mm512_shuffle_epi8(r[0], index1); r[1] = _mm512_shuffle_epi8(r[1], index1); __m512i index2 = _mm512_set_epi64(14, 12, 10, 8, 6, 4, 2, 0); __m512i index3 = _mm512_set_epi64(15, 13, 11, 9, 7, 5, 3, 1); d[0] = _mm512_permutex2var_epi64(r[0], index2, r[1]); d[1] = _mm512_permutex2var_epi64(r[0], index3, r[1]); // store if (mrem < 64) { __mmask64 mask_rem_v = (1ULL << mrem) - 1; // mask store _mm512_mask_storeu_epi8(dst, mask_rem_v, d[0]); _mm512_mask_storeu_epi8(dst + ld_dst, mask_rem_v, d[1]); } else { // normal store _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst), d[0]); _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + ld_dst), d[1]); } } static inline void transpose_contiguous_4x64_block( const uint8_t* src, uint8_t* dst, int64_t ld_src, int nrem = 64) { __m512i r[4], d[4]; // load if (nrem < 64) { __mmask64 mask_mrem_v = (1ULL << nrem) - 1; r[0] = _mm512_maskz_loadu_epi8(mask_mrem_v, src); r[1] = _mm512_maskz_loadu_epi8(mask_mrem_v, src + ld_src); r[2] = _mm512_maskz_loadu_epi8(mask_mrem_v, src + 2 * ld_src); r[3] = _mm512_maskz_loadu_epi8(mask_mrem_v, src + 3 * ld_src); } else { r[0] = _mm512_loadu_si512(reinterpret_cast(src)); r[1] = _mm512_loadu_si512(reinterpret_cast(src + ld_src)); r[2] = _mm512_loadu_si512(reinterpret_cast(src + 2 * ld_src)); r[3] = _mm512_loadu_si512(reinterpret_cast(src + 3 * ld_src)); } // transpose d[0] = _mm512_unpacklo_epi32(r[0], r[1]); d[1] = _mm512_unpackhi_epi32(r[0], r[1]); d[2] = _mm512_unpacklo_epi32(r[2], r[3]); d[3] = _mm512_unpackhi_epi32(r[2], r[3]); r[0] = _mm512_unpacklo_epi64(d[0], d[2]); r[1] = _mm512_unpackhi_epi64(d[0], d[2]); r[2] = _mm512_unpacklo_epi64(d[1], d[3]); r[3] = _mm512_unpackhi_epi64(d[1], d[3]); d[0] = _mm512_shuffle_i32x4(r[0], r[1], 0x44); d[1] = _mm512_shuffle_i32x4(r[0], r[1], 0xee); d[2] = _mm512_shuffle_i32x4(r[2], r[3], 0x44); d[3] = _mm512_shuffle_i32x4(r[2], r[3], 0xee); r[0] = _mm512_shuffle_i32x4(d[0], d[2], 0x88); r[1] = _mm512_shuffle_i32x4(d[0], d[2], 0xdd); r[2] = _mm512_shuffle_i32x4(d[1], d[3], 0x88); r[3] = _mm512_shuffle_i32x4(d[1], d[3], 0xdd); __m512i index = _mm512_set_epi32( 0x0f0b0703, 0x0e0a0602, 0x0d090501, 0x0c080400, 0x0f0b0703, 0x0e0a0602, 0x0d090501, 0x0c080400, 0x0f0b0703, 0x0e0a0602, 0x0d090501, 0x0c080400, 0x0f0b0703, 0x0e0a0602, 0x0d090501, 0x0c080400); d[0] = _mm512_shuffle_epi8(r[0], index); d[1] = _mm512_shuffle_epi8(r[1], index); d[2] = _mm512_shuffle_epi8(r[2], index); d[3] = _mm512_shuffle_epi8(r[3], index); // store int i = 0; for (; (i + 1) * 64 <= nrem * 4; i++) { _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + i * 64), d[i]); } int erem = nrem * 4 - i * 64; if (erem > 0) { __mmask64 mask_rem_v = (1ULL << erem) - 1; _mm512_mask_storeu_epi8(dst + i * 64, mask_rem_v, d[i]); } } static inline void transpose_contiguous_2x64_block( const uint8_t* src, uint8_t* dst, int64_t ld_src, int nrem = 64) { __m512i r[2]; __m512i d[2]; // load if (nrem < 64) { __mmask64 mask_mrem_v = (1ULL << nrem) - 1; r[0] = _mm512_maskz_loadu_epi8(mask_mrem_v, src); r[1] = _mm512_maskz_loadu_epi8(mask_mrem_v, src + ld_src); } else { r[0] = _mm512_loadu_si512(reinterpret_cast(src)); r[1] = _mm512_loadu_si512(reinterpret_cast(src + ld_src)); } // transpose // _mm512_mask_blend_epi8(0xaaaaaaaaaaaaaaaa, r0, r1); d[0] = _mm512_unpacklo_epi16(r[0], r[1]); d[1] = _mm512_unpackhi_epi16(r[0], r[1]); __m512i index1 = _mm512_set_epi32( 0x0f0d0e0c, 0x0b090a08, 0x07050604, 0x03010200, 0x0f0d0e0c, 0x0b090a08, 0x07050604, 0x03010200, 0x0f0d0e0c, 0x0b090a08, 0x07050604, 0x03010200, 0x0f0d0e0c, 0x0b090a08, 0x07050604, 0x03010200); r[0] = _mm512_shuffle_epi8(d[0], index1); r[1] = _mm512_shuffle_epi8(d[1], index1); __m512i index2 = _mm512_set_epi64(11, 10, 3, 2, 9, 8, 1, 0); __m512i index3 = _mm512_set_epi64(15, 14, 7, 6, 13, 12, 5, 4); // a0b0 a1b1 ... a31b31 // a32b32 ... a63b63 d[0] = _mm512_permutex2var_epi64(r[0], index2, r[1]); d[1] = _mm512_permutex2var_epi64(r[0], index3, r[1]); int i = 0; for (; (i + 1) * 64 <= nrem * 2; i++) { _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + i * 64), d[i]); } int erem = nrem * 2 - i * 64; if (erem > 0) { __mmask64 mask_rem_v = (1ULL << erem) - 1; _mm512_mask_storeu_epi8(dst + i * 64, mask_rem_v, d[i]); } } static inline void transpose_contiguous_2x32_block( const uint16_t* src, uint16_t* dst, int64_t ld_src, int nrem = 32) { __m512i r0, r1; __m512i d0, d1; // load if (nrem < 32) { __mmask32 mask_mrem_v = (1ULL << nrem) - 1; r0 = _mm512_maskz_loadu_epi16(mask_mrem_v, src); r1 = _mm512_maskz_loadu_epi16(mask_mrem_v, src + ld_src); } else { r0 = _mm512_loadu_si512(reinterpret_cast(src)); r1 = _mm512_loadu_si512(reinterpret_cast(src + ld_src)); } // transpose d0 = _mm512_unpacklo_epi16(r0, r1); d1 = _mm512_unpackhi_epi16(r0, r1); r0 = _mm512_shuffle_i32x4(d0, d1, 0x88); r1 = _mm512_shuffle_i32x4(d0, d1, 0xdd); d0 = _mm512_shuffle_i32x4(r0, r1, 0x88); d1 = _mm512_shuffle_i32x4(r0, r1, 0xdd); // store if (nrem < 16) { __mmask32 mask_rem_v = (1ULL << (nrem * 2)) - 1; _mm512_mask_storeu_epi16(dst, mask_rem_v, d0); } else if (nrem == 16) { _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst), d0); } else if (nrem < 32) { __mmask32 mask_rem_v = (1ULL << (nrem * 2 - 32)) - 1; _mm512_mask_storeu_epi16(dst, mask_rem_v, d0); _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst), d0); _mm512_mask_storeu_epi16( reinterpret_cast<__m512i*>(dst + 32), mask_rem_v, d1); } else { // normal store _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst), d0); _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 32), d1); } } static inline void transpose_contiguous_32x2_block( const uint16_t* src, uint16_t* dst, int64_t ld_dst, int mrem = 32) { __m512i r[2], d[2]; // load int i = 0; for (; (i + 1) * 32 <= mrem * 2; i++) { r[i] = _mm512_loadu_si512(reinterpret_cast(src + i * 32)); } int erem = mrem * 2 - i * 32; if (erem > 0) { __mmask32 mask_mrem_v = (1ULL << erem) - 1; r[i] = _mm512_maskz_loadu_epi16(mask_mrem_v, src + i * 32); } // transpose __m512i index = _mm512_set_epi32( 0x001f001d, 0x001b0019, 0x00170015, 0x00130011, 0x000f000d, 0x000b0009, 0x00070005, 0x00030001, 0x001e001c, 0x001a0018, 0x00160014, 0x00120010, 0x000e000c, 0x000a0008, 0x00060004, 0x00020000); d[0] = _mm512_permutexvar_epi16(index, r[0]); d[1] = _mm512_permutexvar_epi16(index, r[1]); r[0] = _mm512_shuffle_i32x4(d[0], d[1], 0x44); r[1] = _mm512_shuffle_i32x4(d[0], d[1], 0xee); // store if (mrem < 32) { __mmask32 mask_rem_v = (1ULL << mrem) - 1; // mask store _mm512_mask_storeu_epi16(dst, mask_rem_v, r[0]); _mm512_mask_storeu_epi16(dst + ld_dst, mask_rem_v, r[1]); } else { // normal store _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst), r[0]); _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + ld_dst), r[1]); } } template void transpose_16x16_block( const uint16_t* src, int64_t ld_src, uint16_t* dst, int64_t ld_dst, int mrem = 16, int nrem = 16) { __m512i r[8]; if (MREM || NREM) { load_with_remainders_i16(src, ld_src, r, mrem, nrem); } else { __m256i t00 = _mm256_loadu_si256(reinterpret_cast(src + 0 * ld_src)); __m256i t01 = _mm256_loadu_si256(reinterpret_cast(src + 1 * ld_src)); __m256i t02 = _mm256_loadu_si256(reinterpret_cast(src + 2 * ld_src)); __m256i t03 = _mm256_loadu_si256(reinterpret_cast(src + 3 * ld_src)); __m256i t04 = _mm256_loadu_si256(reinterpret_cast(src + 4 * ld_src)); __m256i t05 = _mm256_loadu_si256(reinterpret_cast(src + 5 * ld_src)); __m256i t06 = _mm256_loadu_si256(reinterpret_cast(src + 6 * ld_src)); __m256i t07 = _mm256_loadu_si256(reinterpret_cast(src + 7 * ld_src)); __m256i t08 = _mm256_loadu_si256(reinterpret_cast(src + 8 * ld_src)); __m256i t09 = _mm256_loadu_si256(reinterpret_cast(src + 9 * ld_src)); __m256i t10 = _mm256_loadu_si256(reinterpret_cast(src + 10 * ld_src)); __m256i t11 = _mm256_loadu_si256(reinterpret_cast(src + 11 * ld_src)); __m256i t12 = _mm256_loadu_si256(reinterpret_cast(src + 12 * ld_src)); __m256i t13 = _mm256_loadu_si256(reinterpret_cast(src + 13 * ld_src)); __m256i t14 = _mm256_loadu_si256(reinterpret_cast(src + 14 * ld_src)); __m256i t15 = _mm256_loadu_si256(reinterpret_cast(src + 15 * ld_src)); // a0a1 a2a3 a4a5 a6a7 a8a9 a10a11 a12a13 a14a15 // e0e1 e2e3 e4e5 e6e7 e8e9 e10e11 e12e13 e14e15 r[0] = _mm512_inserti64x4(_mm512_castsi256_si512(t00), t04, 0x01); // b0-b15 // f0-f15 r[1] = _mm512_inserti64x4(_mm512_castsi256_si512(t01), t05, 0x01); // c0-c15 // g0-g15 r[2] = _mm512_inserti64x4(_mm512_castsi256_si512(t02), t06, 0x01); // d0-d15 // g0-h15 r[3] = _mm512_inserti64x4(_mm512_castsi256_si512(t03), t07, 0x01); // i0-i15 // m0-m15 r[4] = _mm512_inserti64x4(_mm512_castsi256_si512(t08), t12, 0x01); // j0-j15 // n0-n15 r[5] = _mm512_inserti64x4(_mm512_castsi256_si512(t09), t13, 0x01); // k0-k15 // o0-o15 r[6] = _mm512_inserti64x4(_mm512_castsi256_si512(t10), t14, 0x01); // l0-l15 // p0-p15 r[7] = _mm512_inserti64x4(_mm512_castsi256_si512(t11), t15, 0x01); } __m512i u[8]; core_transpose_16x16_block(r, u); if (MREM || NREM) { store_with_remainders_i16(dst, ld_dst, u, mrem, nrem); } else { _mm256_storeu_si256( reinterpret_cast<__m256i*>(dst + 0 * ld_dst), _mm512_extracti32x8_epi32(u[0], 0x0)); _mm256_storeu_si256( reinterpret_cast<__m256i*>(dst + 1 * ld_dst), _mm512_extracti32x8_epi32(u[0], 0x01)); _mm256_storeu_si256( reinterpret_cast<__m256i*>(dst + 2 * ld_dst), _mm512_extracti32x8_epi32(u[1], 0x0)); _mm256_storeu_si256( reinterpret_cast<__m256i*>(dst + 3 * ld_dst), _mm512_extracti32x8_epi32(u[1], 0x01)); _mm256_storeu_si256( reinterpret_cast<__m256i*>(dst + 4 * ld_dst), _mm512_extracti32x8_epi32(u[2], 0x0)); _mm256_storeu_si256( reinterpret_cast<__m256i*>(dst + 5 * ld_dst), _mm512_extracti32x8_epi32(u[2], 0x01)); _mm256_storeu_si256( reinterpret_cast<__m256i*>(dst + 6 * ld_dst), _mm512_extracti32x8_epi32(u[3], 0x0)); _mm256_storeu_si256( reinterpret_cast<__m256i*>(dst + 7 * ld_dst), _mm512_extracti32x8_epi32(u[3], 0x01)); _mm256_storeu_si256( reinterpret_cast<__m256i*>(dst + 8 * ld_dst), _mm512_extracti32x8_epi32(u[4], 0x0)); _mm256_storeu_si256( reinterpret_cast<__m256i*>(dst + 9 * ld_dst), _mm512_extracti32x8_epi32(u[4], 0x01)); _mm256_storeu_si256( reinterpret_cast<__m256i*>(dst + 10 * ld_dst), _mm512_extracti32x8_epi32(u[5], 0x0)); _mm256_storeu_si256( reinterpret_cast<__m256i*>(dst + 11 * ld_dst), _mm512_extracti32x8_epi32(u[5], 0x01)); _mm256_storeu_si256( reinterpret_cast<__m256i*>(dst + 12 * ld_dst), _mm512_extracti32x8_epi32(u[6], 0x0)); _mm256_storeu_si256( reinterpret_cast<__m256i*>(dst + 13 * ld_dst), _mm512_extracti32x8_epi32(u[6], 0x01)); _mm256_storeu_si256( reinterpret_cast<__m256i*>(dst + 14 * ld_dst), _mm512_extracti32x8_epi32(u[7], 0x0)); _mm256_storeu_si256( reinterpret_cast<__m256i*>(dst + 15 * ld_dst), _mm512_extracti32x8_epi32(u[7], 0x01)); } } template void transpose_16x32_block( const uint8_t* src, int64_t ld_src, uint8_t* dst, int64_t ld_dst, int mrem = 16, int nrem = 32) { // Treat the numbers in a row as 4-Byte integers. // Thus 03_04 is is 4-byte element in 03 row and 04 column // // 00_00 00_01 00_02 00_03 00_04 00_05 00_06 00_07 // 01_00 01_01 01_02 01_03 01_04 01_05 01_06 01_07 // 02_00 02_01 02_02 02_03 02_04 02_05 02_06 02_07 // 03_00 03_01 03_02 03_03 03_04 03_05 03_06 03_07 // 04_00 04_01 04_02 04_03 04_04 04_05 04_06 04_07 // 05_00 05_01 05_02 05_03 05_04 05_05 05_06 05_07 // 06_00 06_01 06_02 06_03 06_04 06_05 06_06 06_07 // 07_00 07_01 07_02 07_03 07_04 07_05 07_06 07_07 // 08_00 08_01 08_02 08_03 08_04 08_05 08_06 08_07 // 09_00 09_01 09_02 09_03 09_04 09_05 09_06 09_07 // 10_00 10_01 10_02 10_03 10_04 10_05 10_06 10_07 // 11_00 11_01 11_02 11_03 11_04 11_05 11_06 11_07 // 12_00 12_01 12_02 12_03 12_04 12_05 12_06 12_07 // 13_00 13_01 13_02 13_03 13_04 13_05 13_06 13_07 // 14_00 14_01 14_02 14_03 14_04 14_05 14_06 14_07 // 15_00 15_01 15_02 15_03 15_04 15_05 15_06 15_07 __m512i r[8]; if (MREM || NREM) { load_with_remainders_i8(src, ld_src, r, mrem, nrem); } else { __m256i t00 = _mm256_loadu_si256(reinterpret_cast(src + 0 * ld_src)); __m256i t04 = _mm256_loadu_si256(reinterpret_cast(src + 4 * ld_src)); // 00_00 00_01 00_02 00_03 00_04 00_05 00_06 00_07 04_00 04_01 04_02 04_03 // 04_04 04_05 04_06 04_07 r[0] = _mm512_inserti64x4(_mm512_castsi256_si512(t00), t04, 0x01); __m256i t01 = _mm256_loadu_si256(reinterpret_cast(src + 1 * ld_src)); __m256i t05 = _mm256_loadu_si256(reinterpret_cast(src + 5 * ld_src)); // 01_00 01_01 01_02 01_03 01_04 01_05 01_06 01_07 05_00 05_01 05_02 05_03 // 05_04 05_05 05_06 05_07 r[1] = _mm512_inserti64x4(_mm512_castsi256_si512(t01), t05, 0x01); __m256i t02 = _mm256_loadu_si256(reinterpret_cast(src + 2 * ld_src)); __m256i t06 = _mm256_loadu_si256(reinterpret_cast(src + 6 * ld_src)); // 02_00 02_01 02_02 02_03 02_04 02_05 02_06 02_07 06_00 06_01 06_02 06_03 // 06_04 06_05 06_06 06_07 r[2] = _mm512_inserti64x4(_mm512_castsi256_si512(t02), t06, 0x01); __m256i t03 = _mm256_loadu_si256(reinterpret_cast(src + 3 * ld_src)); __m256i t07 = _mm256_loadu_si256(reinterpret_cast(src + 7 * ld_src)); // 03_00 03_01 03_02 03_03 03_04 03_05 03_06 03_07 07_00 07_01 07_02 07_03 // 07_04 07_05 07_06 07_07 r[3] = _mm512_inserti64x4(_mm512_castsi256_si512(t03), t07, 0x01); __m256i t08 = _mm256_loadu_si256(reinterpret_cast(src + 8 * ld_src)); __m256i t12 = _mm256_loadu_si256(reinterpret_cast(src + 12 * ld_src)); // 08_00 08_01 08_02 08_03 08_04 08_05 08_06 08_07 12_00 12_01 12_02 12_03 // 12_04 12_05 12_06 12_07 r[4] = _mm512_inserti64x4(_mm512_castsi256_si512(t08), t12, 0x01); __m256i t09 = _mm256_loadu_si256(reinterpret_cast(src + 9 * ld_src)); __m256i t13 = _mm256_loadu_si256(reinterpret_cast(src + 13 * ld_src)); // 09_00 09_01 09_02 09_03 09_04 09_05 09_06 09_07 13_00 13_01 13_02 13_03 // 13_04 13_05 13_06 13_07 r[5] = _mm512_inserti64x4(_mm512_castsi256_si512(t09), t13, 0x01); __m256i t10 = _mm256_loadu_si256(reinterpret_cast(src + 10 * ld_src)); __m256i t14 = _mm256_loadu_si256(reinterpret_cast(src + 14 * ld_src)); // 10_00 10_01 10_02 10_03 10_04 10_05 10_06 10_07 14_00 14_01 14_02 14_03 // 14_04 14_05 14_06 14_07 r[6] = _mm512_inserti64x4(_mm512_castsi256_si512(t10), t14, 0x01); __m256i t11 = _mm256_loadu_si256(reinterpret_cast(src + 11 * ld_src)); __m256i t15 = _mm256_loadu_si256(reinterpret_cast(src + 15 * ld_src)); // 11_00 11_01 11_02 11_03 11_04 11_05 11_06 11_07 15_00 15_01 15_02 15_03 // 15_04 15_05 15_06 15_07 r[7] = _mm512_inserti64x4(_mm512_castsi256_si512(t11), t15, 0x01); } __m512i u[8]; core_transpose_16x32_block_i8(r, u); if (MREM || NREM) { store_with_remainders_i8(dst, ld_dst, u, mrem, nrem); } else { _mm_storeu_si128( reinterpret_cast<__m128i*>(dst + 0 * ld_dst), _mm512_extracti32x4_epi32(u[0], 0x0)); _mm_storeu_si128( reinterpret_cast<__m128i*>(dst + 1 * ld_dst), _mm512_extracti32x4_epi32(u[0], 0x1)); _mm_storeu_si128( reinterpret_cast<__m128i*>(dst + 2 * ld_dst), _mm512_extracti32x4_epi32(u[0], 0x2)); _mm_storeu_si128( reinterpret_cast<__m128i*>(dst + 3 * ld_dst), _mm512_extracti32x4_epi32(u[0], 0x3)); _mm_storeu_si128( reinterpret_cast<__m128i*>(dst + 16 * ld_dst), _mm512_extracti32x4_epi32(u[1], 0x0)); _mm_storeu_si128( reinterpret_cast<__m128i*>(dst + 17 * ld_dst), _mm512_extracti32x4_epi32(u[1], 0x1)); _mm_storeu_si128( reinterpret_cast<__m128i*>(dst + 18 * ld_dst), _mm512_extracti32x4_epi32(u[1], 0x2)); _mm_storeu_si128( reinterpret_cast<__m128i*>(dst + 19 * ld_dst), _mm512_extracti32x4_epi32(u[1], 0x3)); _mm_storeu_si128( reinterpret_cast<__m128i*>(dst + 4 * ld_dst), _mm512_extracti32x4_epi32(u[2], 0x0)); _mm_storeu_si128( reinterpret_cast<__m128i*>(dst + 5 * ld_dst), _mm512_extracti32x4_epi32(u[2], 0x1)); _mm_storeu_si128( reinterpret_cast<__m128i*>(dst + 6 * ld_dst), _mm512_extracti32x4_epi32(u[2], 0x2)); _mm_storeu_si128( reinterpret_cast<__m128i*>(dst + 7 * ld_dst), _mm512_extracti32x4_epi32(u[2], 0x3)); _mm_storeu_si128( reinterpret_cast<__m128i*>(dst + 20 * ld_dst), _mm512_extracti32x4_epi32(u[3], 0x0)); _mm_storeu_si128( reinterpret_cast<__m128i*>(dst + 21 * ld_dst), _mm512_extracti32x4_epi32(u[3], 0x1)); _mm_storeu_si128( reinterpret_cast<__m128i*>(dst + 22 * ld_dst), _mm512_extracti32x4_epi32(u[3], 0x2)); _mm_storeu_si128( reinterpret_cast<__m128i*>(dst + 23 * ld_dst), _mm512_extracti32x4_epi32(u[3], 0x3)); _mm_storeu_si128( reinterpret_cast<__m128i*>(dst + 8 * ld_dst), _mm512_extracti32x4_epi32(u[4], 0x0)); _mm_storeu_si128( reinterpret_cast<__m128i*>(dst + 9 * ld_dst), _mm512_extracti32x4_epi32(u[4], 0x1)); _mm_storeu_si128( reinterpret_cast<__m128i*>(dst + 10 * ld_dst), _mm512_extracti32x4_epi32(u[4], 0x2)); _mm_storeu_si128( reinterpret_cast<__m128i*>(dst + 11 * ld_dst), _mm512_extracti32x4_epi32(u[4], 0x3)); _mm_storeu_si128( reinterpret_cast<__m128i*>(dst + 24 * ld_dst), _mm512_extracti32x4_epi32(u[5], 0x0)); _mm_storeu_si128( reinterpret_cast<__m128i*>(dst + 25 * ld_dst), _mm512_extracti32x4_epi32(u[5], 0x1)); _mm_storeu_si128( reinterpret_cast<__m128i*>(dst + 26 * ld_dst), _mm512_extracti32x4_epi32(u[5], 0x2)); _mm_storeu_si128( reinterpret_cast<__m128i*>(dst + 27 * ld_dst), _mm512_extracti32x4_epi32(u[5], 0x3)); _mm_storeu_si128( reinterpret_cast<__m128i*>(dst + 12 * ld_dst), _mm512_extracti32x4_epi32(u[6], 0x0)); _mm_storeu_si128( reinterpret_cast<__m128i*>(dst + 13 * ld_dst), _mm512_extracti32x4_epi32(u[6], 0x1)); _mm_storeu_si128( reinterpret_cast<__m128i*>(dst + 14 * ld_dst), _mm512_extracti32x4_epi32(u[6], 0x2)); _mm_storeu_si128( reinterpret_cast<__m128i*>(dst + 15 * ld_dst), _mm512_extracti32x4_epi32(u[6], 0x3)); _mm_storeu_si128( reinterpret_cast<__m128i*>(dst + 28 * ld_dst), _mm512_extracti32x4_epi32(u[7], 0x0)); _mm_storeu_si128( reinterpret_cast<__m128i*>(dst + 29 * ld_dst), _mm512_extracti32x4_epi32(u[7], 0x1)); _mm_storeu_si128( reinterpret_cast<__m128i*>(dst + 30 * ld_dst), _mm512_extracti32x4_epi32(u[7], 0x2)); _mm_storeu_si128( reinterpret_cast<__m128i*>(dst + 31 * ld_dst), _mm512_extracti32x4_epi32(u[7], 0x3)); } } template <> void transpose_avx512_contiguous_thin( int64_t M, int64_t N, const float* src, int64_t ld_src, float* dst, int64_t ld_dst) { if (N == 2) { int64_t i = 0; for (; i < M / 16 * 16; i += 16) { transpose_contiguous_16x2_block(src + i * ld_src, dst + i, ld_dst); } int mrem = M - i; if (mrem > 0) { transpose_contiguous_16x2_block(src + i * ld_src, dst + i, ld_dst, mrem); } } else if (N == 4) { int64_t i = 0; for (; i < M / 16 * 16; i += 16) { transpose_contiguous_16x4_block(src + i * ld_src, dst + i, ld_dst); } int mrem = M - i; if (mrem > 0) { transpose_contiguous_16x4_block(src + i * ld_src, dst + i, ld_dst, mrem); } } } template <> void transpose_avx512_contiguous_thin( int64_t M, int64_t N, const uint16_t* src, int64_t ld_src, uint16_t* dst, int64_t ld_dst) { if (N == 2) { int64_t i = 0; for (; i < M / 32 * 32; i += 32) { transpose_contiguous_32x2_block(src + i * ld_src, dst + i, ld_dst); } int mrem = M - i; if (mrem > 0) { transpose_contiguous_32x2_block(src + i * ld_src, dst + i, ld_dst, mrem); } } else if (N == 4) { int64_t i = 0; for (; i < M / 32 * 32; i += 32) { transpose_contiguous_32x4_block(src + i * ld_src, dst + i, ld_dst); } int mrem = M - i; if (mrem > 0) { transpose_contiguous_32x4_block(src + i * ld_src, dst + i, ld_dst, mrem); } } } template <> void transpose_avx512_contiguous_thin( int64_t M, int64_t N, const uint8_t* src, int64_t ld_src, uint8_t* dst, int64_t ld_dst) { if (N == 2) { int64_t i = 0; for (; i < M / 64 * 64; i += 64) { transpose_contiguous_64x2_block(src + i * ld_src, dst + i, ld_dst); } int mrem = M - i; if (mrem > 0) { transpose_contiguous_64x2_block(src + i * ld_src, dst + i, ld_dst, mrem); } } else if (N == 4) { int64_t i = 0; for (; i < M / 64 * 64; i += 64) { transpose_contiguous_64x4_block(src + i * ld_src, dst + i, ld_dst); } int mrem = M - i; if (mrem > 0) { transpose_contiguous_64x4_block(src + i * ld_src, dst + i, ld_dst, mrem); } } } template <> void transpose_avx512_contiguous_wide( int64_t M, int64_t N, const float* src, int64_t ld_src, float* dst, int64_t ld_dst) { if (M == 2) { int64_t i = 0; for (; i < N / 16 * 16; i += 16) { transpose_contiguous_2x16_block(src + i, dst + i * ld_dst, ld_src); } int nrem = N - i; if (nrem > 0) { transpose_contiguous_2x16_block(src + i, dst + i * ld_dst, ld_src, nrem); } } else if (M == 4) { int64_t i = 0; for (; i < N / 16 * 16; i += 16) { transpose_contiguous_4x16_block(src + i, dst + i * ld_dst, ld_src); } int nrem = N - i; if (nrem > 0) { transpose_contiguous_4x16_block(src + i, dst + i * ld_dst, ld_src, nrem); } } } template <> void transpose_avx512_contiguous_wide( int64_t M, int64_t N, const uint16_t* src, int64_t ld_src, uint16_t* dst, int64_t ld_dst) { if (M == 2) { int64_t i = 0; for (; i < N / 32 * 32; i += 32) { transpose_contiguous_2x32_block(src + i, dst + i * ld_dst, ld_src); } int nrem = N - i; if (nrem > 0) { transpose_contiguous_2x32_block(src + i, dst + i * ld_dst, ld_src, nrem); } } else if (M == 4) { int64_t i = 0; for (; i < N / 32 * 32; i += 32) { transpose_contiguous_4x32_block(src + i, dst + i * ld_dst, ld_src); } int nrem = N - i; if (nrem > 0) { transpose_contiguous_4x32_block(src + i, dst + i * ld_dst, ld_src, nrem); } } } template <> void transpose_avx512_contiguous_wide( int64_t M, int64_t N, const uint8_t* src, int64_t ld_src, uint8_t* dst, int64_t ld_dst) { if (M == 2) { int64_t i = 0; for (; i < N / 64 * 64; i += 64) { transpose_contiguous_2x64_block(src + i, dst + i * ld_dst, ld_src); } int nrem = N - i; if (nrem > 0) { transpose_contiguous_2x64_block(src + i, dst + i * ld_dst, ld_src, nrem); } } else if (M == 4) { int64_t i = 0; for (; i < N / 64 * 64; i += 64) { transpose_contiguous_4x64_block(src + i, dst + i * ld_dst, ld_src); } int nrem = N - i; if (nrem > 0) { transpose_contiguous_4x64_block(src + i, dst + i * ld_dst, ld_src, nrem); } } } template <> void transpose_avx512( int64_t M, int64_t N, const float* src, int64_t ld_src, float* dst, int64_t ld_dst) { if (M == ld_dst && (M == 2 || M == 4)) { transpose_avx512_contiguous_wide(M, N, src, ld_src, dst, ld_dst); } else if (N == ld_src && (N == 2 || N == 4)) { transpose_avx512_contiguous_thin(M, N, src, ld_src, dst, ld_dst); } else { int64_t ib = 0, jb = 0; if (N % 16 > 0 && N % 16 < 4) { // If the remainder has n < 4 columns, we use the SSE kernel for the // remainder because it requires 4 * (2 * 4 + 2 * N) = 32 + 8N // instructions instead of 4 * 16 + 2 * N = 64 + 2N instructions needed in // the masked AVX512 kernel. for (ib = 0; ib + 16 <= M; ib += 16) { for (jb = 0; jb + 16 <= N; jb += 16) { transpose_kernel_16x16_avx512( &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst); } for (int64_t i = ib; i < ib + 16; i += 4) { transpose_kernel_mxn_sse<4>( N - jb, &src[i * ld_src + jb], ld_src, &dst[i + jb * ld_dst], ld_dst); } } } else if (N % 16 == 4) { // If the remainder has 4 columns, we use the SSE kernel for the remainder // because it requires 4 * 16 = 64 instructions instead of 4 * 16 + 2 * 4 // = 72 instructions needed in the masked AVX512 kernel. for (ib = 0; ib + 16 <= M; ib += 16) { for (jb = 0; jb + 16 <= N; jb += 16) { transpose_kernel_16x16_avx512( &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst); } for (int64_t i = ib; i < ib + 16; i += 4) { transpose_kernel_4x4_sse( &src[i * ld_src + jb], ld_src, &dst[i + jb * ld_dst], ld_dst); } } } else if (N % 16 == 8) { // If the remainder has 8 columns, we use the AVX kenrel for the remainder // because it requires 2 * 40 = 80 instructions instead of 4 * 16 + 2 * 8 // = 80 instructions + looping overhead in the masked AVX512 kernel. for (ib = 0; ib + 16 <= M; ib += 16) { for (jb = 0; jb + 16 <= N; jb += 16) { transpose_kernel_16x16_avx512( &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst); } for (int64_t i = ib; i < ib + 16; i += 8) { transpose_kernel_8x8_avx2( &src[i * ld_src + jb], ld_src, &dst[i + jb * ld_dst], ld_dst); } } } else { for (ib = 0; ib + 16 <= M; ib += 16) { for (jb = 0; jb + 16 <= N; jb += 16) { transpose_kernel_16x16_avx512( &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst); } if (jb < N) { transpose_kernel_mxn_avx512<16>( N - jb, &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst); } } } // Specialization for small M - ib cases so that the compiler can inline // transpose_kernel_mxn_avx512 and unroll the loops whose iteration count // depends on by M - ib . // Specialization for m helps more than for n in transpose_kernel_mxn_avx512 // because we have more loops in that function whose iteration count depends // on m. switch (M - ib) { case 1: for (int64_t j = 0; j < N; ++j) { dst[ib + j * ld_dst] = src[ib * ld_src + j]; } break; case 2: for (jb = 0; jb + 4 <= N; jb += 4) { transpose_kernel_mxn_sse<2>( 4, &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst); } if (jb < N) { transpose_kernel_mxn_sse<2>( N - jb, &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst); } break; case 3: for (jb = 0; jb + 4 <= N; jb += 4) { transpose_kernel_mxn_sse<3>( 4, &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst); } if (jb < N) { transpose_kernel_mxn_sse<3>( N - jb, &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst); } break; case 4: for (jb = 0; jb + 4 <= N; jb += 4) { transpose_kernel_4x4_sse( &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst); } if (jb < N) { transpose_kernel_mxn_sse<4>( N - jb, &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst); } break; case 5: for (jb = 0; jb + 8 <= N; jb += 8) { transpose_kernel_mxn_avx2<5>( 8, &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst); } if (jb < N) { transpose_kernel_mxn_avx2<5>( N - jb, &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst); } break; case 6: for (jb = 0; jb + 8 <= N; jb += 8) { transpose_kernel_mxn_avx2<6>( 8, &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst); } if (jb < N) { transpose_kernel_mxn_avx2<6>( N - jb, &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst); } break; case 7: for (jb = 0; jb + 16 <= N; jb += 16) { transpose_kernel_mxn_avx512<7>( 16, &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst); } if (jb < N) { transpose_kernel_mxn_avx512<7>( N - jb, &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst); } break; case 8: for (jb = 0; jb + 8 <= N; jb += 8) { transpose_kernel_8x8_avx2( &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst); } if (jb < N) { transpose_kernel_mxn_avx2<8>( N - jb, &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst); } break; case 9: for (jb = 0; jb + 16 <= N; jb += 16) { transpose_kernel_mxn_avx512<9>( 16, &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst); } if (jb < N) { transpose_kernel_mxn_avx512<9>( N - jb, &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst); } break; case 10: for (jb = 0; jb + 16 <= N; jb += 16) { transpose_kernel_mxn_avx512<10>( 16, &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst); } if (jb < N) { transpose_kernel_mxn_avx512<10>( N - jb, &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst); } break; case 11: for (jb = 0; jb + 16 <= N; jb += 16) { transpose_kernel_mxn_avx512<11>( 16, &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst); } if (jb < N) { transpose_kernel_mxn_avx512<11>( N - jb, &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst); } break; case 12: for (jb = 0; jb + 16 <= N; jb += 16) { transpose_kernel_mxn_avx512<12>( 16, &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst); } if (jb < N) { transpose_kernel_mxn_avx512<12>( N - jb, &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst); } break; case 13: for (jb = 0; jb + 16 <= N; jb += 16) { transpose_kernel_mxn_avx512<13>( 16, &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst); } if (jb < N) { transpose_kernel_mxn_avx512<13>( N - jb, &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst); } break; case 14: for (jb = 0; jb + 16 <= N; jb += 16) { transpose_kernel_mxn_avx512<14>( 16, &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst); } if (jb < N) { transpose_kernel_mxn_avx512<14>( N - jb, &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst); } break; case 15: for (jb = 0; jb + 16 <= N; jb += 16) { transpose_kernel_mxn_avx512<15>( 16, &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst); } if (jb < N) { transpose_kernel_mxn_avx512<15>( N - jb, &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst); } break; } } } template <> void transpose_avx512( int64_t M, int64_t N, const uint16_t* src, int64_t ld_src, uint16_t* dst, int64_t ld_dst) { if (M == ld_dst && (M == 2 || M == 4)) { transpose_avx512_contiguous_wide(M, N, src, ld_src, dst, ld_dst); } else if (N == ld_src && (N == 2 || N == 4)) { transpose_avx512_contiguous_thin(M, N, src, ld_src, dst, ld_dst); } else { int64_t i = 0; for (; i < M / 16 * 16; i += 16) { int64_t j = 0; for (; j < N / 16 * 16; j += 16) { transpose_16x16_block( src + i * ld_src + j, ld_src, dst + j * ld_dst + i, ld_dst); } // handle j rem int nrem = N - j; if (nrem > 0) { transpose_16x16_block( src + i * ld_src + j, ld_src, dst + j * ld_dst + i, ld_dst, 16, nrem); } } // handle i rem int mrem = M - i; if (mrem > 0) { int j = 0; for (; j < N / 16 * 16; j += 16) { transpose_16x16_block( src + i * ld_src + j, ld_src, dst + j * ld_dst + i, ld_dst, mrem, 16); } // handle j rem int nrem = N - j; transpose_16x16_block( src + i * ld_src + j, ld_src, dst + j * ld_dst + i, ld_dst, mrem, nrem); } } } template <> void transpose_avx512( int64_t M, int64_t N, const uint8_t* src, int64_t ld_src, uint8_t* dst, int64_t ld_dst) { if (M == ld_dst && (M == 2 || M == 4)) { transpose_avx512_contiguous_wide(M, N, src, ld_src, dst, ld_dst); } else if (N == ld_src && (N == 2 || N == 4)) { transpose_avx512_contiguous_thin(M, N, src, ld_src, dst, ld_dst); } else { int64_t i = 0; for (; i < M / 16 * 16; i += 16) { int64_t j = 0; for (; j < N / 32 * 32; j += 32) { transpose_16x32_block( src + i * ld_src + j, ld_src, dst + j * ld_dst + i, ld_dst); } // handle j rem int nrem = N - j; if (nrem > 0) { transpose_16x32_block( src + i * ld_src + j, ld_src, dst + j * ld_dst + i, ld_dst, 16, nrem); } } // handle i rem int mrem = M - i; if (mrem > 0) { int64_t j = 0; for (; j < N / 32 * 32; j += 32) { transpose_16x32_block( src + i * ld_src + j, ld_src, dst + j * ld_dst + i, ld_dst, mrem, 32); } // handle j rem int nrem = N - j; transpose_16x32_block( src + i * ld_src + j, ld_src, dst + j * ld_dst + i, ld_dst, mrem, nrem); } } } } // namespace internal } // namespace fbgemm