/* * Copyright (c) 2023 by FlashInfer team. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #pragma once #include #include #include #include #ifdef FLASHINFER_ENABLE_BF16 #include #endif #ifdef FLASHINFER_ENABLE_F16 #include #endif #if defined(FLASHINFER_ENABLE_FP8_E4M3) || defined(FLASHINFER_ENABLE_FP8_E5M2) || \ defined(FLASHINFER_ENABLE_FP8_E8M0) #include #endif #if defined(FLASHINFER_ENABLE_FP4_E2M1) #if (__CUDACC_VER_MAJOR__ * 10000 + __CUDACC_VER_MINOR__ * 100 >= 120800) #include #endif #endif #ifndef FLASHINFER_EXT_MODULE_INITED #define FLASHINFER_EXT_MODULE_INITED // To expand macros in #name #define FLASHINFER_EXT_MODULE_INIT_EXPAND(name) FLASHINFER_EXT_MODULE_INIT(name) /* Creates a dummy empty module that can be imported from Python. The import from Python will load the .so consisting of the file in this extension, so that the TORCH_LIBRARY_FRAGMENT static initializers are run. */ #define FLASHINFER_EXT_MODULE_INIT(name) \ extern "C" { \ __attribute__((weak)) PyObject* PyInit_##name(void) { \ static struct PyModuleDef module_def = { \ PyModuleDef_HEAD_INIT, \ #name, /* name of module */ \ NULL, /* module documentation, may be NULL */ \ -1, /* size of per-interpreter state of the module, \ or -1 if the module keeps state in global variables. */ \ NULL, /* methods */ \ NULL, /* slots */ \ NULL, /* traverse */ \ NULL, /* clear */ \ NULL, /* free */ \ }; \ return PyModule_Create(&module_def); \ } \ } FLASHINFER_EXT_MODULE_INIT_EXPAND(TORCH_EXTENSION_NAME) #undef FLASHINFER_EXT_MODULE_INIT #undef FLASHINFER_EXT_MODULE_INIT_EXPAND #endif #define _DISPATCH_CASE_I32(c_type, ...) \ case at::ScalarType::Int: { \ using c_type = int32_t; \ return __VA_ARGS__(); \ } #define _DISPATCH_CASE_I64(c_type, ...) \ case at::ScalarType::Long: { \ using c_type = int64_t; \ return __VA_ARGS__(); \ } #define DISPATCH_PYTORCH_IDTYPE_TO_CTYPE(pytorch_dtype, c_type, ...) \ [&]() -> bool { \ switch (pytorch_dtype) { \ _DISPATCH_CASE_I32(c_type, __VA_ARGS__) \ _DISPATCH_CASE_I64(c_type, __VA_ARGS__) \ default: \ std::ostringstream oss; \ oss << __PRETTY_FUNCTION__ << " failed to dispatch idtype " << pytorch_dtype; \ TORCH_CHECK(false, oss.str()); \ return false; \ } \ }() #define _DISPATCH_CASE_F32(c_type, ...) \ case at::ScalarType::Float: { \ using c_type = float; \ return __VA_ARGS__(); \ } #ifdef FLASHINFER_ENABLE_F16 #define _DISPATCH_CASE_F16(c_type, ...) \ case at::ScalarType::Half: { \ using c_type = nv_half; \ return __VA_ARGS__(); \ } #else #define _DISPATCH_CASE_F16(c_type, ...) #endif #ifdef FLASHINFER_ENABLE_BF16 #define _DISPATCH_CASE_BF16(c_type, ...) \ case at::ScalarType::BFloat16: { \ using c_type = nv_bfloat16; \ return __VA_ARGS__(); \ } #else #define _DISPATCH_CASE_BF16(c_type, ...) #endif #ifdef FLASHINFER_ENABLE_FP8_E4M3 #define _DISPATCH_CASE_FP8_E4M3(c_type, ...) \ case at::ScalarType::Float8_e4m3fn: { \ using c_type = __nv_fp8_e4m3; \ return __VA_ARGS__(); \ } #else #define _DISPATCH_CASE_FP8_E4M3(c_type, ...) #endif #ifdef FLASHINFER_ENABLE_FP8_E5M2 #define _DISPATCH_CASE_FP8_E5M2(c_type, ...) \ case at::ScalarType::Float8_e5m2: { \ using c_type = __nv_fp8_e5m2; \ return __VA_ARGS__(); \ } #else #define _DISPATCH_CASE_FP8_E5M2(c_type, ...) #endif // Should not be used together with _DISPATCH_SF_CASE_FP8_E8M0 #if defined(FLASHINFER_ENABLE_FP4_E2M1) && \ (__CUDACC_VER_MAJOR__ * 10000 + __CUDACC_VER_MINOR__ * 100 >= 120800) #define _DISPATCH_CASE_FP4_E2M1(c_type, ...) \ case at::ScalarType::Byte: { \ using c_type = __nv_fp4_e2m1; \ return __VA_ARGS__(); \ } #else #define _DISPATCH_CASE_FP4_E2M1(c_type, ...) #endif // Should not be used together with _DISPATCH_CASE_FP4_E2M1 #if defined(FLASHINFER_ENABLE_FP8_E8M0) && \ (__CUDACC_VER_MAJOR__ * 10000 + __CUDACC_VER_MINOR__ * 100 >= 120800) #define _DISPATCH_SF_CASE_FP8_E8M0(c_type, ...) \ case at::ScalarType::Byte: { \ using c_type = __nv_fp8_e8m0; \ return __VA_ARGS__(); \ } #else #define _DISPATCH_SF_CASE_FP8_E8M0(c_type, ...) #endif #define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(pytorch_dtype, c_type, ...) \ [&]() -> bool { \ switch (pytorch_dtype) { \ _DISPATCH_CASE_F16(c_type, __VA_ARGS__) \ _DISPATCH_CASE_BF16(c_type, __VA_ARGS__) \ default: \ std::ostringstream oss; \ oss << __PRETTY_FUNCTION__ << " failed to dispatch data type " << pytorch_dtype; \ TORCH_CHECK(false, oss.str()); \ return false; \ } \ }() #define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(pytorch_dtype, c_type, ...) \ [&]() -> bool { \ switch (pytorch_dtype) { \ _DISPATCH_CASE_FP8_E4M3(c_type, __VA_ARGS__) \ _DISPATCH_CASE_FP8_E5M2(c_type, __VA_ARGS__) \ default: \ std::ostringstream oss; \ oss << __PRETTY_FUNCTION__ << " failed to dispatch fp8 data type " << pytorch_dtype; \ TORCH_CHECK(false, oss.str()); \ return false; \ } \ }() #define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_SF(pytorch_dtype, c_type, ...) \ [&]() -> bool { \ switch (pytorch_dtype) { \ _DISPATCH_CASE_F32(c_type, __VA_ARGS__) \ _DISPATCH_SF_CASE_FP8_E8M0(c_type, __VA_ARGS__) \ default: \ std::ostringstream oss; \ oss << __PRETTY_FUNCTION__ << " failed to dispatch scaling factor data type " \ << pytorch_dtype; \ TORCH_CHECK(false, oss.str()); \ return false; \ } \ }() #define DISPATCH_PYTORCH_DTYPE_TO_CTYPE(pytorch_dtype, c_type, ...) \ [&]() -> bool { \ switch (pytorch_dtype) { \ _DISPATCH_CASE_F16(c_type, __VA_ARGS__) \ _DISPATCH_CASE_BF16(c_type, __VA_ARGS__) \ _DISPATCH_CASE_FP8_E4M3(c_type, __VA_ARGS__) \ _DISPATCH_CASE_FP8_E5M2(c_type, __VA_ARGS__) \ _DISPATCH_CASE_FP4_E2M1(c_type, __VA_ARGS__) \ default: \ std::ostringstream oss; \ oss << __PRETTY_FUNCTION__ << " failed to dispatch data type " << pytorch_dtype; \ TORCH_CHECK(false, oss.str()); \ return false; \ } \ }() #define _DISPATCH_SWITCH(var_name, cond, ...) \ [&]() -> bool { \ switch (cond) { \ __VA_ARGS__ \ default: \ std::ostringstream oss; \ oss << __PRETTY_FUNCTION__ << " failed to dispatch " var_name " " << int(cond); \ TORCH_CHECK(false, oss.str()); \ return false; \ } \ }() #define _DISPATCH_SWITCH_U16x2(var1_name, var2_name, cond1, cond2, ...) \ [&]() -> bool { \ switch (pack_u16(cond1, cond2)) { \ __VA_ARGS__ \ default: \ std::ostringstream oss; \ oss << __PRETTY_FUNCTION__ << " failed to dispatch (" var1_name ", " var2_name "): (" \ << int(cond1) << ", " << int(cond2) << ")"; \ TORCH_CHECK(false, oss.str()); \ return false; \ } \ }() #define _DISPATCH_CASE(case_expr, case_var, ...) \ case case_expr: { \ constexpr auto case_var = case_expr; \ return __VA_ARGS__(); \ } #define _DISPATCH_CASE_U16x2(case_expr1, case_expr2, case_var1, case_var2, ...) \ case pack_u16(case_expr1, case_expr2): { \ constexpr auto case_var1 = case_expr1; \ constexpr auto case_var2 = case_expr2; \ return __VA_ARGS__(); \ } #define DISPATCH_BOOL(expr, const_expr, ...) \ [&]() -> bool { \ if (expr) { \ constexpr bool const_expr = true; \ return __VA_ARGS__(); \ } else { \ constexpr bool const_expr = false; \ return __VA_ARGS__(); \ } \ }() inline void check_shape(const at::Tensor& a, const at::Tensor& b, const char* a_name, const char* b_name) { TORCH_CHECK(a.dim() == b.dim(), a_name, ".dim() != ", b_name, ".dim(). ", a.dim(), " vs ", b.dim()); for (int i = 0; i < a.dim(); ++i) { TORCH_CHECK(a.size(i) == b.size(i), a_name, ".size(", i, ") != ", b_name, ".size(", i, ")"); } } inline constexpr uint32_t pack_u16(uint16_t a, uint16_t b) { return (uint32_t(a) << 16) | uint32_t(b); } #define CHECK_GQA_HEAD_DIVISIBLE(num_qo_heads, num_kv_heads) \ TORCH_CHECK(num_qo_heads % num_kv_heads == 0, "num_qo_heads(", num_qo_heads, \ ") must be divisible by num_kv_heads(", num_kv_heads, ")") #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") #define CHECK_LAST_DIM_CONTIGUOUS(x) \ TORCH_CHECK(x.strides()[x.strides().size() - 1] == 1, #x "must be contiguous at last dimension") #define CHECK_INPUT(x) \ CHECK_CUDA(x); \ CHECK_CONTIGUOUS(x) #define CHECK_INPUT_TYPE(x, st) \ TORCH_CHECK(x.scalar_type() == st, "Inconsistency of Tensor type: " #x) #define CHECK_INPUT_AND_TYPE(x, st) \ CHECK_CUDA(x); \ CHECK_CONTIGUOUS(x); \ CHECK_INPUT_TYPE(x, st) #define CHECK_LAST_DIM_CONTIGUOUS_INPUT(x) \ CHECK_CUDA(x); \ CHECK_LAST_DIM_CONTIGUOUS(x) #define CHECK_DIM(d, x) TORCH_CHECK(x.dim() == d, #x " must be a " #d "D tensor") #define CHECK_SHAPE(a, b) check_shape(a, b, #a, #b) #define CHECK_EQ(a, b) TORCH_CHECK((a) == (b), "CHECK_EQ(" #a ", " #b ") failed. ", a, " vs ", b) #define CHECK_GE(a, b) TORCH_CHECK((a) >= (b), "CHECK_GE(" #a ", " #b ") failed. ", a, " vs ", b) inline bool is_float8_tensor(const at::Tensor& tensor) { return tensor.scalar_type() == at::ScalarType::Float8_e4m3fn || tensor.scalar_type() == at::ScalarType::Float8_e5m2; }