12 lines
381 B
Python
12 lines
381 B
Python
import torch
|
|
|
|
|
|
def fast_topk(values, topk, dim):
|
|
if topk == 1:
|
|
# Use max along the specified dimension to get both value and index
|
|
return torch.max(values, dim=dim, keepdim=True)
|
|
else:
|
|
# Use topk for efficiency with larger k values
|
|
# TODO: implement faster cuda kernels for large vocab sizes
|
|
return torch.topk(values, topk, dim=dim)
|