sglang_v0.5.2/flashinfer_0.3.1/include/flashinfer/attention/state.cuh

84 lines
2.5 KiB
Plaintext

/*
* Copyright (c) 2023 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef FLASHINFER_STATE_CUH_
#define FLASHINFER_STATE_CUH_
#include "../math.cuh"
#include "../vec_dtypes.cuh"
namespace flashinfer {
/*!
* \brief The flashattention state.
* \tparam vec_size The size of the vector used in o.
*/
template <size_t vec_size>
struct state_t {
/* the weighted sum of v: exp(pre-softmax logit - m) * v / d */
vec_t<float, vec_size> o;
/* maximum value of pre-softmax logits */
float m;
/* sum of exp(pre-softmax logits - m) */
float d;
__device__ __forceinline__ void init() {
o.fill(0.f);
m = -math::inf;
d = 1.f;
}
__device__ __forceinline__ state_t() { init(); }
__device__ __forceinline__ float get_lse() const { return m + math::ptx_log2(d); }
/*!
* \brief Merge the state with another state.
* \param other_m The maximum value of pre-softmax logits of the other state.
* \param other_d The sum of exp(pre-softmax logits - m) of the other state.
* \param other_o The weighted sum of v of the other state.
*/
__device__ __forceinline__ void merge(const vec_t<float, vec_size>& other_o, float other_m,
float other_d) {
float m_prev = m, d_prev = d;
m = max(m_prev, other_m);
d = d_prev * math::ptx_exp2(m_prev - m) + other_d * math::ptx_exp2(other_m - m);
#pragma unroll
for (size_t i = 0; i < vec_size; ++i) {
o[i] = o[i] * math::ptx_exp2(m_prev - m) + other_o[i] * math::ptx_exp2(other_m - m);
}
}
/*!
* \brief Merge the state with another state.
* \param other The other state.
*/
__device__ __forceinline__ void merge(const state_t<vec_size>& other) {
merge(other.o, other.m, other.d);
}
__device__ __forceinline__ void normalize() {
// only normalize by d when not normalized on the fly
#pragma unroll
for (size_t i = 0; i < vec_size; ++i) {
o[i] = __fdividef(o[i], d);
}
}
};
} // namespace flashinfer
#endif // FLASHINFER_STATE_CUH_