123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145 |
- '''Train CIFAR10 with PyTorch.'''
- from __future__ import print_function
- import torch
- import torch.nn as nn
- import torch.optim as optim
- import torch.nn.functional as F
- import torch.backends.cudnn as cudnn
- import torchvision
- import torchvision.transforms as transforms
- import os
- import argparse
- from senet import *
- from utils import progress_bar
- from torch.autograd import Variable
- parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')
- parser.add_argument('--lr', default=0.1, type=float, help='learning rate')
- parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint')
- args = parser.parse_args()
- use_cuda = torch.cuda.is_available()
- torch.cuda.set_device(3)
- best_acc = 0 # best test accuracy
- start_epoch = 0 # start from epoch 0 or last checkpoint epoch
- # Data
- print('==> Preparing data..')
- transform_train = transforms.Compose([
- transforms.RandomCrop(32, padding=4),
- transforms.RandomHorizontalFlip(),
- transforms.ToTensor(),
- transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
- ])
- transform_test = transforms.Compose([
- transforms.ToTensor(),
- transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
- ])
- trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
- trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=4)
- testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
- testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=4)
- classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
- # Model
- if args.resume:
- # Load checkpoint.
- print('==> Resuming from checkpoint..')
- assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
- checkpoint = torch.load('./checkpoint/ckpt.t7')
- net = checkpoint['net']
- best_acc = checkpoint['acc']
- start_epoch = checkpoint['epoch']
- else:
- print('==> Building model..')
- # net = VGG('VGG19')
- # net = ResNet18()
- # net = PreActResNet18()
- # net = GoogLeNet()
- # net = DenseNet121()
- # net = ResNeXt29_2x64d()
- # net = MobileNet()
- # net = DPN92()
- # net = ShuffleNetG2()
- net = SENet18()
- if use_cuda:
- net.cuda()
- #net = torch.nn.DataParallel(net, device_ids=(0,3))
- #net = torch.nn.DataParallel(net, device_ids=range(torch.cuda.device_count()))
- cudnn.benchmark = True
- criterion = F.nll_loss
- optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)
- # Training
- def train(epoch):
- print('\nEpoch: %d' % epoch)
- net.train()
- train_loss = 0
- correct = 0
- total = 0
- for batch_idx, (inputs, targets) in enumerate(trainloader):
- if use_cuda: inputs, targets = inputs.cuda(), targets.cuda()
- optimizer.zero_grad()
- inputs, targets = Variable(inputs), Variable(targets)
- outputs = net(inputs)
- loss = criterion(outputs, targets)
- loss.backward()
- optimizer.step()
- train_loss += loss.data[0]
- _, predicted = torch.max(outputs.data, 1)
- total += targets.size(0)
- correct += predicted.eq(targets.data).cpu().sum()
- progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
- % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))
- def test(epoch):
- global best_acc
- net.eval()
- test_loss = 0
- correct = 0
- total = 0
- for batch_idx, (inputs, targets) in enumerate(testloader):
- if use_cuda: inputs, targets = inputs.cuda(), targets.cuda()
- inputs, targets = Variable(inputs, volatile=True), Variable(targets)
- outputs = net(inputs)
- loss = criterion(outputs, targets)
- test_loss += loss.data[0]
- _, predicted = torch.max(outputs.data, 1)
- total += targets.size(0)
- correct += predicted.eq(targets.data).cpu().sum()
- progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
- % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))
- # Save checkpoint.
- acc = 100.*correct/total
- if acc > best_acc:
- print('Saving..')
- state = {
- 'net': net,
- 'acc': acc,
- 'epoch': epoch,
- }
- if not os.path.isdir('checkpoint'):
- os.mkdir('checkpoint')
- torch.save(state, './checkpoint/ckpt.t7')
- best_acc = acc
- for epoch in range(start_epoch, start_epoch+100):
- train(epoch)
- test(epoch)
|