sglang_v0.5.2/flashinfer_0.3.1/flashinfer/cute_dsl/utils.py

224 lines
7.0 KiB
Python

"""
Copyright (c) 2025 by FlashInfer team.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import ctypes
import functools
import importlib.util
from typing import Union
import cutlass
import cutlass._mlir.dialects.cute as _cute_ir
import torch
from cutlass._mlir import ir
from cutlass.cute.typing import AddressSpace, Numeric, Pointer, Type
def is_cute_dsl_available() -> bool:
return (
importlib.util.find_spec("cutlass") is not None
and importlib.util.find_spec("cutlass.cute") is not None
)
def get_cutlass_dtype(dtype: str) -> cutlass.dtype:
dtype_map = {
"float16": cutlass.Float16,
"bfloat16": cutlass.BFloat16,
"float32": cutlass.Float32,
"float8_e5m2": cutlass.Float8E5M2,
"float8_e4m3fn": cutlass.Float8E4M3FN,
"float8_e8m0fnu": cutlass.Float8E8M0FNU,
"float4_e2m1fn": cutlass.Float4E2M1FN,
}
return dtype_map[dtype]
def cutlass_to_torch_dtype(cutlass_dtype):
"""
Return the corresponding torch.dtype per the given DSL type
"""
torch_dtype = getattr(torch, cutlass_dtype.__name__.lower(), None)
torch_type_map = {
cutlass.TFloat32: torch.float32,
cutlass.Float32: torch.float32,
cutlass.Float16: torch.float16,
cutlass.BFloat16: torch.bfloat16,
cutlass.Float8E5M2: torch.float8_e5m2,
cutlass.Float8E4M3FN: torch.float8_e4m3fn,
cutlass.Float8E4M3B11FNUZ: torch.float8_e4m3fnuz,
}
if torch_dtype is None:
torch_dtype = torch_type_map.get(cutlass_dtype)
if torch_dtype is None:
raise TypeError(f"{cutlass_dtype} is not supported by torch")
return torch_dtype
@functools.cache
def get_num_sm(device: torch.device) -> int:
# get the compute capability of the device, which would be cached
return torch.cuda.get_device_properties(device).multi_processor_count
# WAR for CuTeDSL make_ptr implementation for flashinfer
class _Pointer(Pointer):
"""Runtime representation of a pointer that can inter-operate with
various data structures, including numpy arrays and device memory.
:param pointer: The pointer to the data
:type pointer: int or pointer-like object
:param dtype: Data type of the elements pointed to
:type dtype: Type
:param mem_space: Memory space where the pointer resides, defaults generic
:type mem_space: _cute_ir.AddressSpace, optional
:param assumed_align: Alignment of input pointer in bytes, defaults None
:type assumed_align: int, optional
:ivar _pointer: The underlying pointer
:ivar _dtype: Data type of the elements
:ivar _addr_space: Memory space of the pointer
:ivar _assumed_align: Alignment of the pointer in bytes
:ivar _desc: C-type descriptor for the pointer
:ivar _c_pointer: C-compatible pointer representation
"""
def __init__(
self,
pointer,
dtype,
mem_space: _cute_ir.AddressSpace = _cute_ir.AddressSpace.generic,
assumed_align=None,
):
self._pointer = pointer
self._dtype = dtype
self._addr_space = mem_space
if assumed_align is None:
self._assumed_align = dtype.width // 8
else:
self._assumed_align = assumed_align
self._desc = None
self._c_pointer = None
assert int(self._pointer) % self._assumed_align == 0, (
f"pointer must be {self._assumed_align} bytes aligned"
)
def size_in_bytes(self) -> int:
return ctypes.sizeof(ctypes.c_void_p(int(self._pointer)))
def __get_mlir_types__(self):
return [self.mlir_type]
def __c_pointers__(self):
if self._c_pointer is None:
self._desc = ctypes.c_void_p(int(self._pointer))
self._c_pointer = ctypes.addressof(self._desc)
return [self._c_pointer]
def __new_from_mlir_values__(self, values):
assert len(values) == 1
return values[0]
# Move mlir Type out of __init__ to decouple with mlir Context
@property
def mlir_type(self) -> ir.Type:
return _cute_ir.PtrType.get(
self._dtype.mlir_type, self._addr_space, self._assumed_align
)
@property
def dtype(self) -> Type[Numeric]:
return self._dtype
@property
def memspace(self):
return self._addr_space
def align(self, min_align: int, *, loc=None, ip=None) -> Pointer:
raise NotImplementedError("align is not supported in runtime")
def verify(self, expected_py_type):
# if expected_py_type is Pointer:
# return True
# elif isinstance(expected_py_type, ir.Value) and expected_py_type.ty is Pointer:
# return True
if expected_py_type is Pointer or (
isinstance(expected_py_type, ir.Value) and expected_py_type.ty is Pointer
):
return True
return False
def __str__(self) -> str:
return f"Ptr<0x{int(self._pointer):016x}@{self._addr_space}>"
def __repr__(self):
return self.__str__()
def make_ptr(
dtype: Type[Numeric],
value: Union[int, ctypes._Pointer],
mem_space: AddressSpace = AddressSpace.generic,
assumed_align=None,
) -> Pointer:
"""Create a pointer from a memory address
:param dtype: Data type of the pointer elements
:type dtype: Type[Numeric]
:param value: Memory address as integer or ctypes pointer
:type value: Union[int, ctypes._Pointer]
:param mem_space: Memory address space, defaults to AddressSpace.generic
:type mem_space: AddressSpace, optional
:param assumed_align: Alignment in bytes, defaults to None
:type assumed_align: int, optional
:return: A pointer object
:rtype: Pointer
.. code-block:: python
import numpy as np
import ctypes
from cutlass import Float32
from cutlass.cute.runtime import make_ptr
# Create a numpy array
a = np.random.randn(16, 32).astype(np.float32)
# Get pointer address as integer
ptr_address = a.ctypes.data_as(ctypes.POINTER(ctypes.c_float))
# Create pointer from address
y = make_ptr(cutlass.Float32, ptr_address)
"""
# check if value is int or ctypes.POINTER
if isinstance(value, int):
address_value = value
elif isinstance(value, ctypes._Pointer):
# get address value
address_value = ctypes.cast(value, ctypes.c_void_p).value
assert address_value is not None, "Pointer address is None"
else:
raise TypeError(
f"Expect int or ctypes.POINTER for value but got {type(value)=}"
)
return _Pointer(address_value, dtype, mem_space, assumed_align=assumed_align)