utils.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. import os, sys, time
  2. import random
  3. import torch
  4. import numpy as np
  5. import matplotlib
  6. matplotlib.use('agg')
  7. import matplotlib.pyplot as plt
  8. class AverageMeter(object):
  9. """Computes and stores the average and current value"""
  10. def __init__(self):
  11. self.reset()
  12. def reset(self):
  13. self.val = 0
  14. self.avg = 0
  15. self.sum = 0
  16. self.count = 0
  17. def update(self, val, n=1):
  18. self.val = val
  19. self.sum += val * n
  20. self.count += n
  21. self.avg = self.sum / self.count
  22. class RecorderMeter(object):
  23. """Computes and stores the minimum loss value and its epoch index"""
  24. def __init__(self, total_epoch):
  25. self.reset(total_epoch)
  26. def reset(self, total_epoch):
  27. assert total_epoch > 0
  28. self.total_epoch = total_epoch
  29. self.current_epoch = 0
  30. self.epoch_losses = np.zeros((self.total_epoch, 2), dtype=np.float32) # [epoch, train/val]
  31. self.epoch_losses = self.epoch_losses - 1
  32. self.epoch_accuracy= np.zeros((self.total_epoch, 2), dtype=np.float32) # [epoch, train/val]
  33. self.epoch_accuracy= self.epoch_accuracy
  34. def update(self, idx, train_loss, train_acc, val_loss, val_acc):
  35. assert idx >= 0 and idx < self.total_epoch, 'total_epoch : {} , but update with the {} index'.format(self.total_epoch, idx)
  36. self.epoch_losses [idx, 0] = train_loss
  37. self.epoch_losses [idx, 1] = val_loss
  38. self.epoch_accuracy[idx, 0] = train_acc
  39. self.epoch_accuracy[idx, 1] = val_acc
  40. self.current_epoch = idx + 1
  41. return self.max_accuracy(False) == val_acc
  42. def max_accuracy(self, istrain):
  43. if self.current_epoch <= 0: return 0
  44. if istrain: return self.epoch_accuracy[:self.current_epoch, 0].max()
  45. else: return self.epoch_accuracy[:self.current_epoch, 1].max()
  46. def plot_curve(self, save_path):
  47. title = 'the accuracy/loss curve of train/val'
  48. dpi = 80
  49. width, height = 1200, 800
  50. legend_fontsize = 10
  51. scale_distance = 48.8
  52. figsize = width / float(dpi), height / float(dpi)
  53. fig = plt.figure(figsize=figsize)
  54. x_axis = np.array([i for i in range(self.total_epoch)]) # epochs
  55. y_axis = np.zeros(self.total_epoch)
  56. plt.xlim(0, self.total_epoch)
  57. plt.ylim(0, 100)
  58. interval_y = 5
  59. interval_x = 5
  60. plt.xticks(np.arange(0, self.total_epoch + interval_x, interval_x))
  61. plt.yticks(np.arange(0, 100 + interval_y, interval_y))
  62. plt.grid()
  63. plt.title(title, fontsize=20)
  64. plt.xlabel('the training epoch', fontsize=16)
  65. plt.ylabel('accuracy', fontsize=16)
  66. y_axis[:] = self.epoch_accuracy[:, 0]
  67. plt.plot(x_axis, y_axis, color='g', linestyle='-', label='train-accuracy', lw=2)
  68. plt.legend(loc=4, fontsize=legend_fontsize)
  69. y_axis[:] = self.epoch_accuracy[:, 1]
  70. plt.plot(x_axis, y_axis, color='y', linestyle='-', label='valid-accuracy', lw=2)
  71. plt.legend(loc=4, fontsize=legend_fontsize)
  72. y_axis[:] = self.epoch_losses[:, 0]
  73. plt.plot(x_axis, y_axis*50, color='g', linestyle=':', label='train-loss-x50', lw=2)
  74. plt.legend(loc=4, fontsize=legend_fontsize)
  75. y_axis[:] = self.epoch_losses[:, 1]
  76. plt.plot(x_axis, y_axis*50, color='y', linestyle=':', label='valid-loss-x50', lw=2)
  77. plt.legend(loc=4, fontsize=legend_fontsize)
  78. if save_path is not None:
  79. fig.savefig(save_path, dpi=dpi, bbox_inches='tight')
  80. print ('---- save figure {} into {}'.format(title, save_path))
  81. plt.close(fig)
  82. def time_string():
  83. ISOTIMEFORMAT='%Y-%m-%d %X'
  84. string = '[{}]'.format(time.strftime( ISOTIMEFORMAT, time.gmtime(time.time()) ))
  85. return string
  86. def convert_secs2time(epoch_time):
  87. need_hour = int(epoch_time / 3600)
  88. need_mins = int((epoch_time - 3600*need_hour) / 60)
  89. need_secs = int(epoch_time - 3600*need_hour - 60*need_mins)
  90. return need_hour, need_mins, need_secs
  91. def time_file_str():
  92. ISOTIMEFORMAT='%Y-%m-%d'
  93. string = '{}'.format(time.strftime( ISOTIMEFORMAT, time.gmtime(time.time()) ))
  94. return string + '-{}'.format(random.randint(1, 10000))