utils_kuangliu.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. '''Some helper functions for PyTorch, including:
  2. - get_mean_and_std: calculate the mean and std value of dataset.
  3. - msr_init: net parameter initialization.
  4. - progress_bar: progress bar mimic xlua.progress.
  5. '''
  6. import os
  7. import sys
  8. import time
  9. import math
  10. import torch.nn as nn
  11. import torch.nn.init as init
  12. def get_mean_and_std(dataset):
  13. '''Compute the mean and std value of dataset.'''
  14. dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2)
  15. mean = torch.zeros(3)
  16. std = torch.zeros(3)
  17. print('==> Computing mean and std..')
  18. for inputs, targets in dataloader:
  19. for i in range(3):
  20. mean[i] += inputs[:,i,:,:].mean()
  21. std[i] += inputs[:,i,:,:].std()
  22. mean.div_(len(dataset))
  23. std.div_(len(dataset))
  24. return mean, std
  25. def init_params(net):
  26. '''Init layer parameters.'''
  27. for m in net.modules():
  28. if isinstance(m, nn.Conv2d):
  29. init.kaiming_normal(m.weight, mode='fan_out')
  30. if m.bias:
  31. init.constant(m.bias, 0)
  32. elif isinstance(m, nn.BatchNorm2d):
  33. init.constant(m.weight, 1)
  34. init.constant(m.bias, 0)
  35. elif isinstance(m, nn.Linear):
  36. init.normal(m.weight, std=1e-3)
  37. if m.bias:
  38. init.constant(m.bias, 0)
  39. _, term_width = os.popen('stty size', 'r').read().split()
  40. term_width = int(term_width)
  41. TOTAL_BAR_LENGTH = 65.
  42. last_time = time.time()
  43. begin_time = last_time
  44. def progress_bar(current, total, msg=None):
  45. global last_time, begin_time
  46. if current == 0:
  47. begin_time = time.time() # Reset for new bar.
  48. cur_len = int(TOTAL_BAR_LENGTH*current/total)
  49. rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1
  50. sys.stdout.write(' [')
  51. for i in range(cur_len):
  52. sys.stdout.write('=')
  53. sys.stdout.write('>')
  54. for i in range(rest_len):
  55. sys.stdout.write('.')
  56. sys.stdout.write(']')
  57. cur_time = time.time()
  58. step_time = cur_time - last_time
  59. last_time = cur_time
  60. tot_time = cur_time - begin_time
  61. L = []
  62. L.append(' Step: %s' % format_time(step_time))
  63. L.append(' | Tot: %s' % format_time(tot_time))
  64. if msg:
  65. L.append(' | ' + msg)
  66. msg = ''.join(L)
  67. sys.stdout.write(msg)
  68. for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3):
  69. sys.stdout.write(' ')
  70. # Go back to the center of the bar.
  71. for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2):
  72. sys.stdout.write('\b')
  73. sys.stdout.write(' %d/%d ' % (current+1, total))
  74. if current < total-1:
  75. sys.stdout.write('\r')
  76. else:
  77. sys.stdout.write('\n')
  78. sys.stdout.flush()
  79. def format_time(seconds):
  80. days = int(seconds / 3600/24)
  81. seconds = seconds - days*3600*24
  82. hours = int(seconds / 3600)
  83. seconds = seconds - hours*3600
  84. minutes = int(seconds / 60)
  85. seconds = seconds - minutes*60
  86. secondsf = int(seconds)
  87. seconds = seconds - secondsf
  88. millis = int(seconds*1000)
  89. f = ''
  90. i = 1
  91. if days > 0:
  92. f += str(days) + 'D'
  93. i += 1
  94. if hours > 0 and i <= 2:
  95. f += str(hours) + 'h'
  96. i += 1
  97. if minutes > 0 and i <= 2:
  98. f += str(minutes) + 'm'
  99. i += 1
  100. if secondsf > 0 and i <= 2:
  101. f += str(secondsf) + 's'
  102. i += 1
  103. if millis > 0 and i <= 2:
  104. f += str(millis) + 'ms'
  105. i += 1
  106. if f == '':
  107. f = '0ms'
  108. return f