'''Train CIFAR10 with PyTorch.''' from __future__ import print_function import torch import torch.nn as nn import torch.optim as optim import torch.nn.functional as F import torch.backends.cudnn as cudnn import torchvision import torchvision.transforms as transforms import os import argparse from senet import * from utils import progress_bar from torch.autograd import Variable parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training') parser.add_argument('--lr', default=0.1, type=float, help='learning rate') parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint') args = parser.parse_args() use_cuda = torch.cuda.is_available() torch.cuda.set_device(3) best_acc = 0 # best test accuracy start_epoch = 0 # start from epoch 0 or last checkpoint epoch # Data print('==> Preparing data..') transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train) trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=4) testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test) testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=4) classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') # Model if args.resume: # Load checkpoint. print('==> Resuming from checkpoint..') assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!' checkpoint = torch.load('./checkpoint/ckpt.t7') net = checkpoint['net'] best_acc = checkpoint['acc'] start_epoch = checkpoint['epoch'] else: print('==> Building model..') # net = VGG('VGG19') # net = ResNet18() # net = PreActResNet18() # net = GoogLeNet() # net = DenseNet121() # net = ResNeXt29_2x64d() # net = MobileNet() # net = DPN92() # net = ShuffleNetG2() net = SENet18() if use_cuda: net.cuda() #net = torch.nn.DataParallel(net, device_ids=(0,3)) #net = torch.nn.DataParallel(net, device_ids=range(torch.cuda.device_count())) cudnn.benchmark = True criterion = F.nll_loss optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4) # Training def train(epoch): print('\nEpoch: %d' % epoch) net.train() train_loss = 0 correct = 0 total = 0 for batch_idx, (inputs, targets) in enumerate(trainloader): if use_cuda: inputs, targets = inputs.cuda(), targets.cuda() optimizer.zero_grad() inputs, targets = Variable(inputs), Variable(targets) outputs = net(inputs) loss = criterion(outputs, targets) loss.backward() optimizer.step() train_loss += loss.data[0] _, predicted = torch.max(outputs.data, 1) total += targets.size(0) correct += predicted.eq(targets.data).cpu().sum() progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' % (train_loss/(batch_idx+1), 100.*correct/total, correct, total)) def test(epoch): global best_acc net.eval() test_loss = 0 correct = 0 total = 0 for batch_idx, (inputs, targets) in enumerate(testloader): if use_cuda: inputs, targets = inputs.cuda(), targets.cuda() inputs, targets = Variable(inputs, volatile=True), Variable(targets) outputs = net(inputs) loss = criterion(outputs, targets) test_loss += loss.data[0] _, predicted = torch.max(outputs.data, 1) total += targets.size(0) correct += predicted.eq(targets.data).cpu().sum() progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' % (test_loss/(batch_idx+1), 100.*correct/total, correct, total)) # Save checkpoint. acc = 100.*correct/total if acc > best_acc: print('Saving..') state = { 'net': net, 'acc': acc, 'epoch': epoch, } if not os.path.isdir('checkpoint'): os.mkdir('checkpoint') torch.save(state, './checkpoint/ckpt.t7') best_acc = acc for epoch in range(start_epoch, start_epoch+100): train(epoch) test(epoch)