import torch def fast_topk(values, topk, dim): if topk == 1: # Use max along the specified dimension to get both value and index return torch.max(values, dim=dim, keepdim=True) else: # Use topk for efficiency with larger k values # TODO: implement faster cuda kernels for large vocab sizes return torch.topk(values, topk, dim=dim)