math.h 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. /**
  2. * Copyright (c) 2022 Xiaomi Corporation (authors: Daniel Povey)
  3. * Copyright (c) 2022 (Pingfeng Luo)
  4. *
  5. * See LICENSE for clarification regarding multiple authors
  6. *
  7. * Licensed under the Apache License, Version 2.0 (the "License");
  8. * you may not use this file except in compliance with the License.
  9. * You may obtain a copy of the License at
  10. *
  11. * http://www.apache.org/licenses/LICENSE-2.0
  12. *
  13. * Unless required by applicable law or agreed to in writing, software
  14. * distributed under the License is distributed on an "AS IS" BASIS,
  15. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  16. * See the License for the specific language governing permissions and
  17. * limitations under the License.
  18. */
  19. // This file is copied from k2/csrc/utils.h
  20. #ifndef SHERPA_NCNN_CSRC_MATH_H_
  21. #define SHERPA_NCNN_CSRC_MATH_H_
  22. #include <algorithm>
  23. #include <cassert>
  24. #include <cmath>
  25. #include <numeric>
  26. #include <vector>
  27. namespace sherpa_ncnn {
  28. // logf(FLT_EPSILON)
  29. #define SHERPA_NCNN_MIN_LOG_DIFF_FLOAT -15.9423847198486328125f
  30. // log(DBL_EPSILON)
  31. #define SHERPA_NCNN_MIN_LOG_DIFF_DOUBLE \
  32. -36.0436533891171535515240975655615329742431640625
  33. template <typename T>
  34. struct LogAdd;
  35. template <>
  36. struct LogAdd<double> {
  37. double operator()(double x, double y) const {
  38. double diff;
  39. if (x < y) {
  40. diff = x - y;
  41. x = y;
  42. } else {
  43. diff = y - x;
  44. }
  45. // diff is negative. x is now the larger one.
  46. if (diff >= SHERPA_NCNN_MIN_LOG_DIFF_DOUBLE) {
  47. double res;
  48. res = x + log1p(exp(diff));
  49. return res;
  50. }
  51. return x; // return the larger one.
  52. }
  53. };
  54. template <>
  55. struct LogAdd<float> {
  56. float operator()(float x, float y) const {
  57. float diff;
  58. if (x < y) {
  59. diff = x - y;
  60. x = y;
  61. } else {
  62. diff = y - x;
  63. }
  64. // diff is negative. x is now the larger one.
  65. if (diff >= SHERPA_NCNN_MIN_LOG_DIFF_DOUBLE) {
  66. float res;
  67. res = x + log1pf(expf(diff));
  68. return res;
  69. }
  70. return x; // return the larger one.
  71. }
  72. };
  73. template <class T>
  74. void LogSoftmax(T *input, int32_t input_len) {
  75. assert(input);
  76. T m = *std::max_element(input, input + input_len);
  77. T sum = 0.0;
  78. for (int32_t i = 0; i < input_len; i++) {
  79. sum += exp(input[i] - m);
  80. }
  81. T offset = m + log(sum);
  82. for (int32_t i = 0; i < input_len; i++) {
  83. input[i] -= offset;
  84. }
  85. }
  86. template <class T>
  87. std::vector<int32_t> TopkIndex(const T *vec, int32_t size, int32_t topk) {
  88. std::vector<int32_t> vec_index(size);
  89. std::iota(vec_index.begin(), vec_index.end(), 0);
  90. std::sort(vec_index.begin(), vec_index.end(),
  91. [vec](int32_t index_1, int32_t index_2) {
  92. return vec[index_1] > vec[index_2];
  93. });
  94. int32_t k_num = std::min<int32_t>(size, topk);
  95. std::vector<int32_t> index(vec_index.begin(), vec_index.begin() + k_num);
  96. return index;
  97. }
  98. } // namespace sherpa_ncnn
  99. #endif // SHERPA_NCNN_CSRC_MATH_H_