{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os\n", "os.environ['CUDA_VISIBLE_DEVICES']='0' " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import fastai\n", "from fastai import *\n", "from fastai.vision import *\n", "from fastai.callbacks import *\n", "from fastai.vision.gan import *\n", "from fasterai.dataset import *\n", "from fasterai.visualize import *\n", "from fasterai.tensorboard import *\n", "from fasterai.loss import *\n", "from fasterai.critics import *\n", "from fasterai.generators import *\n", "from pathlib import Path\n", "from itertools import repeat\n", "plt.style.use('dark_background')\n", "torch.backends.cudnn.benchmark=True\n", "from PIL import ImageFile" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "IMAGENET = Path('data/imagenet/ILSVRC/Data/CLS-LOC')\n", "BWIMAGENET = Path('data/imagenet/ILSVRC/Data/CLS-LOC/bandw')\n", "\n", "proj_id = 'colorize1'\n", "TENSORBOARD_PATH = Path('data/tensorboard/' + proj_id)\n", "\n", "torch.backends.cudnn.benchmark=True" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def decolorize(fn:str, i:int):\n", " dest = BWIMAGENET/fn.relative_to(IMAGENET)\n", " dest.parent.mkdir(parents=True, exist_ok=True)\n", " img = PIL.Image.open(fn).convert('LA').convert('RGB')\n", " img.save(dest) " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Uncomment the first time you run this notebook." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#il = ImageItemList.from_folder(IMAGENET/'val')\n", "#parallel(decolorize, il.items, max_workers=16)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#il = ImageItemList.from_folder(IMAGENET/'train')\n", "#parallel(decolorize, il.items, max_workers=16)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def get_data(sz:int, bs:int, keep_pct:float):\n", " return get_colorize_data(sz=sz, bs=bs, crappy_path=BWIMAGENET, good_path=IMAGENET, \n", " random_seed=None, keep_pct=keep_pct,num_workers=16)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def save():\n", " learn_gen.save(proj_id + '_gen_' + str(sz))\n", " learn_crit.save(proj_id + '_crit_' + str(sz))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def load():\n", " learn_gen.load(proj_id + '_gen_' + str(sz))\n", " learn_crit.load(proj_id + '_crit_' + str(sz))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def colorize_gen_learner_exp(data:ImageDataBunch, gen_loss=FeatureLoss(), arch=models.resnet34):\n", " return unet_learner3(data, arch, wd=1e-3, blur=True, norm_type=NormType.Spectral,\n", " self_attention=True, y_range=(-3.,3.), loss_func=gen_loss)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Training" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#Needed to instantiate critic but not actually used\n", "sz=64\n", "bs=128\n", "\n", "data = get_data(sz=sz, bs=bs, keep_pct=1.0)\n", "learn_crit = colorize_crit_learner(data=data, nf=256)\n", "learn_crit.unfreeze()\n", "\n", "gen_loss = FeatureLoss()\n", "learn_gen = colorize_gen_learner_exp(data=data)\n", "\n", "switcher = partial(AdaptiveGANSwitcher, critic_thresh=0.65)\n", "learn = GANLearner.from_learners(learn_gen, learn_crit, weights_gen=(1.0,2.0), show_img=False, switcher=switcher,\n", " opt_func=partial(optim.Adam, betas=(0.,0.9)), wd=1e-3)\n", "\n", "learn.callback_fns.append(partial(GANDiscriminativeLR, mult_lr=5.))\n", "learn.callback_fns.append(partial(GANTensorboardWriter, base_dir=TENSORBOARD_PATH, name='GanLearner', visual_iters=100))\n", "\n", "lr=1e-4\n", "unfreeze_fctr=0.1" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 64px" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn_gen.freeze_to(-1)\n", "learn.fit(1,lr)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "save()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn_gen.unfreeze()\n", "learn.fit(1,lr*unfreeze_fctr)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "save()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 96px" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#lr=lr/2\n", "sz=96\n", "bs=bs//2" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.data = get_data(sz=sz, bs=bs, keep_pct=0.1)\n", "learn_gen.freeze_to(-1)\n", "learn.fit(1,lr/10)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "save()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.data = get_data(sz=sz, bs=bs, keep_pct=1.0)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn_gen.freeze_to(-1)\n", "learn.fit(1,lr)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "save()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn_gen.unfreeze()\n", "learn.fit(1,lr*unfreeze_fctr)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "save()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 128px" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#lr=lr/2\n", "sz=128\n", "bs=bs//2" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.data = get_data(sz=sz, bs=bs, keep_pct=0.1)\n", "learn_gen.freeze_to(-1)\n", "learn.fit(1,lr/10)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "save()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.data = get_data(sz=sz, bs=bs, keep_pct=1.0)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn_gen.freeze_to(-1)\n", "learn.fit(1,lr)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "save()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn_gen.unfreeze()\n", "learn.fit(1,lr*unfreeze_fctr)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "save()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "load()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 160px" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "lr=lr/1.5\n", "sz=160\n", "bs=int(bs//1.5)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.data = get_data(sz=sz, bs=bs, keep_pct=0.1)\n", "learn_gen.freeze_to(-1)\n", "learn.fit(1,lr/10)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "save()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.data = get_data(sz=sz, bs=bs, keep_pct=0.25)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn_gen.freeze_to(-1)\n", "learn.fit(1,lr)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "save()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn_gen.unfreeze()\n", "learn.fit(1,lr*unfreeze_fctr)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "save()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 192px" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "lr=lr/1.5\n", "sz=192\n", "bs=int(bs//1.5)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.data = get_data(sz=sz, bs=bs, keep_pct=0.1)\n", "learn_gen.freeze_to(-1)\n", "learn.fit(1,lr/10)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "save()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.data = get_data(sz=sz, bs=bs, keep_pct=0.25)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn_gen.freeze_to(-1)\n", "learn.fit(1,lr)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "save()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn_gen.unfreeze()\n", "learn.fit(1,lr*unfreeze_fctr)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "save()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 224px" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "lr=lr/1.5\n", "sz=224\n", "bs=int(bs//1.5)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.data = get_data(sz=sz, bs=bs, keep_pct=0.1)\n", "learn_gen.freeze_to(-1)\n", "learn.fit(1,lr/10)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "save()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.data = get_data(sz=sz, bs=bs, keep_pct=0.25)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn_gen.freeze_to(-1)\n", "learn.fit(1,lr)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "save()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn_gen.unfreeze()\n", "learn.fit(1,lr*unfreeze_fctr)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "save()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "load()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 256px" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "lr=lr/1.75\n", "sz=256\n", "bs=int(bs//1.5)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.data = get_data(sz=sz, bs=bs, keep_pct=0.1)\n", "learn_gen.freeze_to(-1)\n", "learn.fit(1,lr/10)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "save()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.data = get_data(sz=sz, bs=bs, keep_pct=0.25)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn_gen.freeze_to(-1)\n", "learn.fit(1,lr)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "save()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn_gen.unfreeze()\n", "learn.fit(1,lr*unfreeze_fctr)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "save()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.0" }, "toc": { "colors": { "hover_highlight": "#DAA520", "navigate_num": "#000000", "navigate_text": "#333333", "running_highlight": "#FF0000", "selected_highlight": "#FFD700", "sidebar_border": "#EEEEEE", "wrapper_background": "#FFFFFF" }, "moveMenuLeft": true, "nav_menu": { "height": "67px", "width": "252px" }, "navigate_menu": true, "number_sections": true, "sideBar": true, "threshold": 4, "toc_cell": false, "toc_section_display": "block", "toc_window_display": false, "widenNotebook": false } }, "nbformat": 4, "nbformat_minor": 2 }