main_dxy.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283
  1. from __future__ import division
  2. from senet import *
  3. import os, sys, shutil, time, random
  4. import argparse
  5. import torch
  6. import torch.backends.cudnn as cudnn
  7. import torchvision.datasets as dset
  8. import torchvision.transforms as transforms
  9. from utils import AverageMeter, RecorderMeter, time_string, convert_secs2time
  10. parser = argparse.ArgumentParser(description='Trains ResNeXt on CIFAR or ImageNet', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
  11. parser.add_argument('--data_path', default='./data', type=str, help='Path to dataset')
  12. parser.add_argument('--dataset', default='cifar10', type=str, choices=['cifar10', 'cifar100', 'imagenet', 'svhn', 'stl10'], help='Choose between Cifar10/100 and ImageNet.')
  13. # Optimization options
  14. parser.add_argument('--epochs', type=int, default=300, help='Number of epochs to train.')
  15. parser.add_argument('--batch_size', type=int, default=64, help='Batch size.')
  16. parser.add_argument('--learning_rate', type=float, default=0.05, help='The Learning Rate.')
  17. parser.add_argument('--momentum', type=float, default=0.9, help='Momentum.')
  18. parser.add_argument('--decay', type=float, default=0.0005, help='Weight decay (L2 penalty).')
  19. parser.add_argument('--schedule', type=int, nargs='+', default=[150, 225], help='Decrease learning rate at these epochs.')
  20. parser.add_argument('--gammas', type=float, nargs='+', default=[0.1, 0.1], help='LR is multiplied by gamma on schedule, number of gammas should be equal to schedule')
  21. # Checkpoints
  22. parser.add_argument('--print_freq', default=200, type=int, metavar='N', help='print frequency (default: 200)')
  23. parser.add_argument('--save_path', type=str, default='./', help='Folder to save checkpoints and log.')
  24. parser.add_argument('--resume', default='', type=str, metavar='PATH', help='path to latest checkpoint (default: none)')
  25. parser.add_argument('--start_epoch', default=0, type=int, metavar='N', help='manual epoch number (useful on restarts)')
  26. parser.add_argument('--evaluate', dest='evaluate', action='store_true', help='evaluate model on validation set')
  27. # Acceleration
  28. parser.add_argument('--workers', type=int, default=2, help='number of data loading workers (default: 2)')
  29. # random seed
  30. parser.add_argument('--manualSeed', type=int, help='manual seed')
  31. args = parser.parse_args()
  32. args.use_cuda = torch.cuda.is_available()
  33. torch.cuda.set_device(0)
  34. if args.manualSeed is None: args.manualSeed = random.randint(1, 10000)
  35. random.seed(args.manualSeed)
  36. torch.manual_seed(args.manualSeed)
  37. if args.use_cuda: torch.cuda.manual_seed_all(args.manualSeed)
  38. cudnn.benchmark = True
  39. def main():
  40. if not os.path.isdir(args.save_path): os.makedirs(args.save_path)
  41. log = open(os.path.join(args.save_path, 'log_seed_{}.txt'.format(args.manualSeed)), 'w')
  42. print_log('save path : {}'.format(args.save_path), log)
  43. state = {k: v for k, v in args._get_kwargs()}
  44. print_log(state, log)
  45. print_log("Random Seed: {}".format(args.manualSeed), log)
  46. print_log("python version : {}".format(sys.version.replace('\n', ' ')), log)
  47. print_log("torch version : {}".format(torch.__version__), log)
  48. print_log("cudnn version : {}".format(torch.backends.cudnn.version()), log)
  49. # Init dataset
  50. if not os.path.isdir(args.data_path):
  51. os.makedirs(args.data_path)
  52. if args.dataset == 'cifar10':
  53. mean = [x / 255 for x in [125.3, 123.0, 113.9]]
  54. std = [x / 255 for x in [63.0, 62.1, 66.7]]
  55. elif args.dataset == 'cifar100':
  56. mean = [x / 255 for x in [129.3, 124.1, 112.4]]
  57. std = [x / 255 for x in [68.2, 65.4, 70.4]]
  58. else:
  59. assert False, "Unknow dataset : {}".format(args.dataset)
  60. train_transform = transforms.Compose(
  61. [transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, padding=4), transforms.ToTensor(),
  62. transforms.Normalize(mean, std)])
  63. test_transform = transforms.Compose(
  64. [transforms.ToTensor(), transforms.Normalize(mean, std)])
  65. if args.dataset == 'cifar10':
  66. train_data = dset.CIFAR10(args.data_path, train=True, transform=train_transform, download=True)
  67. test_data = dset.CIFAR10(args.data_path, train=False, transform=test_transform, download=True)
  68. num_classes = 10
  69. elif args.dataset == 'cifar100':
  70. train_data = dset.CIFAR100(args.data_path, train=True, transform=train_transform, download=True)
  71. test_data = dset.CIFAR100(args.data_path, train=False, transform=test_transform, download=True)
  72. num_classes = 100
  73. elif args.dataset == 'svhn':
  74. train_data = dset.SVHN(args.data_path, split='train', transform=train_transform, download=True)
  75. test_data = dset.SVHN(args.data_path, split='test', transform=test_transform, download=True)
  76. num_classes = 10
  77. elif args.dataset == 'stl10':
  78. train_data = dset.STL10(args.data_path, split='train', transform=train_transform, download=True)
  79. test_data = dset.STL10(args.data_path, split='test', transform=test_transform, download=True)
  80. num_classes = 10
  81. elif args.dataset == 'imagenet':
  82. assert False, 'Do not finish imagenet code'
  83. else:
  84. assert False, 'Do not support dataset : {}'.format(args.dataset)
  85. train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True,
  86. num_workers=args.workers, pin_memory=True)
  87. test_loader = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size, shuffle=False,
  88. num_workers=args.workers, pin_memory=True)
  89. # Init model, criterion, and optimizer
  90. #net = models.__dict__[args.arch](num_classes).cuda()
  91. net = SENet34()
  92. # define loss function (criterion) and optimizer
  93. criterion = F.nll_loss
  94. optimizer = torch.optim.SGD(net.parameters(), state['learning_rate'], momentum=state['momentum'],
  95. weight_decay=state['decay'], nesterov=True)
  96. if args.use_cuda: net.cuda()
  97. recorder = RecorderMeter(args.epochs)
  98. # optionally resume from a checkpoint
  99. if args.resume:
  100. if os.path.isfile(args.resume):
  101. print_log("=> loading checkpoint '{}'".format(args.resume), log)
  102. checkpoint = torch.load(args.resume)
  103. recorder = checkpoint['recorder']
  104. args.start_epoch = checkpoint['epoch']
  105. net.load_state_dict(checkpoint['state_dict'])
  106. optimizer.load_state_dict(checkpoint['optimizer'])
  107. print_log("=> loaded checkpoint '{}' (epoch {})" .format(args.resume, checkpoint['epoch']), log)
  108. else:
  109. print_log("=> no checkpoint found at '{}'".format(args.resume), log)
  110. else:
  111. print_log("=> do not use any checkpoint for model", log)
  112. if args.evaluate:
  113. validate(test_loader, net, criterion, log)
  114. return
  115. # Main loop
  116. start_time = time.time()
  117. epoch_time = AverageMeter()
  118. for epoch in range(args.start_epoch, args.epochs):
  119. current_learning_rate = adjust_learning_rate(optimizer, epoch, args.gammas, args.schedule)
  120. need_hour, need_mins, need_secs = convert_secs2time(epoch_time.avg * (args.epochs-epoch))
  121. need_time = '[Need: {:02d}:{:02d}:{:02d}]'.format(need_hour, need_mins, need_secs)
  122. print_log('\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} [learning_rate={:6.4f}]'.format(time_string(), epoch, args.epochs, need_time, current_learning_rate) \
  123. + ' [Best : Accuracy={:.2f}, Error={:.2f}]'.format(recorder.max_accuracy(False), 100-recorder.max_accuracy(False)), log)
  124. # train for one epoch
  125. train_acc, train_los = train(train_loader, net, criterion, optimizer, epoch, log)
  126. # evaluate on validation set
  127. val_acc, val_los = validate(test_loader, net, criterion, log)
  128. is_best = recorder.update(epoch, train_los, train_acc, val_los, val_acc)
  129. save_checkpoint({
  130. 'epoch': epoch + 1,
  131. 'state_dict': net.state_dict(),
  132. 'recorder': recorder,
  133. 'optimizer' : optimizer.state_dict(),
  134. }, is_best, args.save_path, 'checkpoint.pth.tar')
  135. # measure elapsed time
  136. epoch_time.update(time.time() - start_time)
  137. start_time = time.time()
  138. recorder.plot_curve( os.path.join(args.save_path, 'curve.png') )
  139. log.close()
  140. # train function (forward, backward, update)
  141. def train(train_loader, model, criterion, optimizer, epoch, log):
  142. batch_time = AverageMeter()
  143. data_time = AverageMeter()
  144. losses = AverageMeter()
  145. top1 = AverageMeter()
  146. top5 = AverageMeter()
  147. # switch to train mode
  148. model.train()
  149. end = time.time()
  150. for i, (input, target) in enumerate(train_loader):
  151. # measure data loading time
  152. data_time.update(time.time() - end)
  153. if args.use_cuda:
  154. target = target.cuda(non_blocking=True)
  155. input = input.cuda()
  156. input_var = torch.autograd.Variable(input)
  157. target_var = torch.autograd.Variable(target)
  158. # compute output
  159. output = model(input_var)
  160. loss = criterion(output, target_var)
  161. # measure accuracy and record loss
  162. prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
  163. losses.update(loss.data[0], input.size(0))
  164. top1.update(prec1[0], input.size(0))
  165. top5.update(prec5[0], input.size(0))
  166. # compute gradient and do SGD step
  167. optimizer.zero_grad()
  168. loss.backward()
  169. optimizer.step()
  170. # measure elapsed time
  171. batch_time.update(time.time() - end)
  172. end = time.time()
  173. print_log(' Epoch: [{:03d}][{:03d}/{:03d}] '
  174. 'Time {batch_time.val:.3f} ({batch_time.avg:.3f}) '
  175. 'Data {data_time.val:.3f} ({data_time.avg:.3f}) '
  176. 'Loss {loss.val:.4f} ({loss.avg:.4f}) '
  177. 'Prec@1 {top1.val:.3f} ({top1.avg:.3f}) '
  178. 'Prec@5 {top5.val:.3f} ({top5.avg:.3f}) '.format(
  179. epoch, i, len(train_loader), batch_time=batch_time,
  180. data_time=data_time, loss=losses, top1=top1, top5=top5) + time_string(), log)
  181. return top1.avg, losses.avg
  182. def validate(val_loader, model, criterion, log):
  183. losses = AverageMeter()
  184. top1 = AverageMeter()
  185. top5 = AverageMeter()
  186. # switch to evaluate mode
  187. model.eval()
  188. for i, (input, target) in enumerate(val_loader):
  189. if args.use_cuda:
  190. target = target.cuda(non_blocking=True)
  191. input = input.cuda()
  192. input_var = torch.autograd.Variable(input, volatile=True)
  193. target_var = torch.autograd.Variable(target, volatile=True)
  194. # compute output
  195. output = model(input_var)
  196. loss = criterion(output, target_var)
  197. # measure accuracy and record loss
  198. prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
  199. losses.update(loss.data[0], input.size(0))
  200. top1.update(prec1[0], input.size(0))
  201. top5.update(prec5[0], input.size(0))
  202. print_log(' **Test** Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f} Error@1 {error1:.3f}'.format(top1=top1, top5=top5, error1=100-top1.avg), log)
  203. return top1.avg, losses.avg
  204. def print_log(print_string, log):
  205. print("{}".format(print_string))
  206. log.write('{}\n'.format(print_string))
  207. log.flush()
  208. def save_checkpoint(state, is_best, save_path, filename):
  209. filename = os.path.join(save_path, filename)
  210. torch.save(state, filename)
  211. if is_best:
  212. bestname = os.path.join(save_path, 'model_best.pth.tar')
  213. shutil.copyfile(filename, bestname)
  214. def adjust_learning_rate(optimizer, epoch, gammas, schedule):
  215. """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
  216. lr = args.learning_rate
  217. assert len(gammas) == len(schedule), "length of gammas and schedule should be equal"
  218. for (gamma, step) in zip(gammas, schedule):
  219. if (epoch >= step):
  220. lr = lr * gamma
  221. else:
  222. break
  223. for param_group in optimizer.param_groups:
  224. param_group['lr'] = lr
  225. return lr
  226. def accuracy(output, target, topk=(1,)):
  227. """Computes the precision@k for the specified values of k"""
  228. maxk = max(topk)
  229. batch_size = target.size(0)
  230. _, pred = output.topk(maxk, 1, True, True)
  231. pred = pred.t()
  232. correct = pred.eq(target.view(1, -1).expand_as(pred))
  233. res = []
  234. for k in topk:
  235. correct_k = correct[:k].view(-1).float().sum(0)
  236. res.append(correct_k.mul_(100.0 / batch_size))
  237. return res
  238. if __name__ == '__main__':
  239. main()