{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%matplotlib inline\n", "%reload_ext autoreload\n", "%autoreload 2" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import multiprocessing\n", "from torch import autograd\n", "from fastai.conv_learner import *\n", "from fastai.transforms import TfmType\n", "from fasterai.transforms import *\n", "from fasterai.images import *\n", "from fasterai.dataset import *\n", "from fasterai.visualize import *\n", "from fasterai.callbacks import *\n", "from fasterai.loss import *\n", "from fasterai.modules import *\n", "from fasterai.training import *\n", "from fasterai.generators import *\n", "from fastai.torch_imports import *\n", "from pathlib import Path\n", "from itertools import repeat\n", "import tensorboardX\n", "torch.cuda.set_device(0)\n", "plt.style.use('dark_background')\n", "torch.backends.cudnn.benchmark=True\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "IMAGENET = Path('data/imagenet/ILSVRC/Data/CLS-LOC/train')\n", "proj_id = 'colorize'\n", "TENSORBOARD_PATH = Path('data/tensorboard/' + proj_id)\n", "\n", "gpath = IMAGENET.parent/(proj_id + '_gen_128.h5')\n", "dpath = IMAGENET.parent/(proj_id + '_critic_128.h5')\n", "\n", "c_lr=5e-4\n", "c_lrs = np.array([c_lr,c_lr,c_lr])\n", "\n", "g_lr=c_lr/5\n", "g_lrs = np.array([g_lr/100,g_lr/10,g_lr])\n", "\n", "keep_pcts=[0.25,0.25]\n", "gen_freeze_tos=[-1,0]\n", "lrs_unfreeze_factor=0.05\n", "x_tfms = [BlackAndWhiteTransform()]\n", "extra_aug_tfms = [RandomLighting(0.1, 0.1, tfm_y=TfmType.PIXEL)]\n", "torch.backends.cudnn.benchmark=True" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Training" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "netG = Unet34(nf_factor=2).cuda()\n", "#netGVis = ModelVisualizationHook(TENSORBOARD_PATH, netG, 'netG')\n", "#load_model(netG, gpath)\n", "\n", "netD = DCCritic(ni=3, nf=256).cuda()\n", "#netDVis = ModelVisualizationHook(TENSORBOARD_PATH, netD, 'netD')\n", "#load_model(netD, dpath)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "trainer = GANTrainer(netD=netD, netG=netG, genloss_fns=[FeatureLoss(multiplier=1e2)])\n", "trainerVis = GANVisualizationHook(TENSORBOARD_PATH, trainer, 'trainer', jupyter=False, visual_iters=100)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "scheds=[]\n", "\n", "scheds.extend(GANTrainSchedule.generate_schedules(szs=[64,64], bss=[128,128], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms,keep_pcts=[1.0,1.0], \n", " save_base_name=proj_id, c_lrs=c_lrs, g_lrs=g_lrs, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=gen_freeze_tos))\n", "\n", "c_lrs=c_lrs/2\n", "g_lrs=g_lrs/2\n", "\n", "#unshock\n", "scheds.extend(GANTrainSchedule.generate_schedules(szs=[96], bss=[64], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.1], \n", " save_base_name=proj_id, c_lrs=c_lrs/10, g_lrs=g_lrs/10, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))\n", "\n", "scheds.extend(GANTrainSchedule.generate_schedules(szs=[96,96], bss=[64,64], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=keep_pcts, \n", " save_base_name=proj_id, c_lrs=c_lrs, g_lrs=g_lrs, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=gen_freeze_tos))\n", "\n", "c_lrs=c_lrs/2\n", "g_lrs=g_lrs/2\n", "\n", "#unshock\n", "scheds.extend(GANTrainSchedule.generate_schedules(szs=[128], bss=[32], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.1], \n", " save_base_name=proj_id, c_lrs=c_lrs/10, g_lrs=g_lrs/10, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))\n", "\n", "scheds.extend(GANTrainSchedule.generate_schedules(szs=[128,128], bss=[32,32], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=keep_pcts, \n", " save_base_name=proj_id, c_lrs=c_lrs, g_lrs=g_lrs, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=gen_freeze_tos))\n", "\n", "\n", "c_lrs=c_lrs/2\n", "g_lrs=g_lrs/2\n", "\n", "#unshock\n", "scheds.extend(GANTrainSchedule.generate_schedules(szs=[160], bss=[20], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.1], \n", " save_base_name=proj_id, c_lrs=c_lrs/10, g_lrs=g_lrs/10, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))\n", "\n", "scheds.extend(GANTrainSchedule.generate_schedules(szs=[160], bss=[20], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.25], \n", " save_base_name=proj_id, c_lrs=c_lrs, g_lrs=g_lrs, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))\n", "\n", "scheds.extend(GANTrainSchedule.generate_schedules(szs=[160], bss=[20], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.25], \n", " save_base_name=proj_id, c_lrs=c_lrs, g_lrs=g_lrs, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[0]))\n", "\n", "\n", "c_lrs=c_lrs/2\n", "g_lrs=g_lrs/2\n", "\n", "#unshock\n", "scheds.extend(GANTrainSchedule.generate_schedules(szs=[192], bss=[12], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.1], \n", " save_base_name=proj_id, c_lrs=c_lrs/10, g_lrs=g_lrs/10, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))\n", "\n", "scheds.extend(GANTrainSchedule.generate_schedules(szs=[192], bss=[12], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.25], \n", " save_base_name=proj_id, c_lrs=c_lrs, g_lrs=g_lrs, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))\n", "\n", "scheds.extend(GANTrainSchedule.generate_schedules(szs=[192], bss=[12], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.25], \n", " save_base_name=proj_id, c_lrs=c_lrs, g_lrs=g_lrs, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[0]))\n", "\n", "\n", "c_lrs=c_lrs/2\n", "g_lrs=g_lrs/2\n", "\n", "#unshock\n", "scheds.extend(GANTrainSchedule.generate_schedules(szs=[224], bss=[8], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.1], \n", " save_base_name=proj_id, c_lrs=c_lrs/10, g_lrs=g_lrs/10, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))\n", "\n", "scheds.extend(GANTrainSchedule.generate_schedules(szs=[224], bss=[8], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.25], \n", " save_base_name=proj_id, c_lrs=c_lrs, g_lrs=g_lrs, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))\n", "\n", "scheds.extend(GANTrainSchedule.generate_schedules(szs=[224], bss=[8], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.25], \n", " save_base_name=proj_id, c_lrs=c_lrs, g_lrs=g_lrs, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[0]))\n", "\n", "c_lrs=c_lrs/2\n", "g_lrs=g_lrs/2\n", "\n", "#unshock\n", "scheds.extend(GANTrainSchedule.generate_schedules(szs=[256], bss=[6], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.1], \n", " save_base_name=proj_id, c_lrs=c_lrs/10, g_lrs=g_lrs/10, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))\n", "\n", "scheds.extend(GANTrainSchedule.generate_schedules(szs=[256], bss=[6], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.25], \n", " save_base_name=proj_id, c_lrs=c_lrs, g_lrs=g_lrs, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))\n", "\n", "scheds.extend(GANTrainSchedule.generate_schedules(szs=[256], bss=[6], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.25], \n", " save_base_name=proj_id, c_lrs=c_lrs, g_lrs=g_lrs, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[0]))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "trainer.train(scheds=scheds)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.5" }, "toc": { "colors": { "hover_highlight": "#DAA520", "navigate_num": "#000000", "navigate_text": "#333333", "running_highlight": "#FF0000", "selected_highlight": "#FFD700", "sidebar_border": "#EEEEEE", "wrapper_background": "#FFFFFF" }, "moveMenuLeft": true, "nav_menu": { "height": "67px", "width": "252px" }, "navigate_menu": true, "number_sections": true, "sideBar": true, "threshold": 4, "toc_cell": false, "toc_section_display": "block", "toc_window_display": false, "widenNotebook": false } }, "nbformat": 4, "nbformat_minor": 2 }