losses.py 415 B

123456789101112
  1. from .imports import *
  2. from .torch_imports import *
  3. def fbeta_torch(y_true, y_pred, beta, threshold, eps=1e-9):
  4. y_pred = (y_pred.float() > threshold).float()
  5. y_true = y_true.float()
  6. tp = (y_pred * y_true).sum(dim=1)
  7. precision = tp / (y_pred.sum(dim=1)+eps)
  8. recall = tp / (y_true.sum(dim=1)+eps)
  9. return torch.mean(
  10. precision*recall / (precision*(beta**2)+recall+eps) * (1+beta**2))