/* * Copyright (c) 2023 by FlashInfer team. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef FLASHINFER_STATE_CUH_ #define FLASHINFER_STATE_CUH_ #include "../math.cuh" #include "../vec_dtypes.cuh" namespace flashinfer { /*! * \brief The flashattention state. * \tparam vec_size The size of the vector used in o. */ template struct state_t { /* the weighted sum of v: exp(pre-softmax logit - m) * v / d */ vec_t o; /* maximum value of pre-softmax logits */ float m; /* sum of exp(pre-softmax logits - m) */ float d; __device__ __forceinline__ void init() { o.fill(0.f); m = -math::inf; d = 1.f; } __device__ __forceinline__ state_t() { init(); } __device__ __forceinline__ float get_lse() const { return m + math::ptx_log2(d); } /*! * \brief Merge the state with another state. * \param other_m The maximum value of pre-softmax logits of the other state. * \param other_d The sum of exp(pre-softmax logits - m) of the other state. * \param other_o The weighted sum of v of the other state. */ __device__ __forceinline__ void merge(const vec_t& other_o, float other_m, float other_d) { float m_prev = m, d_prev = d; m = max(m_prev, other_m); d = d_prev * math::ptx_exp2(m_prev - m) + other_d * math::ptx_exp2(other_m - m); #pragma unroll for (size_t i = 0; i < vec_size; ++i) { o[i] = o[i] * math::ptx_exp2(m_prev - m) + other_o[i] * math::ptx_exp2(other_m - m); } } /*! * \brief Merge the state with another state. * \param other The other state. */ __device__ __forceinline__ void merge(const state_t& other) { merge(other.o, other.m, other.d); } __device__ __forceinline__ void normalize() { // only normalize by d when not normalized on the fly #pragma unroll for (size_t i = 0; i < vec_size; ++i) { o[i] = __fdividef(o[i], d); } } }; } // namespace flashinfer #endif // FLASHINFER_STATE_CUH_