main_kuangliu.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. '''Train CIFAR10 with PyTorch.'''
  2. from __future__ import print_function
  3. import torch
  4. import torch.nn as nn
  5. import torch.optim as optim
  6. import torch.nn.functional as F
  7. import torch.backends.cudnn as cudnn
  8. import torchvision
  9. import torchvision.transforms as transforms
  10. import os
  11. import argparse
  12. from senet import *
  13. from utils import progress_bar
  14. from torch.autograd import Variable
  15. parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')
  16. parser.add_argument('--lr', default=0.1, type=float, help='learning rate')
  17. parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint')
  18. args = parser.parse_args()
  19. use_cuda = torch.cuda.is_available()
  20. torch.cuda.set_device(3)
  21. best_acc = 0 # best test accuracy
  22. start_epoch = 0 # start from epoch 0 or last checkpoint epoch
  23. # Data
  24. print('==> Preparing data..')
  25. transform_train = transforms.Compose([
  26. transforms.RandomCrop(32, padding=4),
  27. transforms.RandomHorizontalFlip(),
  28. transforms.ToTensor(),
  29. transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
  30. ])
  31. transform_test = transforms.Compose([
  32. transforms.ToTensor(),
  33. transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
  34. ])
  35. trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
  36. trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=4)
  37. testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
  38. testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=4)
  39. classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
  40. # Model
  41. if args.resume:
  42. # Load checkpoint.
  43. print('==> Resuming from checkpoint..')
  44. assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
  45. checkpoint = torch.load('./checkpoint/ckpt.t7')
  46. net = checkpoint['net']
  47. best_acc = checkpoint['acc']
  48. start_epoch = checkpoint['epoch']
  49. else:
  50. print('==> Building model..')
  51. # net = VGG('VGG19')
  52. # net = ResNet18()
  53. # net = PreActResNet18()
  54. # net = GoogLeNet()
  55. # net = DenseNet121()
  56. # net = ResNeXt29_2x64d()
  57. # net = MobileNet()
  58. # net = DPN92()
  59. # net = ShuffleNetG2()
  60. net = SENet18()
  61. if use_cuda:
  62. net.cuda()
  63. #net = torch.nn.DataParallel(net, device_ids=(0,3))
  64. #net = torch.nn.DataParallel(net, device_ids=range(torch.cuda.device_count()))
  65. cudnn.benchmark = True
  66. criterion = F.nll_loss
  67. optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)
  68. # Training
  69. def train(epoch):
  70. print('\nEpoch: %d' % epoch)
  71. net.train()
  72. train_loss = 0
  73. correct = 0
  74. total = 0
  75. for batch_idx, (inputs, targets) in enumerate(trainloader):
  76. if use_cuda: inputs, targets = inputs.cuda(), targets.cuda()
  77. optimizer.zero_grad()
  78. inputs, targets = Variable(inputs), Variable(targets)
  79. outputs = net(inputs)
  80. loss = criterion(outputs, targets)
  81. loss.backward()
  82. optimizer.step()
  83. train_loss += loss.data[0]
  84. _, predicted = torch.max(outputs.data, 1)
  85. total += targets.size(0)
  86. correct += predicted.eq(targets.data).cpu().sum()
  87. progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
  88. % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))
  89. def test(epoch):
  90. global best_acc
  91. net.eval()
  92. test_loss = 0
  93. correct = 0
  94. total = 0
  95. for batch_idx, (inputs, targets) in enumerate(testloader):
  96. if use_cuda: inputs, targets = inputs.cuda(), targets.cuda()
  97. inputs, targets = Variable(inputs, volatile=True), Variable(targets)
  98. outputs = net(inputs)
  99. loss = criterion(outputs, targets)
  100. test_loss += loss.data[0]
  101. _, predicted = torch.max(outputs.data, 1)
  102. total += targets.size(0)
  103. correct += predicted.eq(targets.data).cpu().sum()
  104. progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
  105. % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))
  106. # Save checkpoint.
  107. acc = 100.*correct/total
  108. if acc > best_acc:
  109. print('Saving..')
  110. state = {
  111. 'net': net,
  112. 'acc': acc,
  113. 'epoch': epoch,
  114. }
  115. if not os.path.isdir('checkpoint'):
  116. os.mkdir('checkpoint')
  117. torch.save(state, './checkpoint/ckpt.t7')
  118. best_acc = acc
  119. for epoch in range(start_epoch, start_epoch+100):
  120. train(epoch)
  121. test(epoch)