sglang_v0.5.2/pytorch_2.8.0/third_party/fbgemm/src/CodeGenHelpers.h

275 lines
8.0 KiB
C++

/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#pragma once
#include <asmjit/asmjit.h>
#include "fbgemm/SimdUtils.h"
#include "fbgemm/Utils.h"
namespace fbgemm {
namespace x86 = asmjit::x86;
/**
* @brief Create instruction sequence to generate 16-bit 1s
* @tparam T Register type of destination, e.g., x86::Ymm or x86::Zmm
*
* @param dest Once the instruction sequence is executed,
* dest[0:15] will have 0x0001, dest[16:31]
* will have 0x0001 and so on
*/
template <
inst_set_t instSet,
typename T,
typename std::enable_if<instSet == inst_set_t::avx2, int>::type = 0>
void gen16BitVectorOne(x86::Emitter* a, T dest) {
a->vpcmpeqw(dest, dest, dest);
a->vpsrlw(dest, dest, 15);
}
template <
inst_set_t instSet,
typename T,
typename std::enable_if<
instSet == inst_set_t::avx512 || instSet == inst_set_t::avx512_ymm ||
instSet == inst_set_t::avx512_vnni ||
instSet == inst_set_t::avx512_vnni_ymm,
int>::type = 0>
void gen16BitVectorOne(x86::Emitter* a, T dest) {
a->vpternlogd(dest, dest, dest, 0xff);
a->vpsrlw(dest, dest, 15);
}
/**
* @brief Emit instruction do load 32-bit integer. AVX512 has
* different instrunction to load registers with index >= 16
* @tparam T Register type of destination, e.g., x86::Ymm or x86::Zmm
*
* @param dest Destination vector register
*/
template <
inst_set_t instSet,
typename T,
typename std::enable_if<instSet == inst_set_t::avx2, int>::type = 0>
void emitLoadDWord(x86::Emitter* a, T dest, const x86::Mem& ptr) {
a->vmovdqa(dest, ptr);
}
template <
inst_set_t instSet,
typename T,
typename std::enable_if<
instSet == inst_set_t::avx512 || instSet == inst_set_t::avx512_ymm ||
instSet == inst_set_t::avx512_vnni ||
instSet == inst_set_t::avx512_vnni_ymm,
int>::type = 0>
void emitLoadDWord(x86::Emitter* a, T dest, const x86::Mem& ptr) {
a->vmovdqa32(dest, ptr);
}
/**
* @brief Emit partial extract from Wide regiter to Half Register, eg.
* Zmm -> Ymm or Ymm -> Xmm
* @tparam instSet instruction set to be used
*
* @param half Destination (half) vector register
* @param vec Source (full) vector register
* @param idx Index of of the half vector 0 or 1
*/
template <
inst_set_t instSet,
typename T,
typename std::enable_if<
instSet == inst_set_t::avx512 || instSet == inst_set_t::avx512_ymm ||
instSet == inst_set_t::avx512_vnni ||
instSet == inst_set_t::avx512_vnni_ymm,
int>::type = 0>
void emitExtractHalfVector(
x86::Emitter* a,
x86::Ymm half,
const x86::Zmm vec,
int idx) {
a->vextracti32x8(half, vec, idx);
}
template <
inst_set_t instSet,
typename T,
typename std::enable_if<
instSet == inst_set_t::avx512 || instSet == inst_set_t::avx512_ymm ||
instSet == inst_set_t::avx512_vnni ||
instSet == inst_set_t::avx512_vnni_ymm,
int>::type = 0>
void emitExtractHalfVector(
x86::Emitter* a,
x86::Xmm half,
x86::Ymm vec,
int idx) {
a->vextracti32x4(half, vec, idx);
}
template <
inst_set_t instSet,
typename T,
typename std::enable_if<instSet == inst_set_t::avx2, int>::type = 0>
void emitExtractHalfVector(
x86::Emitter* a,
x86::Xmm half,
x86::Ymm vec,
int idx) {
a->vextracti128(half, vec, idx);
}
/**
* @brief Create instruction sequence to generate 8-bit 1s
* @tparam T Register type of destination, e.g., x86::Ymm or x86::Zmm
*
* @param dest Once the instruction sequence is executed,
* dest[0:7] will have 0x01, dest[8:15]
* will have 0x01 and so on
*/
template <
typename T,
typename std::enable_if<std::is_same<T, x86::Ymm>::value, int>::type = 0>
void gen8BitVectorOne(x86::Emitter* a, T dest) {
a->vpcmpeqw(dest, dest, dest);
a->vpabsb(dest, dest);
}
template <
typename T,
typename std::enable_if<std::is_same<T, x86::Zmm>::value, int>::type = 0>
void gen8BitVectorOne(x86::Emitter* a, T dest) {
a->vpternlogd(dest, dest, dest, 0xff);
a->vpabsb(dest, dest);
}
/**
* @brief Generates instruction sequence to compute s32 += U8 * I8
* @tparam T Register type of destination, e.g., x86::Ymm or x86::Zmm
*
* @param cReg contains result
*
*/
template <
inst_set_t INST_SET,
typename std::enable_if<
INST_SET == inst_set_t::avx2 || INST_SET == inst_set_t::avx512,
int>::type = 0>
void genU8I8S32FMA(
x86::Emitter* a,
typename simd_info<INST_SET>::vec_reg_t aReg,
typename simd_info<INST_SET>::vec_reg_t bReg,
typename simd_info<INST_SET>::vec_reg_t cReg,
typename simd_info<INST_SET>::vec_reg_t oneReg16Bit,
typename simd_info<INST_SET>::vec_reg_t tmpReg) {
a->vpmaddubsw(tmpReg, aReg, bReg);
a->vpmaddwd(tmpReg, oneReg16Bit, tmpReg);
a->vpaddd(cReg, tmpReg, cReg);
}
template <
inst_set_t INST_SET,
typename std::enable_if<INST_SET == inst_set_t::avx512_vnni, int>::type = 0>
void genU8I8S32FMA(
x86::Emitter* a,
typename simd_info<INST_SET>::vec_reg_t aReg,
typename simd_info<INST_SET>::vec_reg_t bReg,
typename simd_info<INST_SET>::vec_reg_t cReg,
typename simd_info<INST_SET>::vec_reg_t /*oneReg16Bit*/,
typename simd_info<INST_SET>::vec_reg_t /*tmpReg*/) {
a->vpdpbusd(cReg, aReg, bReg);
}
/**
* @brief Add 4 consecutive numbers of type uint8
* and emit their sum as 32-bit numbers.
* i.e., dest[0:31] contains
* src[0:7] + src[8:15] + src[16:23] + src[24:31]
* @tparam T Register type of destination, e.g., x86::Ymm or x86::Zmm
*
* @param dest contains result
*
*/
template <
inst_set_t INST_SET,
typename std::enable_if<
INST_SET == inst_set_t::avx2 || INST_SET == inst_set_t::avx512,
int>::type = 0>
void genU8Sum4(
x86::Emitter* a,
typename simd_info<INST_SET>::vec_reg_t src,
typename simd_info<INST_SET>::vec_reg_t dest,
typename simd_info<INST_SET>::vec_reg_t oneReg16Bit,
typename simd_info<INST_SET>::vec_reg_t tmpReg) {
gen8BitVectorOne(a, tmpReg);
a->vpmaddubsw(tmpReg, src, tmpReg);
a->vpmaddwd(tmpReg, tmpReg, oneReg16Bit);
a->vpaddd(dest, tmpReg, dest);
/*a->vxorps(tmpReg, tmpReg, tmpReg);*/
/*a->vmpsadbw(tmpReg, src, tmpReg, static_cast<asmjit::Imm>(0));*/
/*a->vpermilps(tmpReg, tmpReg, static_cast<asmjit::Imm>(4));*/
/*a->vpmovzxwd(tmpReg, tmpReg.half());*/
/*a->vpaddd(dest, tmpReg, dest);*/
}
template <
inst_set_t INST_SET,
typename std::enable_if<INST_SET == inst_set_t::avx512_vnni, int>::type = 0>
void genU8Sum4(
x86::Emitter* a,
typename simd_info<INST_SET>::vec_reg_t src,
typename simd_info<INST_SET>::vec_reg_t dest,
typename simd_info<INST_SET>::vec_reg_t /*oneReg16Bit*/,
typename simd_info<INST_SET>::vec_reg_t tmpReg) {
gen8BitVectorOne(a, tmpReg);
a->vpdpbusd(dest, src, tmpReg);
}
/**
* @brief Add 8 consecutive numbers of type uint8
* and emit their sum as 16-bit numbers.
* i.e., dest[0:15] contains
* src[0:7] + src[8:15] + src[16:23] + src[24:31]
* src[32:39] + src[40:47] + src[48:55] + src[56:63]
*
* and
*
* dest[64:79] contains
* src[64:71] + src[71:79] + src[80:87] + src[88:95]
* src[96:103] + src[104:111] + src[112:119] + src[120:127]
*
* so on
*
* @tparam T Register type of destination, e.g., x86::Ymm or x86::Zmm
*
* @param dest contains result
*
*/
template <typename T>
void genU8Sum8(x86::Emitter* a, T src, T dest, T tmpReg) {
a->vxorps(tmpReg, tmpReg, tmpReg);
a->vpsadbw(tmpReg, src, tmpReg);
a->vpaddd(dest, tmpReg, dest);
}
/**
* @brief Broadcast lower 8-bits of src to destination vector
* register.
*/
template <typename T>
void broadcast8Bit(x86::Emitter* a, x86::Gp src, T dest) {
// move src to dest
auto xmm = dest.xmm();
a->movq(xmm, src);
a->vpbroadcastb(dest, xmm);
}
} // namespace fbgemm