sglang_v0.5.2/sglang/sgl-kernel/python/sgl_kernel/marlin.py

45 lines
1021 B
Python

import torch
def gptq_marlin_repack(
b_q_weight,
perm,
size_k,
size_n,
num_bits,
) -> torch.Tensor:
return torch.ops.sgl_kernel.gptq_marlin_repack(
b_q_weight,
perm,
size_k,
size_n,
num_bits,
)
def awq_marlin_repack(
b_q_weight: torch.Tensor, size_k: int, size_n: int, num_bits: int
) -> torch.Tensor:
return torch.ops.sgl_kernel.awq_marlin_repack(b_q_weight, size_k, size_n, num_bits)
def awq_marlin_moe_repack(
b_q_weight: torch.Tensor,
perm: torch.Tensor,
size_k: int,
size_n: int,
num_bits: int,
) -> torch.Tensor:
num_experts = b_q_weight.shape[0]
assert size_k % 16 == 0
output = torch.empty(
(num_experts, size_k // 16, size_n * (num_bits // 2)),
device=b_q_weight.device,
dtype=b_q_weight.dtype,
)
for e in range(num_experts):
output[e] = torch.ops.sgl_kernel.awq_marlin_repack(
b_q_weight[e], size_k, size_n, num_bits
)
return output