51 lines
1.1 KiB
Python
51 lines
1.1 KiB
Python
from typing import Optional
|
|
|
|
import torch
|
|
|
|
|
|
# mamba
|
|
def causal_conv1d_fwd(
|
|
x: torch.Tensor,
|
|
weight: torch.Tensor,
|
|
bias_: Optional[torch.Tensor],
|
|
conv_states: Optional[torch.Tensor],
|
|
query_start_loc: Optional[torch.Tensor],
|
|
cache_indices: Optional[torch.Tensor],
|
|
has_initial_state: Optional[torch.Tensor],
|
|
silu_activation: bool,
|
|
pad_slot_id: int,
|
|
):
|
|
torch.ops.sgl_kernel.causal_conv1d_fwd(
|
|
x,
|
|
weight,
|
|
bias_,
|
|
conv_states,
|
|
query_start_loc,
|
|
cache_indices,
|
|
has_initial_state,
|
|
silu_activation,
|
|
pad_slot_id,
|
|
)
|
|
|
|
|
|
def causal_conv1d_update(
|
|
x: torch.Tensor,
|
|
conv_state: torch.Tensor,
|
|
weight: torch.Tensor,
|
|
bias_: Optional[torch.Tensor],
|
|
silu_activation: bool,
|
|
cache_seqlens: Optional[torch.Tensor],
|
|
conv_state_indices: Optional[torch.Tensor],
|
|
pad_slot_id: int,
|
|
):
|
|
torch.ops.sgl_kernel.causal_conv1d_update(
|
|
x,
|
|
conv_state,
|
|
weight,
|
|
bias_,
|
|
silu_activation,
|
|
cache_seqlens,
|
|
conv_state_indices,
|
|
pad_slot_id,
|
|
)
|