Browse Source

Fix topk (#187)

* fix typo

* use std::partial_sort to replace std::sort
PF Luo 2 years ago
parent
commit
baaea55740
1 changed files with 4 additions and 4 deletions
  1. 4 4
      sherpa-ncnn/csrc/math.h

+ 4 - 4
sherpa-ncnn/csrc/math.h

@@ -106,10 +106,10 @@ std::vector<int32_t> TopkIndex(const T *vec, int32_t size, int32_t topk) {
   std::vector<int32_t> vec_index(size);
   std::iota(vec_index.begin(), vec_index.end(), 0);
 
-  std::sort(vec_index.begin(), vec_index.end(),
-            [vec](int32_t index_1, int32_t index_2) {
-              return vec[index_1] > vec[index_2];
-            });
+  std::partial_sort(vec_index.begin(), vec_index.begin() + topk,
+                    vec_index.end(), [vec](int32_t index_1, int32_t index_2) {
+                      return vec[index_1] > vec[index_2];
+                    });
 
   int32_t k_num = std::min<int32_t>(size, topk);
   std::vector<int32_t> index(vec_index.begin(), vec_index.begin() + k_num);