""" Copyright (c) 2025 by FlashInfer team. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import ctypes from dataclasses import dataclass from typing import Any, Dict, List, Optional import torch.distributed as dist from torch.distributed import ProcessGroup # NOTE(Zihao): we should use cuda-python instead of ctypes cuda runtime bindings. # However, cuda-python's API is not stable yet, so we use ctypes bindings instead. # which is copied from vllm codebase. cudaError_t = ctypes.c_int cudaMemcpyKind = ctypes.c_int class cudaIpcMemHandle_t(ctypes.Structure): _fields_ = [("internal", ctypes.c_byte * 128)] @dataclass class Function: name: str restype: Any argtypes: List[Any] def find_loaded_library(lib_name) -> Optional[str]: """ According to according to https://man7.org/linux/man-pages/man5/proc_pid_maps.5.html, the file `/proc/self/maps` contains the memory maps of the process, which includes the shared libraries loaded by the process. We can use this file to find the path of the a loaded library. """ # noqa found = False with open("/proc/self/maps") as f: for line in f: if lib_name in line: found = True break if not found: # the library is not loaded in the current process return None # if lib_name is libcudart, we need to match a line with: # address /path/to/libcudart-hash.so.11.0 start = line.index("/") path = line[start:].strip() filename = path.split("/")[-1] assert filename.rpartition(".so")[0].startswith(lib_name), ( f"Unexpected filename: {filename} for library {lib_name}" ) return path class CudaRTLibrary: """CudaRTLibrary""" exported_functions = [ # ​cudaError_t cudaSetDevice ( int device ) Function("cudaSetDevice", cudaError_t, [ctypes.c_int]), # cudaError_t cudaDeviceSynchronize ( void ) Function("cudaDeviceSynchronize", cudaError_t, []), # ​cudaError_t cudaDeviceReset ( void ) Function("cudaDeviceReset", cudaError_t, []), # const char* cudaGetErrorString ( cudaError_t error ) Function("cudaGetErrorString", ctypes.c_char_p, [cudaError_t]), # ​cudaError_t cudaMalloc ( void** devPtr, size_t size ) Function( "cudaMalloc", cudaError_t, [ctypes.POINTER(ctypes.c_void_p), ctypes.c_size_t], ), # ​cudaError_t cudaFree ( void* devPtr ) Function("cudaFree", cudaError_t, [ctypes.c_void_p]), # ​cudaError_t cudaMemset ( void* devPtr, int value, size_t count ) Function( "cudaMemset", cudaError_t, [ctypes.c_void_p, ctypes.c_int, ctypes.c_size_t] ), # ​cudaError_t cudaMemcpy ( void* dst, const void* src, size_t count, cudaMemcpyKind kind ) # noqa Function( "cudaMemcpy", cudaError_t, [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, cudaMemcpyKind], ), # cudaError_t cudaIpcGetMemHandle ( cudaIpcMemHandle_t* handle, void* devPtr ) # noqa Function( "cudaIpcGetMemHandle", cudaError_t, [ctypes.POINTER(cudaIpcMemHandle_t), ctypes.c_void_p], ), # ​cudaError_t cudaIpcOpenMemHandle ( void** devPtr, cudaIpcMemHandle_t handle, unsigned int flags ) # noqa Function( "cudaIpcOpenMemHandle", cudaError_t, [ctypes.POINTER(ctypes.c_void_p), cudaIpcMemHandle_t, ctypes.c_uint], ), ] # class attribute to store the mapping from the path to the library # to avoid loading the same library multiple times path_to_library_cache: Dict[str, Any] = {} # class attribute to store the mapping from library path # to the corresponding dictionary path_to_dict_mapping: Dict[str, Dict[str, Any]] = {} def __init__(self, so_file: Optional[str] = None): if so_file is None: so_file = find_loaded_library("libcudart") assert so_file is not None, "libcudart is not loaded in the current process" if so_file not in CudaRTLibrary.path_to_library_cache: lib = ctypes.CDLL(so_file) CudaRTLibrary.path_to_library_cache[so_file] = lib self.lib = CudaRTLibrary.path_to_library_cache[so_file] if so_file not in CudaRTLibrary.path_to_dict_mapping: _funcs = {} for func in CudaRTLibrary.exported_functions: f = getattr(self.lib, func.name) f.restype = func.restype f.argtypes = func.argtypes _funcs[func.name] = f CudaRTLibrary.path_to_dict_mapping[so_file] = _funcs self.funcs = CudaRTLibrary.path_to_dict_mapping[so_file] def CUDART_CHECK(self, result: cudaError_t) -> None: if result != 0: error_str = self.cudaGetErrorString(result) raise RuntimeError(f"CUDART error: {error_str}") def cudaGetErrorString(self, error: cudaError_t) -> str: return self.funcs["cudaGetErrorString"](error).decode("utf-8") def cudaSetDevice(self, device: int) -> None: self.CUDART_CHECK(self.funcs["cudaSetDevice"](device)) def cudaDeviceSynchronize(self) -> None: self.CUDART_CHECK(self.funcs["cudaDeviceSynchronize"]()) def cudaDeviceReset(self) -> None: self.CUDART_CHECK(self.funcs["cudaDeviceReset"]()) def cudaMalloc(self, size: int) -> ctypes.c_void_p: devPtr = ctypes.c_void_p() self.CUDART_CHECK(self.funcs["cudaMalloc"](ctypes.byref(devPtr), size)) return devPtr def cudaFree(self, devPtr: ctypes.c_void_p) -> None: self.CUDART_CHECK(self.funcs["cudaFree"](devPtr)) def cudaMemset(self, devPtr: ctypes.c_void_p, value: int, count: int) -> None: self.CUDART_CHECK(self.funcs["cudaMemset"](devPtr, value, count)) def cudaMemcpy( self, dst: ctypes.c_void_p, src: ctypes.c_void_p, count: int ) -> None: cudaMemcpyDefault = 4 kind = cudaMemcpyDefault self.CUDART_CHECK(self.funcs["cudaMemcpy"](dst, src, count, kind)) def cudaIpcGetMemHandle(self, devPtr: ctypes.c_void_p) -> cudaIpcMemHandle_t: handle = cudaIpcMemHandle_t() self.CUDART_CHECK( self.funcs["cudaIpcGetMemHandle"](ctypes.byref(handle), devPtr) ) return handle def cudaIpcOpenMemHandle(self, handle: cudaIpcMemHandle_t) -> ctypes.c_void_p: cudaIpcMemLazyEnablePeerAccess = 1 devPtr = ctypes.c_void_p() self.CUDART_CHECK( self.funcs["cudaIpcOpenMemHandle"]( ctypes.byref(devPtr), handle, cudaIpcMemLazyEnablePeerAccess ) ) return devPtr cudart = CudaRTLibrary() def create_shared_buffer( size_in_bytes: int, group: Optional[ProcessGroup] = None ) -> List[int]: """ Creates a shared buffer and returns a list of pointers representing the buffer on all processes in the group. """ pointer = cudart.cudaMalloc(size_in_bytes) handle = cudart.cudaIpcGetMemHandle(pointer) if group is None: group = dist.group.WORLD world_size = dist.get_world_size(group=group) rank = dist.get_rank(group=group) handles = [None] * world_size dist.all_gather_object(handles, handle, group=group) handles = [None] * world_size dist.all_gather_object(handles, handle, group=group) pointers: List[int] = [] for i, h in enumerate(handles): if i == rank: pointers.append(pointer.value) else: pointers.append(cudart.cudaIpcOpenMemHandle(h).value) dist.barrier(group=group) return pointers def free_shared_buffer( pointers: List[int], group: Optional[ProcessGroup] = None ) -> None: """ Frees a shared buffer. """ if group is None: group = dist.group.WORLD rank = dist.get_rank(group=group) if pointers and len(pointers) > rank and pointers[rank] is not None: cudart.cudaFree(ctypes.c_void_p(pointers[rank])) dist.barrier(group=group)