157 lines
4.4 KiB
Plaintext
157 lines
4.4 KiB
Plaintext
/*
|
|
* Copyright (c) 2023 by FlashInfer team.
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
#ifndef FLASHINFER_MATH_CUH_
|
|
#define FLASHINFER_MATH_CUH_
|
|
|
|
#include <cuda_fp16.h>
|
|
#include <cuda_runtime.h>
|
|
|
|
#include <cstdint>
|
|
|
|
namespace flashinfer {
|
|
namespace math {
|
|
|
|
// log2(e)
|
|
constexpr float log2e = 1.44269504088896340736f;
|
|
|
|
constexpr float loge2 = 0.693147180559945309417f;
|
|
|
|
constexpr float inf = 5e4;
|
|
|
|
__forceinline__ __device__ half2 uint32_as_half2(uint32_t x) { return *(half2*)&x; }
|
|
|
|
__forceinline__ __device__ uint32_t half2_as_uint32(half2 x) { return *(uint32_t*)&x; }
|
|
|
|
/*!
|
|
* \brief Wrapper of PTX ex2.approx instruction, which computes 2^x
|
|
* \param x input
|
|
*/
|
|
__forceinline__ __device__ float ptx_exp2(float x) {
|
|
float y;
|
|
asm volatile("ex2.approx.ftz.f32 %0, %1;" : "=f"(y) : "f"(x));
|
|
return y;
|
|
}
|
|
|
|
/*!
|
|
* \brief Wrapper of PTX lg2.approx instruction, which computes log2(x)
|
|
* \param x input
|
|
*/
|
|
__forceinline__ __device__ float ptx_log2(float x) {
|
|
float y;
|
|
asm volatile("lg2.approx.ftz.f32 %0, %1;" : "=f"(y) : "f"(x));
|
|
return y;
|
|
}
|
|
|
|
/*!
|
|
* \brief Wrapper of PTX ex2.approx.f16x2 instruction, which computes 2^x
|
|
* \param x input
|
|
*/
|
|
__forceinline__ __device__ half2 ptx_exp2(half2 x) {
|
|
uint32_t y_u32;
|
|
uint32_t x_u32 = half2_as_uint32(x);
|
|
asm volatile("ex2.approx.f16x2 %0, %1;" : "=r"(y_u32) : "r"(x_u32));
|
|
return uint32_as_half2(y_u32);
|
|
}
|
|
|
|
/*!
|
|
* \brief Wrapper of PTX ex2.approx.f16 instruction, which computes 2^x
|
|
* \param x input
|
|
*/
|
|
__forceinline__ __device__ half ptx_exp2(half x) {
|
|
ushort y_u16;
|
|
asm volatile("ex2.approx.f16 %0, %1;" : "=h"(y_u16) : "h"(__half_as_ushort(x)));
|
|
return __ushort_as_half(y_u16);
|
|
}
|
|
|
|
/*!
|
|
* \brief Wrapper of PTX rcp.approx instruction, which computes 1/x
|
|
* \param x input
|
|
*/
|
|
__forceinline__ __device__ float ptx_rcp(float x) {
|
|
float y;
|
|
asm volatile("rcp.approx.ftz.f32 %0, %1;" : "=f"(y) : "f"(x));
|
|
return y;
|
|
}
|
|
|
|
/*!
|
|
* \brief Wrapper of PTX shfl.sync.bfly instruction, which performs a butterfly shuffle
|
|
* between threads in a warp.
|
|
* \param x The value in the source lane
|
|
* \param lane_mask The mask to perform thread index xor with: y[i] <- x[i ^ delta]
|
|
*/
|
|
__forceinline__ __device__ float shfl_xor_sync(float x, int lane_mask) {
|
|
float y;
|
|
asm volatile("shfl.sync.bfly.b32 %0, %1, %2, 0x1f, 0xffffffff;"
|
|
: "=f"(y)
|
|
: "f"(x), "r"(lane_mask));
|
|
return y;
|
|
}
|
|
|
|
/*!
|
|
* \brief Wrapper of PTX shfl.sync.bfly instruction on half2, which performs a butterfly
|
|
* shuffle between threads in a warp.
|
|
* \param x The value in the source lane
|
|
* \param lane_mask The mask to perform thread index xor with: y[i] <- x[i ^ lane_mask]
|
|
*/
|
|
__forceinline__ __device__ half2 shfl_xor_sync(half2 x, int lane_mask) {
|
|
return __shfl_xor_sync(0xffffffff, x, lane_mask);
|
|
}
|
|
|
|
/*!
|
|
* \brief Wrapper of PTX rsqrt approximation instruction, which computes 1/sqrt(x)
|
|
* \param x input
|
|
*/
|
|
__forceinline__ __device__ float rsqrt(float x) {
|
|
float y;
|
|
asm volatile("rsqrt.approx.ftz.f32 %0, %1;" : "=f"(y) : "f"(x));
|
|
return y;
|
|
}
|
|
|
|
/*!
|
|
* \brief Wrapper of PTX tanh.approx.f32 instruction, which computes tanh(x)
|
|
* \param x input
|
|
*/
|
|
__forceinline__ __device__ float tanh(float x) {
|
|
float y;
|
|
asm volatile("tanh.approx.f32 %0, %1;" : "=f"(y) : "f"(x));
|
|
return y;
|
|
}
|
|
|
|
/*!
|
|
* \brief Wrapper of PTX tanh.approx.f16x2 instruction, which computes tanh(x)
|
|
* \param x input
|
|
*/
|
|
__forceinline__ __device__ half2 tanh(half2 x) {
|
|
uint32_t y_u32;
|
|
uint32_t x_u32 = half2_as_uint32(x);
|
|
asm volatile("tanh.approx.f16x2 %0, %1;" : "=r"(y_u32) : "r"(x_u32));
|
|
return uint32_as_half2(y_u32);
|
|
}
|
|
|
|
/*!
|
|
* \brief Wrapper of PTX tanh.approx.f16 instruction, which computes tanh(x)
|
|
* \param x input
|
|
*/
|
|
__forceinline__ __device__ half tanh(half x) {
|
|
ushort y_u16;
|
|
asm volatile("tanh.approx.f16 %0, %1;" : "=h"(y_u16) : "h"(__half_as_ushort(x)));
|
|
return __ushort_as_half(y_u16);
|
|
}
|
|
|
|
} // namespace math
|
|
} // namespace flashinfer
|
|
#endif // FLASHINFER_MATH_CUH_
|