224 lines
7.0 KiB
Python
224 lines
7.0 KiB
Python
"""
|
|
Copyright (c) 2025 by FlashInfer team.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
"""
|
|
|
|
import ctypes
|
|
import functools
|
|
import importlib.util
|
|
from typing import Union
|
|
|
|
import cutlass
|
|
import cutlass._mlir.dialects.cute as _cute_ir
|
|
import torch
|
|
from cutlass._mlir import ir
|
|
from cutlass.cute.typing import AddressSpace, Numeric, Pointer, Type
|
|
|
|
|
|
def is_cute_dsl_available() -> bool:
|
|
return (
|
|
importlib.util.find_spec("cutlass") is not None
|
|
and importlib.util.find_spec("cutlass.cute") is not None
|
|
)
|
|
|
|
|
|
def get_cutlass_dtype(dtype: str) -> cutlass.dtype:
|
|
dtype_map = {
|
|
"float16": cutlass.Float16,
|
|
"bfloat16": cutlass.BFloat16,
|
|
"float32": cutlass.Float32,
|
|
"float8_e5m2": cutlass.Float8E5M2,
|
|
"float8_e4m3fn": cutlass.Float8E4M3FN,
|
|
"float8_e8m0fnu": cutlass.Float8E8M0FNU,
|
|
"float4_e2m1fn": cutlass.Float4E2M1FN,
|
|
}
|
|
return dtype_map[dtype]
|
|
|
|
|
|
def cutlass_to_torch_dtype(cutlass_dtype):
|
|
"""
|
|
Return the corresponding torch.dtype per the given DSL type
|
|
"""
|
|
torch_dtype = getattr(torch, cutlass_dtype.__name__.lower(), None)
|
|
|
|
torch_type_map = {
|
|
cutlass.TFloat32: torch.float32,
|
|
cutlass.Float32: torch.float32,
|
|
cutlass.Float16: torch.float16,
|
|
cutlass.BFloat16: torch.bfloat16,
|
|
cutlass.Float8E5M2: torch.float8_e5m2,
|
|
cutlass.Float8E4M3FN: torch.float8_e4m3fn,
|
|
cutlass.Float8E4M3B11FNUZ: torch.float8_e4m3fnuz,
|
|
}
|
|
if torch_dtype is None:
|
|
torch_dtype = torch_type_map.get(cutlass_dtype)
|
|
|
|
if torch_dtype is None:
|
|
raise TypeError(f"{cutlass_dtype} is not supported by torch")
|
|
return torch_dtype
|
|
|
|
|
|
@functools.cache
|
|
def get_num_sm(device: torch.device) -> int:
|
|
# get the compute capability of the device, which would be cached
|
|
return torch.cuda.get_device_properties(device).multi_processor_count
|
|
|
|
|
|
# WAR for CuTeDSL make_ptr implementation for flashinfer
|
|
class _Pointer(Pointer):
|
|
"""Runtime representation of a pointer that can inter-operate with
|
|
various data structures, including numpy arrays and device memory.
|
|
|
|
:param pointer: The pointer to the data
|
|
:type pointer: int or pointer-like object
|
|
:param dtype: Data type of the elements pointed to
|
|
:type dtype: Type
|
|
:param mem_space: Memory space where the pointer resides, defaults generic
|
|
:type mem_space: _cute_ir.AddressSpace, optional
|
|
:param assumed_align: Alignment of input pointer in bytes, defaults None
|
|
:type assumed_align: int, optional
|
|
|
|
:ivar _pointer: The underlying pointer
|
|
:ivar _dtype: Data type of the elements
|
|
:ivar _addr_space: Memory space of the pointer
|
|
:ivar _assumed_align: Alignment of the pointer in bytes
|
|
:ivar _desc: C-type descriptor for the pointer
|
|
:ivar _c_pointer: C-compatible pointer representation
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
pointer,
|
|
dtype,
|
|
mem_space: _cute_ir.AddressSpace = _cute_ir.AddressSpace.generic,
|
|
assumed_align=None,
|
|
):
|
|
self._pointer = pointer
|
|
self._dtype = dtype
|
|
self._addr_space = mem_space
|
|
|
|
if assumed_align is None:
|
|
self._assumed_align = dtype.width // 8
|
|
else:
|
|
self._assumed_align = assumed_align
|
|
|
|
self._desc = None
|
|
self._c_pointer = None
|
|
assert int(self._pointer) % self._assumed_align == 0, (
|
|
f"pointer must be {self._assumed_align} bytes aligned"
|
|
)
|
|
|
|
def size_in_bytes(self) -> int:
|
|
return ctypes.sizeof(ctypes.c_void_p(int(self._pointer)))
|
|
|
|
def __get_mlir_types__(self):
|
|
return [self.mlir_type]
|
|
|
|
def __c_pointers__(self):
|
|
if self._c_pointer is None:
|
|
self._desc = ctypes.c_void_p(int(self._pointer))
|
|
self._c_pointer = ctypes.addressof(self._desc)
|
|
return [self._c_pointer]
|
|
|
|
def __new_from_mlir_values__(self, values):
|
|
assert len(values) == 1
|
|
return values[0]
|
|
|
|
# Move mlir Type out of __init__ to decouple with mlir Context
|
|
@property
|
|
def mlir_type(self) -> ir.Type:
|
|
return _cute_ir.PtrType.get(
|
|
self._dtype.mlir_type, self._addr_space, self._assumed_align
|
|
)
|
|
|
|
@property
|
|
def dtype(self) -> Type[Numeric]:
|
|
return self._dtype
|
|
|
|
@property
|
|
def memspace(self):
|
|
return self._addr_space
|
|
|
|
def align(self, min_align: int, *, loc=None, ip=None) -> Pointer:
|
|
raise NotImplementedError("align is not supported in runtime")
|
|
|
|
def verify(self, expected_py_type):
|
|
# if expected_py_type is Pointer:
|
|
# return True
|
|
# elif isinstance(expected_py_type, ir.Value) and expected_py_type.ty is Pointer:
|
|
# return True
|
|
if expected_py_type is Pointer or (
|
|
isinstance(expected_py_type, ir.Value) and expected_py_type.ty is Pointer
|
|
):
|
|
return True
|
|
|
|
return False
|
|
|
|
def __str__(self) -> str:
|
|
return f"Ptr<0x{int(self._pointer):016x}@{self._addr_space}>"
|
|
|
|
def __repr__(self):
|
|
return self.__str__()
|
|
|
|
|
|
def make_ptr(
|
|
dtype: Type[Numeric],
|
|
value: Union[int, ctypes._Pointer],
|
|
mem_space: AddressSpace = AddressSpace.generic,
|
|
assumed_align=None,
|
|
) -> Pointer:
|
|
"""Create a pointer from a memory address
|
|
|
|
:param dtype: Data type of the pointer elements
|
|
:type dtype: Type[Numeric]
|
|
:param value: Memory address as integer or ctypes pointer
|
|
:type value: Union[int, ctypes._Pointer]
|
|
:param mem_space: Memory address space, defaults to AddressSpace.generic
|
|
:type mem_space: AddressSpace, optional
|
|
:param assumed_align: Alignment in bytes, defaults to None
|
|
:type assumed_align: int, optional
|
|
:return: A pointer object
|
|
:rtype: Pointer
|
|
|
|
.. code-block:: python
|
|
|
|
import numpy as np
|
|
import ctypes
|
|
|
|
from cutlass import Float32
|
|
from cutlass.cute.runtime import make_ptr
|
|
|
|
# Create a numpy array
|
|
a = np.random.randn(16, 32).astype(np.float32)
|
|
|
|
# Get pointer address as integer
|
|
ptr_address = a.ctypes.data_as(ctypes.POINTER(ctypes.c_float))
|
|
|
|
# Create pointer from address
|
|
y = make_ptr(cutlass.Float32, ptr_address)
|
|
"""
|
|
# check if value is int or ctypes.POINTER
|
|
if isinstance(value, int):
|
|
address_value = value
|
|
elif isinstance(value, ctypes._Pointer):
|
|
# get address value
|
|
address_value = ctypes.cast(value, ctypes.c_void_p).value
|
|
assert address_value is not None, "Pointer address is None"
|
|
else:
|
|
raise TypeError(
|
|
f"Expect int or ctypes.POINTER for value but got {type(value)=}"
|
|
)
|
|
|
|
return _Pointer(address_value, dtype, mem_space, assumed_align=assumed_align)
|