sglang_v0.5.2/sglang/sgl-kernel/python/sgl_kernel/top_k.py

12 lines
381 B
Python

import torch
def fast_topk(values, topk, dim):
if topk == 1:
# Use max along the specified dimension to get both value and index
return torch.max(values, dim=dim, keepdim=True)
else:
# Use topk for efficiency with larger k values
# TODO: implement faster cuda kernels for large vocab sizes
return torch.topk(values, topk, dim=dim)