sglang_v0.5.2/flashinfer_0.3.1/include/flashinfer/trtllm/fused_moe/IntFastDiv.h

150 lines
4.3 KiB
C++

/*
* Copyright (c) 2022-2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
// code in this file based on https://github.com/milakov/int_fastdiv
// note: above repo is not added as a submodule because we slightly update APIs
#include <cute/config.hpp>
#include <stdexcept>
namespace trtllm::dev {
////////////////////////////////////////////////////////////////////////////////////////////////////
// *************************************************************************************************
// IntFastDiv class.
// *************************************************************************************************
////////////////////////////////////////////////////////////////////////////////////////////////////
class IntFastDiv {
public:
// we allow a default constructor (and over-writing the divisor with assignment)
CUTE_HOST
IntFastDiv() : mDivisor(1), mMagicM(0), mMagicS(-1), mAddSign(1) {}
CUTE_HOST
IntFastDiv(int divisor) : mDivisor(divisor) {
if (mDivisor == 0) throw std::runtime_error("IntFastDiv: cannot divide by 0");
updateMagicNumbers();
}
CUTE_HOST
IntFastDiv& operator=(int divisor) {
this->mDivisor = divisor;
if (this->mDivisor == 0) throw std::runtime_error("IntFastDiv: cannot divide by 0");
updateMagicNumbers();
return *this;
}
CUTE_HOST_DEVICE
operator int() const { return mDivisor; }
private:
int mDivisor;
int mMagicM;
int mMagicS;
int mAddSign;
// Hacker'mMagicS Delight, Second Edition, Chapter 10, Integer Division By Constants
CUTE_HOST
void updateMagicNumbers() {
if (mDivisor == 1) {
mMagicM = 0;
mMagicS = -1;
mAddSign = 1;
return;
} else if (mDivisor == -1) {
mMagicM = 0;
mMagicS = -1;
mAddSign = -1;
return;
}
int p;
unsigned int tmpAd, tmpAnc, delta, q1, r1, q2, r2, t;
unsigned const two31 = 0x80000000;
tmpAd = abs(mDivisor);
t = two31 + ((unsigned int)mDivisor >> 31);
tmpAnc = t - 1 - t % tmpAd;
p = 31;
q1 = two31 / tmpAnc;
r1 = two31 - q1 * tmpAnc;
q2 = two31 / tmpAd;
r2 = two31 - q2 * tmpAd;
do {
++p;
q1 = 2 * q1;
r1 = 2 * r1;
if (r1 >= tmpAnc) {
++q1;
r1 -= tmpAnc;
}
q2 = 2 * q2;
r2 = 2 * r2;
if (r2 >= tmpAd) {
++q2;
r2 -= tmpAd;
}
delta = tmpAd - r2;
} while (q1 < delta || (q1 == delta && r1 == 0));
this->mMagicM = q2 + 1;
if (mDivisor < 0) this->mMagicM = -this->mMagicM;
this->mMagicS = p - 32;
if ((mDivisor > 0) && (mMagicM < 0))
mAddSign = 1;
else if ((mDivisor < 0) && (mMagicM > 0))
mAddSign = -1;
else
mAddSign = 0;
}
CUTE_HOST_DEVICE
friend int operator/(int const dividend, IntFastDiv const& divisor);
};
////////////////////////////////////////////////////////////////////////////////////////////////////
CUTE_HOST_DEVICE
int operator/(int const dividend, IntFastDiv const& divisor) {
int q;
#ifdef __CUDA_ARCH__
asm("mul.hi.s32 %0, %1, %2;" : "=r"(q) : "r"(divisor.mMagicM), "r"(dividend));
#else
q = (((unsigned long long)((long long)divisor.mMagicM * (long long)dividend)) >> 32);
#endif
q += dividend * divisor.mAddSign;
if (divisor.mMagicS >= 0) {
q >>= divisor.mMagicS; // we rely on this to be implemented as arithmetic shift
q += (((unsigned int)q) >> 31);
}
return q;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
CUTE_HOST_DEVICE
int operator%(int const dividend, IntFastDiv const& divisor) {
int quotient = dividend / divisor;
int remainder = dividend - quotient * divisor;
return remainder;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace trtllm::dev