""" Copyright (c) 2025 by FlashInfer team. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import ctypes import functools import importlib.util from typing import Union import cutlass import cutlass._mlir.dialects.cute as _cute_ir import torch from cutlass._mlir import ir from cutlass.cute.typing import AddressSpace, Numeric, Pointer, Type def is_cute_dsl_available() -> bool: return ( importlib.util.find_spec("cutlass") is not None and importlib.util.find_spec("cutlass.cute") is not None ) def get_cutlass_dtype(dtype: str) -> cutlass.dtype: dtype_map = { "float16": cutlass.Float16, "bfloat16": cutlass.BFloat16, "float32": cutlass.Float32, "float8_e5m2": cutlass.Float8E5M2, "float8_e4m3fn": cutlass.Float8E4M3FN, "float8_e8m0fnu": cutlass.Float8E8M0FNU, "float4_e2m1fn": cutlass.Float4E2M1FN, } return dtype_map[dtype] def cutlass_to_torch_dtype(cutlass_dtype): """ Return the corresponding torch.dtype per the given DSL type """ torch_dtype = getattr(torch, cutlass_dtype.__name__.lower(), None) torch_type_map = { cutlass.TFloat32: torch.float32, cutlass.Float32: torch.float32, cutlass.Float16: torch.float16, cutlass.BFloat16: torch.bfloat16, cutlass.Float8E5M2: torch.float8_e5m2, cutlass.Float8E4M3FN: torch.float8_e4m3fn, cutlass.Float8E4M3B11FNUZ: torch.float8_e4m3fnuz, } if torch_dtype is None: torch_dtype = torch_type_map.get(cutlass_dtype) if torch_dtype is None: raise TypeError(f"{cutlass_dtype} is not supported by torch") return torch_dtype @functools.cache def get_num_sm(device: torch.device) -> int: # get the compute capability of the device, which would be cached return torch.cuda.get_device_properties(device).multi_processor_count # WAR for CuTeDSL make_ptr implementation for flashinfer class _Pointer(Pointer): """Runtime representation of a pointer that can inter-operate with various data structures, including numpy arrays and device memory. :param pointer: The pointer to the data :type pointer: int or pointer-like object :param dtype: Data type of the elements pointed to :type dtype: Type :param mem_space: Memory space where the pointer resides, defaults generic :type mem_space: _cute_ir.AddressSpace, optional :param assumed_align: Alignment of input pointer in bytes, defaults None :type assumed_align: int, optional :ivar _pointer: The underlying pointer :ivar _dtype: Data type of the elements :ivar _addr_space: Memory space of the pointer :ivar _assumed_align: Alignment of the pointer in bytes :ivar _desc: C-type descriptor for the pointer :ivar _c_pointer: C-compatible pointer representation """ def __init__( self, pointer, dtype, mem_space: _cute_ir.AddressSpace = _cute_ir.AddressSpace.generic, assumed_align=None, ): self._pointer = pointer self._dtype = dtype self._addr_space = mem_space if assumed_align is None: self._assumed_align = dtype.width // 8 else: self._assumed_align = assumed_align self._desc = None self._c_pointer = None assert int(self._pointer) % self._assumed_align == 0, ( f"pointer must be {self._assumed_align} bytes aligned" ) def size_in_bytes(self) -> int: return ctypes.sizeof(ctypes.c_void_p(int(self._pointer))) def __get_mlir_types__(self): return [self.mlir_type] def __c_pointers__(self): if self._c_pointer is None: self._desc = ctypes.c_void_p(int(self._pointer)) self._c_pointer = ctypes.addressof(self._desc) return [self._c_pointer] def __new_from_mlir_values__(self, values): assert len(values) == 1 return values[0] # Move mlir Type out of __init__ to decouple with mlir Context @property def mlir_type(self) -> ir.Type: return _cute_ir.PtrType.get( self._dtype.mlir_type, self._addr_space, self._assumed_align ) @property def dtype(self) -> Type[Numeric]: return self._dtype @property def memspace(self): return self._addr_space def align(self, min_align: int, *, loc=None, ip=None) -> Pointer: raise NotImplementedError("align is not supported in runtime") def verify(self, expected_py_type): # if expected_py_type is Pointer: # return True # elif isinstance(expected_py_type, ir.Value) and expected_py_type.ty is Pointer: # return True if expected_py_type is Pointer or ( isinstance(expected_py_type, ir.Value) and expected_py_type.ty is Pointer ): return True return False def __str__(self) -> str: return f"Ptr<0x{int(self._pointer):016x}@{self._addr_space}>" def __repr__(self): return self.__str__() def make_ptr( dtype: Type[Numeric], value: Union[int, ctypes._Pointer], mem_space: AddressSpace = AddressSpace.generic, assumed_align=None, ) -> Pointer: """Create a pointer from a memory address :param dtype: Data type of the pointer elements :type dtype: Type[Numeric] :param value: Memory address as integer or ctypes pointer :type value: Union[int, ctypes._Pointer] :param mem_space: Memory address space, defaults to AddressSpace.generic :type mem_space: AddressSpace, optional :param assumed_align: Alignment in bytes, defaults to None :type assumed_align: int, optional :return: A pointer object :rtype: Pointer .. code-block:: python import numpy as np import ctypes from cutlass import Float32 from cutlass.cute.runtime import make_ptr # Create a numpy array a = np.random.randn(16, 32).astype(np.float32) # Get pointer address as integer ptr_address = a.ctypes.data_as(ctypes.POINTER(ctypes.c_float)) # Create pointer from address y = make_ptr(cutlass.Float32, ptr_address) """ # check if value is int or ctypes.POINTER if isinstance(value, int): address_value = value elif isinstance(value, ctypes._Pointer): # get address value address_value = ctypes.cast(value, ctypes.c_void_p).value assert address_value is not None, "Pointer address is None" else: raise TypeError( f"Expect int or ctypes.POINTER for value but got {type(value)=}" ) return _Pointer(address_value, dtype, mem_space, assumed_align=assumed_align)