{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Pretrained GAN" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os\n", "os.environ['CUDA_VISIBLE_DEVICES']='2' " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import fastai\n", "from fastai import *\n", "from fastai.vision import *\n", "from fastai.callbacks import *\n", "from fastai.vision.gan import *\n", "from fasterai.generators import *\n", "from fasterai.tensorboard import *" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "path = untar_data(URLs.PETS)\n", "path_hr = path/'images'\n", "path_lr = path/'crappy'\n", "\n", "proj_id = 'SuperResRefine5c'\n", "TENSORBOARD_PATH = Path('data/tensorboard/' + proj_id)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Crappified data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Prepare the input data by crappifying images." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from PIL import Image, ImageDraw, ImageFont" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def crappify(fn,i):\n", " dest = path_lr/fn.relative_to(path_hr)\n", " dest.parent.mkdir(parents=True, exist_ok=True)\n", " img = PIL.Image.open(fn)\n", " targ_sz = resize_to(img, 96, use_min=True)\n", " img = img.resize(targ_sz, resample=PIL.Image.BILINEAR).convert('RGB')\n", " w,h = img.size\n", " q = random.randint(10,70)\n", " ImageDraw.Draw(img).text((random.randint(0,w//2),random.randint(0,h//2)), str(q), fill=(255,255,255))\n", " img.save(dest, quality=q)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Uncomment the first time you run this notebook." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#il = ImageItemList.from_folder(path_hr)\n", "#parallel(crappify, il.items)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For gradual resizing we can change the commented line here." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "bs,size=32, 128\n", "# bs,size = 24,160\n", "#bs,size = 8,256\n", "arch = models.resnet34" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Pre-train generator" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now let's pretrain the generator." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "arch = models.resnet34\n", "src = ImageImageList.from_folder(path_lr).random_split_by_pct(0.1, seed=42)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def get_data(bs,size):\n", " data = (src.label_from_func(lambda x: path_hr/x.name)\n", " .transform(get_transforms(max_zoom=2.), size=size, tfm_y=True)\n", " .databunch(bs=bs).normalize(imagenet_stats, do_y=True))\n", "\n", " data.c = 3\n", " return data" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data_gen = get_data(bs,size)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "wd = 1e-3" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "y_range = (-3.,3.)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "loss_gen = FeatureLoss(gram_wgt=5e3)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def create_gen_learner():\n", " return unet_learner2(data_gen, arch, wd=wd, blur=True, norm_type=NormType.Spectral,\n", " self_attention=True, y_range=y_range, loss_func=loss_gen)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn_gen = create_gen_learner()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn_gen.fit_one_cycle(8, pct_start=0.8)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn_gen.unfreeze()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn_gen.fit_one_cycle(12, slice(1e-6,1e-3))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn_gen.show_results(rows=4)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn_gen.save('gen-pre-c')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Save generated images" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn_gen.load('gen-pre-c');" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "name_gen = 'image_gen-c'\n", "path_gen = path/name_gen" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# shutil.rmtree(path_gen)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "path_gen.mkdir(exist_ok=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def save_preds(dl):\n", " i=0\n", " names = dl.dataset.items\n", " \n", " for b in dl:\n", " preds = learn_gen.pred_batch(batch=b, reconstruct=True)\n", " for o in preds:\n", " o.save(path_gen/names[i].name)\n", " i += 1" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "save_preds(data_gen.fix_dl)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "PIL.Image.open(path_gen.ls()[0])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Train critic" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn_gen=None\n", "gc.collect()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Pretrain the critic on crappy vs not crappy." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def get_crit_data(classes, bs, size):\n", " src = ImageItemList.from_folder(path, include=classes).random_split_by_pct(0.1, seed=42)\n", " ll = src.label_from_folder(classes=classes)\n", " data = (ll.transform(get_transforms(max_zoom=2.), size=size)\n", " .databunch(bs=bs).normalize(imagenet_stats))\n", " return data" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data_crit = get_crit_data([name_gen, 'images'], bs=bs, size=size)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data_crit.show_batch(rows=3, ds_type=DatasetType.Train, imgsize=3)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "loss_critic = AdaptiveLoss(nn.BCEWithLogitsLoss())" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def create_critic_learner(data, metrics):\n", " learner = Learner(data, gan_critic(), metrics=metrics, loss_func=loss_critic, wd=wd)\n", " return learner" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn_critic = create_critic_learner(data_crit, accuracy_thresh_expand)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn_critic.fit_one_cycle(6, 1e-3)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn_critic.save('critic-pre-c')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## GAN" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we'll combine those pretrained model in a GAN." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn_crit=None\n", "learn_gen=None\n", "gc.collect()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data_crit = get_crit_data(['crappy', 'images'], bs=bs, size=size)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn_crit = create_critic_learner(data_crit, metrics=None).load('critic-pre-c')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn_gen = create_gen_learner().load('gen-pre-c')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To define a GAN Learner, we just have to specify the learner objects foor the generator and the critic. The switcher is a callback that decides when to switch from discriminator to generator and vice versa. Here we do as many iterations of the discriminator as needed to get its loss back < 0.5 then one iteration of the generator.\n", "\n", "The loss of the critic is given by `learn_crit.loss_func`. We take the average of this loss function on the batch of real predictions (target 1) and the batch of fake predicitions (target 0). \n", "\n", "The loss of the generator is weighted sum (weights in `weights_gen`) of `learn_crit.loss_func` on the batch of fake (passed throught the critic to become predictions) with a target of 1, and the `learn_gen.loss_func` applied to the output (batch of fake) and the target (corresponding batch of superres images)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "switcher = partial(AdaptiveGANSwitcher, critic_thresh=0.65)\n", "learn = GANLearner.from_learners(learn_gen, learn_crit, weights_gen=(1.0,2.0), show_img=False, switcher=switcher,\n", " opt_func=partial(optim.Adam, betas=(0.,0.99)), wd=wd)\n", "learn.callback_fns.append(partial(GANDiscriminativeLR, mult_lr=5.))\n", "learn.callback_fns.append(partial(GANTensorboardWriter, base_dir=TENSORBOARD_PATH, name='GanLearner', visual_iters=100))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "lr = 1e-4" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.fit(10,lr)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.show_results()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.save('gan-c')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.load('gan-c')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.data=get_data(14,192)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.fit(10,lr/2)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.show_results(rows=14)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.save('gan-c')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.data=get_data(7,256)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.fit(10,lr/4)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.save('gan-c')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.show_results(rows=7)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.fit(20,lr/40)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.save('gan-c')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.load('gan-c');" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.data=get_data(16,256)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.show_results(rows=14)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.data=get_data(32,192)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.show_results(rows=32)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## fin" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.0" } }, "nbformat": 4, "nbformat_minor": 2 }