import triton # type: ignore[import] import triton.language as tl # type: ignore[import] def matmul_get_configs(): return [ triton.Config( { "BLOCK_SIZE_M": BM, "BLOCK_SIZE_N": BN, "BLOCK_SIZE_K": BK, "GROUP_SIZE_M": 8, }, num_stages=s, num_warps=w, ) for BM in [128] for BN in [128] for BK in [64] for s in ([3]) for w in [4] ] def _matmul_launch_metadata(grid, kernel, args): ret = {} M, N, K = args["M"], args["N"], args["K"] ret["name"] = f"{kernel.name} [M={M}, N={N}, K={K}]" if "c_ptr" in args: bytes_per_elem = args["c_ptr"].element_size() else: bytes_per_elem = 1 if args["FP8_OUTPUT"] else 2 ret[f"flops{bytes_per_elem * 8}"] = 2.0 * M * N * K ret["bytes"] = bytes_per_elem * (M * K + N * K + M * N) return ret @triton.jit def _compute_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS): group_id = tile_id // num_pid_in_group first_pid_m = group_id * GROUP_SIZE_M group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) pid_m = first_pid_m + (tile_id % group_size_m) pid_n = (tile_id % num_pid_in_group) // group_size_m return pid_m, pid_n @triton.autotune( configs=matmul_get_configs(), key=["M", "N", "K"], ) @triton.jit(launch_metadata=_matmul_launch_metadata) def gemm_kernel_persistent( a_ptr, b_ptr, c_ptr, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, alpha, beta, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, NUM_SMS: tl.constexpr, ): start_pid = tl.program_id(axis=0) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) k_tiles = tl.cdiv(K, BLOCK_SIZE_K) num_tiles = num_pid_m * num_pid_n # NOTE: There is currently a bug in blackwell pipelining that means it can't handle a value being # used in both the prologue and epilogue, so we duplicate the counters as a work-around. tile_id_c = start_pid - NUM_SMS offs_k_for_mask = tl.arange(0, BLOCK_SIZE_K) num_pid_in_group = GROUP_SIZE_M * num_pid_n for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, flatten=True): pid_m, pid_n = _compute_pid( tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS ) start_m = pid_m * BLOCK_SIZE_M start_n = pid_n * BLOCK_SIZE_N offs_am = start_m + tl.arange(0, BLOCK_SIZE_M) offs_bn = start_n + tl.arange(0, BLOCK_SIZE_N) offs_am = tl.where(offs_am < M, offs_am, 0) offs_bn = tl.where(offs_bn < N, offs_bn, 0) offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_SIZE_M), BLOCK_SIZE_M) offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), BLOCK_SIZE_N) accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for ki in range(k_tiles): offs_k = ki * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) a_ptrs = a_ptr + ( offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak ) b_ptrs = b_ptr + ( offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn ) a = tl.load( a_ptrs, mask=offs_k_for_mask[None, :] < K - ki * BLOCK_SIZE_K, other=0.0 ) b = tl.load( b_ptrs, mask=offs_k_for_mask[:, None] < K - ki * BLOCK_SIZE_K, other=0.0 ) accumulator = tl.dot(a, b, accumulator) tile_id_c += NUM_SMS pid_m, pid_n = _compute_pid( tile_id_c, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS ) offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) c = accumulator.to(c_ptr.dtype.element_ty) c = tl.fma(c, alpha, beta * tl.load(c_ptrs, mask=c_mask)) tl.store(c_ptrs, c, mask=c_mask) @triton.jit(launch_metadata=_matmul_launch_metadata) def gemm_kernel_descriptor_persistent( a_ptr, b_ptr, c_ptr, # M, N, K, # alpha, beta, BLOCK_SIZE_M: tl.constexpr, # BLOCK_SIZE_N: tl.constexpr, # BLOCK_SIZE_K: tl.constexpr, # GROUP_SIZE_M: tl.constexpr, # EPILOGUE_SUBTILE: tl.constexpr, # NUM_SMS: tl.constexpr, ): # dtype = c_ptr.dtype.element_ty start_pid = tl.program_id(axis=0) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) k_tiles = tl.cdiv(K, BLOCK_SIZE_K) num_tiles = num_pid_m * num_pid_n a_desc = tl.make_tensor_descriptor( a_ptr, shape=[M, K], strides=[K, 1], block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K], ) b_desc = tl.make_tensor_descriptor( b_ptr, shape=[N, K], strides=[K, 1], block_shape=[BLOCK_SIZE_N, BLOCK_SIZE_K], ) c_desc = tl.make_tensor_descriptor( c_ptr, shape=[M, N], strides=[N, 1], block_shape=[ BLOCK_SIZE_M, BLOCK_SIZE_N if not EPILOGUE_SUBTILE else BLOCK_SIZE_N // 2, ], ) # tile_id_c is used in the epilogue to break the dependency between # the prologue and the epilogue tile_id_c = start_pid - NUM_SMS num_pid_in_group = GROUP_SIZE_M * num_pid_n for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, flatten=True): pid_m, pid_n = _compute_pid( tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS ) offs_am = pid_m * BLOCK_SIZE_M offs_bn = pid_n * BLOCK_SIZE_N accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for ki in range(k_tiles): offs_k = ki * BLOCK_SIZE_K a = a_desc.load([offs_am, offs_k]) b = b_desc.load([offs_bn, offs_k]) accumulator = tl.dot(a, b.T, accumulator) tile_id_c += NUM_SMS pid_m, pid_n = _compute_pid( tile_id_c, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS ) offs_cm = pid_m * BLOCK_SIZE_M offs_cn = pid_n * BLOCK_SIZE_N if EPILOGUE_SUBTILE: acc = tl.reshape(accumulator, (BLOCK_SIZE_M, 2, BLOCK_SIZE_N // 2)) acc = tl.permute(acc, (0, 2, 1)) acc0, acc1 = tl.split(acc) acc0 = tl.fma(acc0, alpha, beta * c_desc.load([offs_cm, offs_cn])) acc1 = tl.fma( acc1, alpha, beta * c_desc.load([offs_cm, offs_cn + BLOCK_SIZE_N // 2]) ) c0 = acc0.to(dtype) c_desc.store([offs_cm, offs_cn], c0) c1 = acc1.to(dtype) c_desc.store([offs_cm, offs_cn + BLOCK_SIZE_N // 2], c1) else: accumulator = tl.fma( accumulator, alpha, beta * c_desc.load([offs_cm, offs_cn]) ) c = accumulator.to(dtype) c_desc.store([offs_cm, offs_cn], c) # only for testing @triton.autotune( configs=matmul_get_configs(), key=["M", "N", "K"], ) @triton.jit(launch_metadata=_matmul_launch_metadata) def gemm_kernel( a_ptr, b_ptr, c_ptr, # M, N, K, # stride_am, stride_ak, # stride_bk, stride_bn, # stride_cm, stride_cn, # alpha, beta, BLOCK_SIZE_M: tl.constexpr, # BLOCK_SIZE_N: tl.constexpr, # BLOCK_SIZE_K: tl.constexpr, # GROUP_SIZE_M: tl.constexpr, # ): pid = tl.program_id(axis=0) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) num_pid_in_group = GROUP_SIZE_M * num_pid_n group_id = pid // num_pid_in_group first_pid_m = group_id * GROUP_SIZE_M group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) pid_m = first_pid_m + (pid % group_size_m) pid_n = (pid % num_pid_in_group) // group_size_m start_m = pid_m * BLOCK_SIZE_M start_n = pid_n * BLOCK_SIZE_N offs_am = start_m + tl.arange(0, BLOCK_SIZE_M) offs_bn = start_n + tl.arange(0, BLOCK_SIZE_N) offs_am = tl.where(offs_am < M, offs_am, 0) offs_bn = tl.where(offs_bn < N, offs_bn, 0) offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_SIZE_M), BLOCK_SIZE_M) offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), BLOCK_SIZE_N) offs_k = tl.arange(0, BLOCK_SIZE_K) a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) accumulator = tl.dot(a, b, accumulator) a_ptrs += BLOCK_SIZE_K * stride_ak b_ptrs += BLOCK_SIZE_K * stride_bk c = accumulator.to(c_ptr.dtype.element_ty) offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) c = tl.fma(c, alpha, beta * tl.load(c_ptrs, mask=c_mask)) tl.store(c_ptrs, c, mask=c_mask)