utils_kuangliu.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. '''Some helper functions for PyTorch, including:
  2. - get_mean_and_std: calculate the mean and std value of dataset.
  3. - msr_init: net parameter initialization.
  4. - progress_bar: progress bar mimic xlua.progress.
  5. '''
  6. import os
  7. import sys
  8. import time
  9. import math
  10. import torch
  11. import torch.nn as nn
  12. import torch.nn.init as init
  13. def get_mean_and_std(dataset):
  14. '''Compute the mean and std value of dataset.'''
  15. dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2)
  16. mean = torch.zeros(3)
  17. std = torch.zeros(3)
  18. print('==> Computing mean and std..')
  19. for inputs, targets in dataloader:
  20. for i in range(3):
  21. mean[i] += inputs[:,i,:,:].mean()
  22. std[i] += inputs[:,i,:,:].std()
  23. mean.div_(len(dataset))
  24. std.div_(len(dataset))
  25. return mean, std
  26. def init_params(net):
  27. '''Init layer parameters.'''
  28. for m in net.modules():
  29. if isinstance(m, nn.Conv2d):
  30. init.kaiming_normal(m.weight, mode='fan_out')
  31. if m.bias:
  32. init.constant(m.bias, 0)
  33. elif isinstance(m, nn.BatchNorm2d):
  34. init.constant(m.weight, 1)
  35. init.constant(m.bias, 0)
  36. elif isinstance(m, nn.Linear):
  37. init.normal(m.weight, std=1e-3)
  38. if m.bias:
  39. init.constant(m.bias, 0)
  40. _, term_width = os.popen('stty size', 'r').read().split()
  41. term_width = int(term_width)
  42. TOTAL_BAR_LENGTH = 65.
  43. last_time = time.time()
  44. begin_time = last_time
  45. def progress_bar(current, total, msg=None):
  46. global last_time, begin_time
  47. if current == 0:
  48. begin_time = time.time() # Reset for new bar.
  49. cur_len = int(TOTAL_BAR_LENGTH*current/total)
  50. rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1
  51. sys.stdout.write(' [')
  52. for i in range(cur_len):
  53. sys.stdout.write('=')
  54. sys.stdout.write('>')
  55. for i in range(rest_len):
  56. sys.stdout.write('.')
  57. sys.stdout.write(']')
  58. cur_time = time.time()
  59. step_time = cur_time - last_time
  60. last_time = cur_time
  61. tot_time = cur_time - begin_time
  62. L = []
  63. L.append(' Step: %s' % format_time(step_time))
  64. L.append(' | Tot: %s' % format_time(tot_time))
  65. if msg:
  66. L.append(' | ' + msg)
  67. msg = ''.join(L)
  68. sys.stdout.write(msg)
  69. for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3):
  70. sys.stdout.write(' ')
  71. # Go back to the center of the bar.
  72. for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2):
  73. sys.stdout.write('\b')
  74. sys.stdout.write(' %d/%d ' % (current+1, total))
  75. if current < total-1:
  76. sys.stdout.write('\r')
  77. else:
  78. sys.stdout.write('\n')
  79. sys.stdout.flush()
  80. def format_time(seconds):
  81. days = int(seconds / 3600/24)
  82. seconds = seconds - days*3600*24
  83. hours = int(seconds / 3600)
  84. seconds = seconds - hours*3600
  85. minutes = int(seconds / 60)
  86. seconds = seconds - minutes*60
  87. secondsf = int(seconds)
  88. seconds = seconds - secondsf
  89. millis = int(seconds*1000)
  90. f = ''
  91. i = 1
  92. if days > 0:
  93. f += str(days) + 'D'
  94. i += 1
  95. if hours > 0 and i <= 2:
  96. f += str(hours) + 'h'
  97. i += 1
  98. if minutes > 0 and i <= 2:
  99. f += str(minutes) + 'm'
  100. i += 1
  101. if secondsf > 0 and i <= 2:
  102. f += str(secondsf) + 's'
  103. i += 1
  104. if millis > 0 and i <= 2:
  105. f += str(millis) + 'ms'
  106. i += 1
  107. if f == '':
  108. f = '0ms'
  109. return f