metrics.py 1.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445
  1. from .imports import *
  2. from .torch_imports import *
  3. def accuracy_np(preds, targs):
  4. preds = np.argmax(preds, 1)
  5. return (preds==targs).mean()
  6. def accuracy(preds, targs):
  7. preds = torch.max(preds, dim=1)[1]
  8. return (preds==targs).float().mean()
  9. def accuracy_thresh(thresh):
  10. return lambda preds,targs: accuracy_multi(preds, targs, thresh)
  11. def accuracy_multi(preds, targs, thresh):
  12. return ((preds>thresh).float()==targs).float().mean()
  13. def accuracy_multi_np(preds, targs, thresh):
  14. return ((preds>thresh)==targs).mean()
  15. def recall(preds, targs, thresh=0.5):
  16. pred_pos = preds > thresh
  17. tpos = torch.mul((targs.byte() == pred_pos), targs.byte())
  18. return tpos.sum()/targs.sum()
  19. def precision(preds, targs, thresh=0.5):
  20. pred_pos = preds > thresh
  21. tpos = torch.mul((targs.byte() == pred_pos), targs.byte())
  22. return tpos.sum()/pred_pos.sum()
  23. def fbeta(preds, targs, beta, thresh=0.5):
  24. """Calculates the F-beta score (the weighted harmonic mean of precision and recall).
  25. This is the micro averaged version where the true positives, false negatives and
  26. false positives are calculated globally (as opposed to on a per label basis).
  27. beta == 1 places equal weight on precision and recall, b < 1 emphasizes precision and
  28. beta > 1 favors recall.
  29. """
  30. assert beta > 0, 'beta needs to be greater than 0'
  31. beta2 = beta ** 2
  32. rec = recall(preds, targs, thresh)
  33. prec = precision(preds, targs, thresh)
  34. return (1 + beta2) * prec * rec / (beta2 * prec + rec)
  35. def f1(preds, targs, thresh=0.5): return fbeta(preds, targs, 1, thresh)