|
@@ -106,10 +106,10 @@ std::vector<int32_t> TopkIndex(const T *vec, int32_t size, int32_t topk) {
|
|
|
std::vector<int32_t> vec_index(size);
|
|
|
std::iota(vec_index.begin(), vec_index.end(), 0);
|
|
|
|
|
|
- std::sort(vec_index.begin(), vec_index.end(),
|
|
|
- [vec](int32_t index_1, int32_t index_2) {
|
|
|
- return vec[index_1] > vec[index_2];
|
|
|
- });
|
|
|
+ std::partial_sort(vec_index.begin(), vec_index.begin() + topk,
|
|
|
+ vec_index.end(), [vec](int32_t index_1, int32_t index_2) {
|
|
|
+ return vec[index_1] > vec[index_2];
|
|
|
+ });
|
|
|
|
|
|
int32_t k_num = std::min<int32_t>(size, topk);
|
|
|
std::vector<int32_t> index(vec_index.begin(), vec_index.begin() + k_num);
|