sglang_v0.5.2/flashinfer_0.3.1/csrc/pytorch_extension_utils.h

325 lines
16 KiB
C++

/*
* Copyright (c) 2023 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <Python.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAStream.h>
#include <torch/library.h>
#ifdef FLASHINFER_ENABLE_BF16
#include <cuda_bf16.h>
#endif
#ifdef FLASHINFER_ENABLE_F16
#include <cuda_fp16.h>
#endif
#if defined(FLASHINFER_ENABLE_FP8_E4M3) || defined(FLASHINFER_ENABLE_FP8_E5M2) || \
defined(FLASHINFER_ENABLE_FP8_E8M0)
#include <cuda_fp8.h>
#endif
#if defined(FLASHINFER_ENABLE_FP4_E2M1)
#if (__CUDACC_VER_MAJOR__ * 10000 + __CUDACC_VER_MINOR__ * 100 >= 120800)
#include <cuda_fp4.h>
#endif
#endif
#ifndef FLASHINFER_EXT_MODULE_INITED
#define FLASHINFER_EXT_MODULE_INITED
// To expand macros in #name
#define FLASHINFER_EXT_MODULE_INIT_EXPAND(name) FLASHINFER_EXT_MODULE_INIT(name)
/* Creates a dummy empty module that can be imported from Python.
The import from Python will load the .so consisting of the file
in this extension, so that the TORCH_LIBRARY_FRAGMENT static initializers
are run. */
#define FLASHINFER_EXT_MODULE_INIT(name) \
extern "C" { \
__attribute__((weak)) PyObject* PyInit_##name(void) { \
static struct PyModuleDef module_def = { \
PyModuleDef_HEAD_INIT, \
#name, /* name of module */ \
NULL, /* module documentation, may be NULL */ \
-1, /* size of per-interpreter state of the module, \
or -1 if the module keeps state in global variables. */ \
NULL, /* methods */ \
NULL, /* slots */ \
NULL, /* traverse */ \
NULL, /* clear */ \
NULL, /* free */ \
}; \
return PyModule_Create(&module_def); \
} \
}
FLASHINFER_EXT_MODULE_INIT_EXPAND(TORCH_EXTENSION_NAME)
#undef FLASHINFER_EXT_MODULE_INIT
#undef FLASHINFER_EXT_MODULE_INIT_EXPAND
#endif
#define _DISPATCH_CASE_I32(c_type, ...) \
case at::ScalarType::Int: { \
using c_type = int32_t; \
return __VA_ARGS__(); \
}
#define _DISPATCH_CASE_I64(c_type, ...) \
case at::ScalarType::Long: { \
using c_type = int64_t; \
return __VA_ARGS__(); \
}
#define DISPATCH_PYTORCH_IDTYPE_TO_CTYPE(pytorch_dtype, c_type, ...) \
[&]() -> bool { \
switch (pytorch_dtype) { \
_DISPATCH_CASE_I32(c_type, __VA_ARGS__) \
_DISPATCH_CASE_I64(c_type, __VA_ARGS__) \
default: \
std::ostringstream oss; \
oss << __PRETTY_FUNCTION__ << " failed to dispatch idtype " << pytorch_dtype; \
TORCH_CHECK(false, oss.str()); \
return false; \
} \
}()
#define _DISPATCH_CASE_F32(c_type, ...) \
case at::ScalarType::Float: { \
using c_type = float; \
return __VA_ARGS__(); \
}
#ifdef FLASHINFER_ENABLE_F16
#define _DISPATCH_CASE_F16(c_type, ...) \
case at::ScalarType::Half: { \
using c_type = nv_half; \
return __VA_ARGS__(); \
}
#else
#define _DISPATCH_CASE_F16(c_type, ...)
#endif
#ifdef FLASHINFER_ENABLE_BF16
#define _DISPATCH_CASE_BF16(c_type, ...) \
case at::ScalarType::BFloat16: { \
using c_type = nv_bfloat16; \
return __VA_ARGS__(); \
}
#else
#define _DISPATCH_CASE_BF16(c_type, ...)
#endif
#ifdef FLASHINFER_ENABLE_FP8_E4M3
#define _DISPATCH_CASE_FP8_E4M3(c_type, ...) \
case at::ScalarType::Float8_e4m3fn: { \
using c_type = __nv_fp8_e4m3; \
return __VA_ARGS__(); \
}
#else
#define _DISPATCH_CASE_FP8_E4M3(c_type, ...)
#endif
#ifdef FLASHINFER_ENABLE_FP8_E5M2
#define _DISPATCH_CASE_FP8_E5M2(c_type, ...) \
case at::ScalarType::Float8_e5m2: { \
using c_type = __nv_fp8_e5m2; \
return __VA_ARGS__(); \
}
#else
#define _DISPATCH_CASE_FP8_E5M2(c_type, ...)
#endif
// Should not be used together with _DISPATCH_SF_CASE_FP8_E8M0
#if defined(FLASHINFER_ENABLE_FP4_E2M1) && \
(__CUDACC_VER_MAJOR__ * 10000 + __CUDACC_VER_MINOR__ * 100 >= 120800)
#define _DISPATCH_CASE_FP4_E2M1(c_type, ...) \
case at::ScalarType::Byte: { \
using c_type = __nv_fp4_e2m1; \
return __VA_ARGS__(); \
}
#else
#define _DISPATCH_CASE_FP4_E2M1(c_type, ...)
#endif
// Should not be used together with _DISPATCH_CASE_FP4_E2M1
#if defined(FLASHINFER_ENABLE_FP8_E8M0) && \
(__CUDACC_VER_MAJOR__ * 10000 + __CUDACC_VER_MINOR__ * 100 >= 120800)
#define _DISPATCH_SF_CASE_FP8_E8M0(c_type, ...) \
case at::ScalarType::Byte: { \
using c_type = __nv_fp8_e8m0; \
return __VA_ARGS__(); \
}
#else
#define _DISPATCH_SF_CASE_FP8_E8M0(c_type, ...)
#endif
#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(pytorch_dtype, c_type, ...) \
[&]() -> bool { \
switch (pytorch_dtype) { \
_DISPATCH_CASE_F16(c_type, __VA_ARGS__) \
_DISPATCH_CASE_BF16(c_type, __VA_ARGS__) \
default: \
std::ostringstream oss; \
oss << __PRETTY_FUNCTION__ << " failed to dispatch data type " << pytorch_dtype; \
TORCH_CHECK(false, oss.str()); \
return false; \
} \
}()
#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(pytorch_dtype, c_type, ...) \
[&]() -> bool { \
switch (pytorch_dtype) { \
_DISPATCH_CASE_FP8_E4M3(c_type, __VA_ARGS__) \
_DISPATCH_CASE_FP8_E5M2(c_type, __VA_ARGS__) \
default: \
std::ostringstream oss; \
oss << __PRETTY_FUNCTION__ << " failed to dispatch fp8 data type " << pytorch_dtype; \
TORCH_CHECK(false, oss.str()); \
return false; \
} \
}()
#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_SF(pytorch_dtype, c_type, ...) \
[&]() -> bool { \
switch (pytorch_dtype) { \
_DISPATCH_CASE_F32(c_type, __VA_ARGS__) \
_DISPATCH_SF_CASE_FP8_E8M0(c_type, __VA_ARGS__) \
default: \
std::ostringstream oss; \
oss << __PRETTY_FUNCTION__ << " failed to dispatch scaling factor data type " \
<< pytorch_dtype; \
TORCH_CHECK(false, oss.str()); \
return false; \
} \
}()
#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE(pytorch_dtype, c_type, ...) \
[&]() -> bool { \
switch (pytorch_dtype) { \
_DISPATCH_CASE_F16(c_type, __VA_ARGS__) \
_DISPATCH_CASE_BF16(c_type, __VA_ARGS__) \
_DISPATCH_CASE_FP8_E4M3(c_type, __VA_ARGS__) \
_DISPATCH_CASE_FP8_E5M2(c_type, __VA_ARGS__) \
_DISPATCH_CASE_FP4_E2M1(c_type, __VA_ARGS__) \
default: \
std::ostringstream oss; \
oss << __PRETTY_FUNCTION__ << " failed to dispatch data type " << pytorch_dtype; \
TORCH_CHECK(false, oss.str()); \
return false; \
} \
}()
#define _DISPATCH_SWITCH(var_name, cond, ...) \
[&]() -> bool { \
switch (cond) { \
__VA_ARGS__ \
default: \
std::ostringstream oss; \
oss << __PRETTY_FUNCTION__ << " failed to dispatch " var_name " " << int(cond); \
TORCH_CHECK(false, oss.str()); \
return false; \
} \
}()
#define _DISPATCH_SWITCH_U16x2(var1_name, var2_name, cond1, cond2, ...) \
[&]() -> bool { \
switch (pack_u16(cond1, cond2)) { \
__VA_ARGS__ \
default: \
std::ostringstream oss; \
oss << __PRETTY_FUNCTION__ << " failed to dispatch (" var1_name ", " var2_name "): (" \
<< int(cond1) << ", " << int(cond2) << ")"; \
TORCH_CHECK(false, oss.str()); \
return false; \
} \
}()
#define _DISPATCH_CASE(case_expr, case_var, ...) \
case case_expr: { \
constexpr auto case_var = case_expr; \
return __VA_ARGS__(); \
}
#define _DISPATCH_CASE_U16x2(case_expr1, case_expr2, case_var1, case_var2, ...) \
case pack_u16(case_expr1, case_expr2): { \
constexpr auto case_var1 = case_expr1; \
constexpr auto case_var2 = case_expr2; \
return __VA_ARGS__(); \
}
#define DISPATCH_BOOL(expr, const_expr, ...) \
[&]() -> bool { \
if (expr) { \
constexpr bool const_expr = true; \
return __VA_ARGS__(); \
} else { \
constexpr bool const_expr = false; \
return __VA_ARGS__(); \
} \
}()
inline void check_shape(const at::Tensor& a, const at::Tensor& b, const char* a_name,
const char* b_name) {
TORCH_CHECK(a.dim() == b.dim(), a_name, ".dim() != ", b_name, ".dim(). ", a.dim(), " vs ",
b.dim());
for (int i = 0; i < a.dim(); ++i) {
TORCH_CHECK(a.size(i) == b.size(i), a_name, ".size(", i, ") != ", b_name, ".size(", i, ")");
}
}
inline constexpr uint32_t pack_u16(uint16_t a, uint16_t b) {
return (uint32_t(a) << 16) | uint32_t(b);
}
#define CHECK_GQA_HEAD_DIVISIBLE(num_qo_heads, num_kv_heads) \
TORCH_CHECK(num_qo_heads % num_kv_heads == 0, "num_qo_heads(", num_qo_heads, \
") must be divisible by num_kv_heads(", num_kv_heads, ")")
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_LAST_DIM_CONTIGUOUS(x) \
TORCH_CHECK(x.strides()[x.strides().size() - 1] == 1, #x "must be contiguous at last dimension")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
#define CHECK_INPUT_TYPE(x, st) \
TORCH_CHECK(x.scalar_type() == st, "Inconsistency of Tensor type: " #x)
#define CHECK_INPUT_AND_TYPE(x, st) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x); \
CHECK_INPUT_TYPE(x, st)
#define CHECK_LAST_DIM_CONTIGUOUS_INPUT(x) \
CHECK_CUDA(x); \
CHECK_LAST_DIM_CONTIGUOUS(x)
#define CHECK_DIM(d, x) TORCH_CHECK(x.dim() == d, #x " must be a " #d "D tensor")
#define CHECK_SHAPE(a, b) check_shape(a, b, #a, #b)
#define CHECK_EQ(a, b) TORCH_CHECK((a) == (b), "CHECK_EQ(" #a ", " #b ") failed. ", a, " vs ", b)
#define CHECK_GE(a, b) TORCH_CHECK((a) >= (b), "CHECK_GE(" #a ", " #b ") failed. ", a, " vs ", b)
inline bool is_float8_tensor(const at::Tensor& tensor) {
return tensor.scalar_type() == at::ScalarType::Float8_e4m3fn ||
tensor.scalar_type() == at::ScalarType::Float8_e5m2;
}