{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Pretrained GAN" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os\n", "os.environ['CUDA_VISIBLE_DEVICES']='0' " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import fastai\n", "from fastai import *\n", "from fastai.vision import *\n", "from fastai.callbacks.tensorboard import *\n", "from fastai.vision.gan import *\n", "from fasterai.generators import *\n", "from fasterai.critics import *\n", "from fasterai.dataset import *\n", "from fasterai.loss import *\n", "from PIL import Image, ImageDraw, ImageFont\n", "from PIL import ImageFile" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Setup" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "path = Path('data/imagenet/ILSVRC/Data/CLS-LOC')\n", "path_hr = path\n", "path_lr = path/'bandw'\n", "\n", "proj_id = 'ColorizeNew73'\n", "gen_name = proj_id + '_gen'\n", "crit_name = proj_id + '_crit'\n", "\n", "name_gen = proj_id + '_image_gen'\n", "path_gen = path/name_gen\n", "\n", "TENSORBOARD_PATH = Path('data/tensorboard/' + proj_id)\n", "\n", "nf_factor = 2" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def save_all(suffix=''):\n", " learn_gen.save(gen_name + str(sz) + suffix)\n", " learn_crit.save(crit_name + str(sz) + suffix)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def load_all(suffix=''):\n", " learn_gen.load(gen_name + str(sz) + suffix, with_opt=False)\n", " learn_crit.load(crit_name + str(sz) + suffix, with_opt=False)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def get_data(bs:int, sz:int, keep_pct:float):\n", " return get_colorize_data(sz=sz, bs=bs, crappy_path=path_lr, good_path=path_hr, \n", " random_seed=None, keep_pct=keep_pct)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def get_crit_data(classes, bs, sz):\n", " src = ImageList.from_folder(path, include=classes, recurse=True).random_split_by_pct(0.1, seed=42)\n", " ll = src.label_from_folder(classes=classes)\n", " data = (ll.transform(get_transforms(max_zoom=2.), size=sz)\n", " .databunch(bs=bs).normalize(imagenet_stats))\n", " return data" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def crappify(fn,i):\n", " dest = path_lr/fn.relative_to(path_hr)\n", " dest.parent.mkdir(parents=True, exist_ok=True)\n", " img = PIL.Image.open(fn).convert('LA').convert('RGB')\n", " img.save(dest) " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def save_preds(dl):\n", " i=0\n", " names = dl.dataset.items\n", " \n", " for b in dl:\n", " preds = learn_gen.pred_batch(batch=b, reconstruct=True)\n", " for o in preds:\n", " o.save(path_gen/names[i].name)\n", " i += 1" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def save_gen_images(learn_gen):\n", " if path_gen.exists(): shutil.rmtree(path_gen)\n", " path_gen.mkdir(exist_ok=True)\n", " data_gen = get_data(bs=bs, sz=sz, keep_pct=0.085)\n", " save_preds(data_gen.fix_dl)\n", " PIL.Image.open(path_gen.ls()[0])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Crappified data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Prepare the input data by crappifying images." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Uncomment the first time you run this notebook." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#il = ImageItemList.from_folder(path_hr)\n", "#parallel(crappify, il.items)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Pre-training" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Pre-train generator" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now let's pretrain the generator." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "bs=88\n", "sz=64\n", "keep_pct=1.0" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data_gen = get_data(bs=bs, sz=sz, keep_pct=keep_pct)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn_gen = gen_learner_wide(data=data_gen, gen_loss=FeatureLoss(), nf_factor=nf_factor)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn_gen.callback_fns.append(partial(ImageGenTensorboardWriter, base_dir=TENSORBOARD_PATH, name='GenPre'))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn_gen.fit_one_cycle(2, pct_start=0.8, max_lr=slice(1e-3))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn_gen.save(gen_name)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn_gen.load(gen_name, with_opt=False)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn_gen.unfreeze()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn_gen.fit_one_cycle(2, pct_start=0.01, max_lr=slice(3e-7, 3e-4))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn_gen.save(gen_name)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "bs=20\n", "sz=128\n", "keep_pct=1.0" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn_gen.data = get_data(sz=sz, bs=bs, keep_pct=keep_pct)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn_gen.unfreeze()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn_gen.load(gen_name, with_opt=False)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn_gen.fit_one_cycle(2, pct_start=0.01, max_lr=slice(1e-7,1e-4))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn_gen.save(gen_name)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn_gen.load(gen_name, with_opt=False)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "bs=8\n", "sz=192\n", "keep_pct=0.50" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn_gen.data = get_data(sz=sz, bs=bs, keep_pct=keep_pct)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn_gen.unfreeze()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn_gen.fit_one_cycle(1, pct_start=0.01, max_lr=slice(5e-8,5e-5))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn_gen.save(gen_name)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Save generated images" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "save_gen_images(gen_name)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Train critic" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Pretrain the critic on crappy vs not crappy." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "bs=64\n", "sz=128" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn_gen=None\n", "gc.collect()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data_crit = get_crit_data([name_gen, 'test'], bs=bs, sz=sz)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data_crit.show_batch(rows=3, ds_type=DatasetType.Train, imgsize=3)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn_critic = colorize_crit_learner(data=data_crit, nf=256)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn_critic.callback_fns.append(partial(LearnerTensorboardWriter, base_dir=TENSORBOARD_PATH, name='CriticPre'))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn_critic.fit_one_cycle(6, 1e-3)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn_critic.save(crit_name + '1')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "bs=16\n", "sz=192" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn_critic.data=get_crit_data([name_gen, 'test'], bs=bs, sz=sz)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn_critic.fit_one_cycle(4, 1e-4)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn_critic.save(crit_name + '1')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## GAN" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we'll combine those pretrained model in a GAN." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn_crit=None\n", "learn_gen=None\n", "gc.collect()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "lr=2e-5\n", "sz=192\n", "bs=5" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#placeholder- not actually used\n", "data_crit = get_crit_data([name_gen, 'test'], bs=bs, sz=sz)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn_crit = colorize_crit_learner(data=data_crit, nf=256).load(crit_name + '1', with_opt=False)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn_gen = gen_learner_wide(data=data_gen, gen_loss=FeatureLoss(), nf_factor=nf_factor).load(gen_name, with_opt=False)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "switcher = partial(AdaptiveGANSwitcher, critic_thresh=0.65)\n", "learn = GANLearner.from_learners(learn_gen, learn_crit, weights_gen=(1.0,1.5), show_img=False, switcher=switcher,\n", " opt_func=partial(optim.Adam, betas=(0.,0.9)), wd=1e-3)\n", "learn.callback_fns.append(partial(GANDiscriminativeLR, mult_lr=5.))\n", "learn.callback_fns.append(partial(GANTensorboardWriter, base_dir=TENSORBOARD_PATH, name='GanLearner', visual_iters=100))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "for i in range(1,101):\n", " learn.data = get_data(sz=sz, bs=bs, keep_pct=0.001)\n", " learn_gen.freeze_to(-1)\n", " learn.fit(1,lr)\n", " save_all('_1_' + str(i))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Repeat Pretrain-GAN Cycle" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "old_checkpoint_num = 5\n", "checkpoint_num = old_checkpoint_num + 1\n", "gen_old_checkpoint_name = 'ColorizeNew73_gen192_5_7'\n", "crit_old_checkpoint_name = crit_name + str(old_checkpoint_num)\n", "crit_new_checkpoint_name= crit_name + str(checkpoint_num)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Save Generated Images Again" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "bs=8\n", "sz=192" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn_gen = gen_learner_wide(data=data_gen, gen_loss=FeatureLoss(), nf_factor=nf_factor).load(gen_old_checkpoint_name, with_opt=False)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "save_gen_images(gen_name)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Train Critic Again" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "bs=16\n", "sz=192" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn_gen=None\n", "gc.collect()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data_crit = get_crit_data([name_gen, 'test'], bs=bs, sz=sz)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data_crit.show_batch(rows=3, ds_type=DatasetType.Train, imgsize=3)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn_critic = colorize_crit_learner(data=data_crit, nf=256).load(crit_old_checkpoint_name, with_opt=False)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn_critic.callback_fns.append(partial(LearnerTensorboardWriter, base_dir=TENSORBOARD_PATH, name='CriticPre'))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn_critic.fit_one_cycle(4, 1e-4)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn_critic.save(crit_new_checkpoint_name)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### GAN Again" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn_crit=None\n", "learn_gen=None\n", "gc.collect()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "lr=2e-5\n", "sz=192\n", "bs=5" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data_crit = get_crit_data([name_gen, 'test'], bs=bs, sz=sz)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn_crit = colorize_crit_learner(data=data_crit, nf=256).load(crit_new_checkpoint_name, with_opt=False)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn_gen = gen_learner_wide(data=data_gen, gen_loss=FeatureLoss(), nf_factor=nf_factor).load(gen_old_checkpoint_name, with_opt=False)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "switcher = partial(AdaptiveGANSwitcher, critic_thresh=0.65)\n", "learn = GANLearner.from_learners(learn_gen, learn_crit, weights_gen=(1.0,1.5), show_img=False, switcher=switcher,\n", " opt_func=partial(optim.Adam, betas=(0.,0.9)), wd=1e-3)\n", "learn.callback_fns.append(partial(GANDiscriminativeLR, mult_lr=5.))\n", "learn.callback_fns.append(partial(GANTensorboardWriter, base_dir=TENSORBOARD_PATH, name='GanLearner', visual_iters=100))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "for i in range(1,101):\n", " learn.data = get_data(sz=sz, bs=bs, keep_pct=0.001)\n", " learn_gen.freeze_to(-1)\n", " learn.fit(1,lr)\n", " save_all('_' + str(checkpoint_num) '_' + str(i))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## fin" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.0" } }, "nbformat": 4, "nbformat_minor": 2 }