45 lines
1021 B
Python
45 lines
1021 B
Python
import torch
|
|
|
|
|
|
def gptq_marlin_repack(
|
|
b_q_weight,
|
|
perm,
|
|
size_k,
|
|
size_n,
|
|
num_bits,
|
|
) -> torch.Tensor:
|
|
return torch.ops.sgl_kernel.gptq_marlin_repack(
|
|
b_q_weight,
|
|
perm,
|
|
size_k,
|
|
size_n,
|
|
num_bits,
|
|
)
|
|
|
|
|
|
def awq_marlin_repack(
|
|
b_q_weight: torch.Tensor, size_k: int, size_n: int, num_bits: int
|
|
) -> torch.Tensor:
|
|
return torch.ops.sgl_kernel.awq_marlin_repack(b_q_weight, size_k, size_n, num_bits)
|
|
|
|
|
|
def awq_marlin_moe_repack(
|
|
b_q_weight: torch.Tensor,
|
|
perm: torch.Tensor,
|
|
size_k: int,
|
|
size_n: int,
|
|
num_bits: int,
|
|
) -> torch.Tensor:
|
|
num_experts = b_q_weight.shape[0]
|
|
assert size_k % 16 == 0
|
|
output = torch.empty(
|
|
(num_experts, size_k // 16, size_n * (num_bits // 2)),
|
|
device=b_q_weight.device,
|
|
dtype=b_q_weight.dtype,
|
|
)
|
|
for e in range(num_experts):
|
|
output[e] = torch.ops.sgl_kernel.awq_marlin_repack(
|
|
b_q_weight[e], size_k, size_n, num_bits
|
|
)
|
|
return output
|