84 lines
2.5 KiB
Plaintext
84 lines
2.5 KiB
Plaintext
/*
|
|
* Copyright (c) 2023 by FlashInfer team.
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
#ifndef FLASHINFER_STATE_CUH_
|
|
#define FLASHINFER_STATE_CUH_
|
|
|
|
#include "../math.cuh"
|
|
#include "../vec_dtypes.cuh"
|
|
|
|
namespace flashinfer {
|
|
|
|
/*!
|
|
* \brief The flashattention state.
|
|
* \tparam vec_size The size of the vector used in o.
|
|
*/
|
|
template <size_t vec_size>
|
|
struct state_t {
|
|
/* the weighted sum of v: exp(pre-softmax logit - m) * v / d */
|
|
vec_t<float, vec_size> o;
|
|
/* maximum value of pre-softmax logits */
|
|
float m;
|
|
/* sum of exp(pre-softmax logits - m) */
|
|
float d;
|
|
|
|
__device__ __forceinline__ void init() {
|
|
o.fill(0.f);
|
|
m = -math::inf;
|
|
d = 1.f;
|
|
}
|
|
|
|
__device__ __forceinline__ state_t() { init(); }
|
|
|
|
__device__ __forceinline__ float get_lse() const { return m + math::ptx_log2(d); }
|
|
|
|
/*!
|
|
* \brief Merge the state with another state.
|
|
* \param other_m The maximum value of pre-softmax logits of the other state.
|
|
* \param other_d The sum of exp(pre-softmax logits - m) of the other state.
|
|
* \param other_o The weighted sum of v of the other state.
|
|
*/
|
|
__device__ __forceinline__ void merge(const vec_t<float, vec_size>& other_o, float other_m,
|
|
float other_d) {
|
|
float m_prev = m, d_prev = d;
|
|
m = max(m_prev, other_m);
|
|
d = d_prev * math::ptx_exp2(m_prev - m) + other_d * math::ptx_exp2(other_m - m);
|
|
#pragma unroll
|
|
for (size_t i = 0; i < vec_size; ++i) {
|
|
o[i] = o[i] * math::ptx_exp2(m_prev - m) + other_o[i] * math::ptx_exp2(other_m - m);
|
|
}
|
|
}
|
|
|
|
/*!
|
|
* \brief Merge the state with another state.
|
|
* \param other The other state.
|
|
*/
|
|
__device__ __forceinline__ void merge(const state_t<vec_size>& other) {
|
|
merge(other.o, other.m, other.d);
|
|
}
|
|
|
|
__device__ __forceinline__ void normalize() {
|
|
// only normalize by d when not normalized on the fly
|
|
#pragma unroll
|
|
for (size_t i = 0; i < vec_size; ++i) {
|
|
o[i] = __fdividef(o[i], d);
|
|
}
|
|
}
|
|
};
|
|
|
|
} // namespace flashinfer
|
|
|
|
#endif // FLASHINFER_STATE_CUH_
|