/* * Copyright (c) 2023 by FlashInfer team. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include #include "pytorch_extension_utils.h" using namespace flashinfer; void merge_state(at::Tensor v_a, at::Tensor s_a, at::Tensor v_b, at::Tensor s_b, at::Tensor v_merged, at::Tensor s_merged) { CHECK_INPUT(v_a); CHECK_INPUT(s_a); CHECK_INPUT(v_b); CHECK_INPUT(s_b); auto device = v_a.device(); CHECK_EQ(s_a.device(), device); CHECK_EQ(v_b.device(), device); CHECK_EQ(s_b.device(), device); CHECK_DIM(3, v_a); CHECK_DIM(2, s_a); CHECK_DIM(3, v_b); CHECK_DIM(2, s_b); CHECK_SHAPE(v_a, v_b); CHECK_SHAPE(s_a, s_b); CHECK_EQ(v_a.size(0), s_a.size(0)); CHECK_EQ(v_a.size(1), s_b.size(1)); unsigned int seq_len = v_a.size(0); unsigned int num_heads = v_a.size(1); unsigned int head_dim = v_a.size(2); const c10::cuda::OptionalCUDAGuard device_guard(v_a.device()); auto stream = at::cuda::getCurrentCUDAStream(); bool success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(v_a.scalar_type(), c_type, [&] { cudaError_t status = MergeState(static_cast(v_a.data_ptr()), static_cast(s_a.data_ptr()), static_cast(v_b.data_ptr()), static_cast(s_b.data_ptr()), static_cast(v_merged.data_ptr()), static_cast(s_merged.data_ptr()), seq_len, num_heads, head_dim, stream); TORCH_CHECK(status == cudaSuccess, "MergeState kernel launch failed: ", cudaGetErrorString(status)); return true; }); TORCH_CHECK(success, "MergeState kernel launch failed: unsupported data type"); } void merge_state_in_place(at::Tensor v, at::Tensor s, at::Tensor v_other, at::Tensor s_other, std::optional mask) { CHECK_INPUT(v); CHECK_INPUT(s); CHECK_INPUT(v_other); CHECK_INPUT(s_other); auto device = v.device(); CHECK_EQ(s.device(), device); CHECK_EQ(v_other.device(), device); CHECK_EQ(s_other.device(), device); CHECK_DIM(3, v); CHECK_DIM(2, s); CHECK_DIM(3, v_other); CHECK_DIM(2, s_other); CHECK_SHAPE(v, v_other); CHECK_SHAPE(s, s_other); CHECK_EQ(v.size(0), s.size(0)); CHECK_EQ(v.size(1), s.size(1)); uint8_t* mask_ptr = nullptr; if (mask.has_value()) { CHECK_DIM(1, mask.value()); CHECK_EQ(v.size(0), mask.value().size(0)); CHECK_EQ(mask.value().device(), device); mask_ptr = static_cast(mask.value().data_ptr()); } unsigned int seq_len = v.size(0); unsigned int num_heads = v.size(1); unsigned int head_dim = v.size(2); const c10::cuda::OptionalCUDAGuard device_guard(v.device()); auto stream = at::cuda::getCurrentCUDAStream(); bool success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(v.scalar_type(), c_type, [&] { cudaError_t status = MergeStateInPlace( static_cast(v.data_ptr()), static_cast(s.data_ptr()), static_cast(v_other.data_ptr()), static_cast(s_other.data_ptr()), seq_len, num_heads, head_dim, mask_ptr, stream); TORCH_CHECK(status == cudaSuccess, "MergeStateInPlace kernel launch failed: ", cudaGetErrorString(status)); return true; }); TORCH_CHECK(success, "MergeStateInPlace kernel launch failed: unsupported data type"); } void merge_states(at::Tensor v, at::Tensor s, at::Tensor v_merged, at::Tensor s_merged) { CHECK_INPUT(v); CHECK_INPUT(s); auto device = v.device(); CHECK_EQ(s.device(), device); CHECK_DIM(4, v); CHECK_DIM(3, s); CHECK_EQ(v.size(0), s.size(0)); CHECK_EQ(v.size(1), s.size(1)); CHECK_EQ(v.size(2), s.size(2)); unsigned int seq_len = v.size(0); unsigned int num_index_sets = v.size(1); unsigned int num_heads = v.size(2); unsigned int head_dim = v.size(3); const c10::cuda::OptionalCUDAGuard device_guard(v.device()); auto stream = at::cuda::getCurrentCUDAStream(); bool success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(v.scalar_type(), c_type, [&] { cudaError_t status = MergeStates( static_cast(v.data_ptr()), static_cast(s.data_ptr()), static_cast(v_merged.data_ptr()), static_cast(s_merged.data_ptr()), num_index_sets, seq_len, num_heads, head_dim, stream); TORCH_CHECK(status == cudaSuccess, "MergeStates kernel launch failed: ", cudaGetErrorString(status)); return true; }); TORCH_CHECK(success, "MergeStates kernel launch failed: unsupported data type"); }