|
@@ -1,248 +0,0 @@
|
|
|
-{
|
|
|
- "cells": [
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": null,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [],
|
|
|
- "source": [
|
|
|
- "%matplotlib inline\n",
|
|
|
- "%reload_ext autoreload\n",
|
|
|
- "%autoreload 2"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": null,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [],
|
|
|
- "source": [
|
|
|
- "import multiprocessing\n",
|
|
|
- "from torch import autograd\n",
|
|
|
- "from fastai.conv_learner import *\n",
|
|
|
- "from fastai.transforms import TfmType\n",
|
|
|
- "from fasterai.transforms import *\n",
|
|
|
- "from fasterai.images import *\n",
|
|
|
- "from fasterai.dataset import *\n",
|
|
|
- "from fasterai.visualize import *\n",
|
|
|
- "from fasterai.callbacks import *\n",
|
|
|
- "from fasterai.loss import *\n",
|
|
|
- "from fasterai.modules import *\n",
|
|
|
- "from fasterai.training import *\n",
|
|
|
- "from fasterai.generators import *\n",
|
|
|
- "from fastai.torch_imports import *\n",
|
|
|
- "from pathlib import Path\n",
|
|
|
- "from itertools import repeat\n",
|
|
|
- "import tensorboardX\n",
|
|
|
- "torch.cuda.set_device(0)\n",
|
|
|
- "plt.style.use('dark_background')\n",
|
|
|
- "torch.backends.cudnn.benchmark=True\n"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": null,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [],
|
|
|
- "source": [
|
|
|
- "IMAGENET = Path('data/imagenet/ILSVRC/Data/CLS-LOC/train')\n",
|
|
|
- "proj_id = 'bwdefade'\n",
|
|
|
- "TENSORBOARD_PATH = Path('data/tensorboard/' + proj_id)\n",
|
|
|
- "gpath = IMAGENET.parent/(proj_id + '_gen_64.h5')\n",
|
|
|
- "dpath = IMAGENET.parent/(proj_id + '_critic_64.h5')\n",
|
|
|
- "c_lr=5e-4\n",
|
|
|
- "c_lrs = np.array([c_lr,c_lr,c_lr])\n",
|
|
|
- "\n",
|
|
|
- "g_lr=c_lr/5\n",
|
|
|
- "g_lrs = np.array([g_lr/100,g_lr/10,g_lr])\n",
|
|
|
- "\n",
|
|
|
- "keep_pcts=[0.25,0.25]\n",
|
|
|
- "gen_freeze_tos=[-1,0]\n",
|
|
|
- "lrs_unfreeze_factor=0.05\n",
|
|
|
- "x_tfms = [RandomLighting(0.5, 0.5)]\n",
|
|
|
- "extra_aug_tfms = [BlackAndWhiteTransform(tfm_y=TfmType.PIXEL)]\n",
|
|
|
- "torch.backends.cudnn.benchmark=True"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "markdown",
|
|
|
- "metadata": {},
|
|
|
- "source": [
|
|
|
- "## Training"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": null,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [],
|
|
|
- "source": [
|
|
|
- "netG = Unet34(nf_factor=2).cuda()\n",
|
|
|
- "#netGVis = ModelVisualizationHook(TENSORBOARD_PATH, netG, 'netG')\n",
|
|
|
- "#load_model(netG, gpath)\n",
|
|
|
- "\n",
|
|
|
- "netD = DCCritic(ni=3, nf=384).cuda()\n",
|
|
|
- "#netDVis = ModelVisualizationHook(TENSORBOARD_PATH, netD, 'netD')\n",
|
|
|
- "#load_model(netD, dpath)"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": null,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [],
|
|
|
- "source": [
|
|
|
- "trainer = GANTrainer(netD=netD, netG=netG, genloss_fns=[FeatureLoss(multiplier=1e2)])\n",
|
|
|
- "trainerVis = GANVisualizationHook(TENSORBOARD_PATH, trainer, 'trainer', jupyter=False, visual_iters=100)"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": null,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [],
|
|
|
- "source": [
|
|
|
- "scheds=[]\n",
|
|
|
- "\n",
|
|
|
- "scheds.extend(GANTrainSchedule.generate_schedules(szs=[64,64], bss=[128,128], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms,keep_pcts=[1.0,1.0], \n",
|
|
|
- " save_base_name=proj_id, c_lrs=c_lrs, g_lrs=g_lrs, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=gen_freeze_tos))\n",
|
|
|
- "\n",
|
|
|
- "c_lrs=c_lrs/2\n",
|
|
|
- "g_lrs=g_lrs/2\n",
|
|
|
- "\n",
|
|
|
- "#unshock\n",
|
|
|
- "scheds.extend(GANTrainSchedule.generate_schedules(szs=[96], bss=[64], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.1], \n",
|
|
|
- " save_base_name=proj_id, c_lrs=c_lrs/10, g_lrs=g_lrs/10, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))\n",
|
|
|
- "\n",
|
|
|
- "scheds.extend(GANTrainSchedule.generate_schedules(szs=[96,96], bss=[64,64], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=keep_pcts, \n",
|
|
|
- " save_base_name=proj_id, c_lrs=c_lrs, g_lrs=g_lrs, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=gen_freeze_tos))\n",
|
|
|
- "\n",
|
|
|
- "c_lrs=c_lrs/2\n",
|
|
|
- "g_lrs=g_lrs/2\n",
|
|
|
- "\n",
|
|
|
- "#unshock\n",
|
|
|
- "scheds.extend(GANTrainSchedule.generate_schedules(szs=[128], bss=[32], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.1], \n",
|
|
|
- " save_base_name=proj_id, c_lrs=c_lrs/10, g_lrs=g_lrs/10, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))\n",
|
|
|
- "\n",
|
|
|
- "scheds.extend(GANTrainSchedule.generate_schedules(szs=[128,128], bss=[32,32], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=keep_pcts, \n",
|
|
|
- " save_base_name=proj_id, c_lrs=c_lrs, g_lrs=g_lrs, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=gen_freeze_tos))\n",
|
|
|
- "\n",
|
|
|
- "\n",
|
|
|
- "c_lrs=c_lrs/2\n",
|
|
|
- "g_lrs=g_lrs/2\n",
|
|
|
- "\n",
|
|
|
- "#unshock\n",
|
|
|
- "scheds.extend(GANTrainSchedule.generate_schedules(szs=[160], bss=[20], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.1], \n",
|
|
|
- " save_base_name=proj_id, c_lrs=c_lrs/10, g_lrs=g_lrs/10, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))\n",
|
|
|
- "\n",
|
|
|
- "scheds.extend(GANTrainSchedule.generate_schedules(szs=[160], bss=[20], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.25], \n",
|
|
|
- " save_base_name=proj_id, c_lrs=c_lrs, g_lrs=g_lrs, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))\n",
|
|
|
- "\n",
|
|
|
- "scheds.extend(GANTrainSchedule.generate_schedules(szs=[160], bss=[20], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.25], \n",
|
|
|
- " save_base_name=proj_id, c_lrs=c_lrs, g_lrs=g_lrs, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[0]))\n",
|
|
|
- "\n",
|
|
|
- "\n",
|
|
|
- "c_lrs=c_lrs/2\n",
|
|
|
- "g_lrs=g_lrs/2\n",
|
|
|
- "\n",
|
|
|
- "#unshock\n",
|
|
|
- "scheds.extend(GANTrainSchedule.generate_schedules(szs=[192], bss=[12], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.1], \n",
|
|
|
- " save_base_name=proj_id, c_lrs=c_lrs/10, g_lrs=g_lrs/10, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))\n",
|
|
|
- "\n",
|
|
|
- "scheds.extend(GANTrainSchedule.generate_schedules(szs=[192], bss=[12], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.25], \n",
|
|
|
- " save_base_name=proj_id, c_lrs=c_lrs, g_lrs=g_lrs, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))\n",
|
|
|
- "\n",
|
|
|
- "scheds.extend(GANTrainSchedule.generate_schedules(szs=[192], bss=[12], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.25], \n",
|
|
|
- " save_base_name=proj_id, c_lrs=c_lrs, g_lrs=g_lrs, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[0]))\n",
|
|
|
- "\n",
|
|
|
- "\n",
|
|
|
- "c_lrs=c_lrs/2\n",
|
|
|
- "g_lrs=g_lrs/2\n",
|
|
|
- "\n",
|
|
|
- "#unshock\n",
|
|
|
- "scheds.extend(GANTrainSchedule.generate_schedules(szs=[224], bss=[8], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.1], \n",
|
|
|
- " save_base_name=proj_id, c_lrs=c_lrs/10, g_lrs=g_lrs/10, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))\n",
|
|
|
- "\n",
|
|
|
- "scheds.extend(GANTrainSchedule.generate_schedules(szs=[224], bss=[8], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.25], \n",
|
|
|
- " save_base_name=proj_id, c_lrs=c_lrs, g_lrs=g_lrs, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))\n",
|
|
|
- "\n",
|
|
|
- "scheds.extend(GANTrainSchedule.generate_schedules(szs=[224], bss=[8], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.25], \n",
|
|
|
- " save_base_name=proj_id, c_lrs=c_lrs, g_lrs=g_lrs, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[0]))\n",
|
|
|
- "\n",
|
|
|
- "c_lrs=c_lrs/2\n",
|
|
|
- "g_lrs=g_lrs/2\n",
|
|
|
- "\n",
|
|
|
- "#unshock\n",
|
|
|
- "scheds.extend(GANTrainSchedule.generate_schedules(szs=[256], bss=[6], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.1], \n",
|
|
|
- " save_base_name=proj_id, c_lrs=c_lrs/10, g_lrs=g_lrs/10, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))\n",
|
|
|
- "\n",
|
|
|
- "scheds.extend(GANTrainSchedule.generate_schedules(szs=[256], bss=[6], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.25], \n",
|
|
|
- " save_base_name=proj_id, c_lrs=c_lrs, g_lrs=g_lrs, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))\n",
|
|
|
- "\n",
|
|
|
- "scheds.extend(GANTrainSchedule.generate_schedules(szs=[256], bss=[6], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.25], \n",
|
|
|
- " save_base_name=proj_id, c_lrs=c_lrs, g_lrs=g_lrs, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[0]))"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": null,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [],
|
|
|
- "source": [
|
|
|
- "trainer.train(scheds=scheds)"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": null,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [],
|
|
|
- "source": []
|
|
|
- }
|
|
|
- ],
|
|
|
- "metadata": {
|
|
|
- "kernelspec": {
|
|
|
- "display_name": "Python 3",
|
|
|
- "language": "python",
|
|
|
- "name": "python3"
|
|
|
- },
|
|
|
- "language_info": {
|
|
|
- "codemirror_mode": {
|
|
|
- "name": "ipython",
|
|
|
- "version": 3
|
|
|
- },
|
|
|
- "file_extension": ".py",
|
|
|
- "mimetype": "text/x-python",
|
|
|
- "name": "python",
|
|
|
- "nbconvert_exporter": "python",
|
|
|
- "pygments_lexer": "ipython3",
|
|
|
- "version": "3.6.5"
|
|
|
- },
|
|
|
- "toc": {
|
|
|
- "colors": {
|
|
|
- "hover_highlight": "#DAA520",
|
|
|
- "navigate_num": "#000000",
|
|
|
- "navigate_text": "#333333",
|
|
|
- "running_highlight": "#FF0000",
|
|
|
- "selected_highlight": "#FFD700",
|
|
|
- "sidebar_border": "#EEEEEE",
|
|
|
- "wrapper_background": "#FFFFFF"
|
|
|
- },
|
|
|
- "moveMenuLeft": true,
|
|
|
- "nav_menu": {
|
|
|
- "height": "67px",
|
|
|
- "width": "252px"
|
|
|
- },
|
|
|
- "navigate_menu": true,
|
|
|
- "number_sections": true,
|
|
|
- "sideBar": true,
|
|
|
- "threshold": 4,
|
|
|
- "toc_cell": false,
|
|
|
- "toc_section_display": "block",
|
|
|
- "toc_window_display": false,
|
|
|
- "widenNotebook": false
|
|
|
- }
|
|
|
- },
|
|
|
- "nbformat": 4,
|
|
|
- "nbformat_minor": 2
|
|
|
-}
|