from typing import Optional import torch # mamba def causal_conv1d_fwd( x: torch.Tensor, weight: torch.Tensor, bias_: Optional[torch.Tensor], conv_states: Optional[torch.Tensor], query_start_loc: Optional[torch.Tensor], cache_indices: Optional[torch.Tensor], has_initial_state: Optional[torch.Tensor], silu_activation: bool, pad_slot_id: int, ): torch.ops.sgl_kernel.causal_conv1d_fwd( x, weight, bias_, conv_states, query_start_loc, cache_indices, has_initial_state, silu_activation, pad_slot_id, ) def causal_conv1d_update( x: torch.Tensor, conv_state: torch.Tensor, weight: torch.Tensor, bias_: Optional[torch.Tensor], silu_activation: bool, cache_seqlens: Optional[torch.Tensor], conv_state_indices: Optional[torch.Tensor], pad_slot_id: int, ): torch.ops.sgl_kernel.causal_conv1d_update( x, conv_state, weight, bias_, silu_activation, cache_seqlens, conv_state_indices, pad_slot_id, )