{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Stable Model Training with monitoring through Weights & Biases" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### NOTES: \n", "* This is \"NoGAN\" based training, described in the DeOldify readme.\n", "* This model prioritizes stable and reliable renderings. It does particularly well on portraits and landscapes. It's not as colorful as the artistic model.\n", "* Training is logged and monitored through [Weights & Biases](https://www.wandb.com/)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Install W&B Callback\n", "#!pip install wandb" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os\n", "os.environ['CUDA_VISIBLE_DEVICES']='0' " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import fastai\n", "from fastai import *\n", "from fastai.vision import *\n", "# TODELETE from fastai.callbacks.tensorboard import *\n", "from fastai.vision.gan import *\n", "from deoldify.generators import *\n", "from deoldify.critics import *\n", "from deoldify.dataset import *\n", "from deoldify.loss import *\n", "from deoldify.save import *\n", "from PIL import Image, ImageDraw, ImageFont\n", "from PIL import ImageFile\n", "from torch.utils.data.sampler import RandomSampler, SequentialSampler\n", "import wandb\n", "from wandb.fastai import WandbCallback" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Setup" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Set up W&B: checks user can connect to W&B servers\n", "# Note: set up API key the first time\n", "wandb.login()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Dataset can be downloaded from https://www.kaggle.com/c/imagenet-object-localization-challenge/data\n", "path = Path('data/imagenet/ILSVRC/Data/CLS-LOC')\n", "path_hr = path\n", "path_lr = path/'bandw'\n", "\n", "proj_id = 'StableModel'\n", "\n", "gen_name = proj_id + '_gen'\n", "pre_gen_name = gen_name + '_0'\n", "crit_name = proj_id + '_crit'\n", "\n", "name_gen = proj_id + '_image_gen'\n", "path_gen = path/name_gen\n", "\n", "nf_factor = 2\n", "pct_start = 1e-8" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Iterating through the dataset\n", "\n", "The dataset is very large and it would take a long time to iterate through all the samples at each epoch.\n", "\n", "We use custom samplers in order to limit epochs to subsets of data while still iterating slowly through the entire dataset (epoch after epoch). This let us run the validation loop more often where we log metrics as well as prediction samples on validation data." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Reduce quantity of samples per training epoch\n", "# Adapted from https://forums.fast.ai/t/epochs-of-arbitrary-length/27777/10\n", "\n", "@classmethod\n", "def create(cls, train_ds:Dataset, valid_ds:Dataset, test_ds:Optional[Dataset]=None, path:PathOrStr='.', bs:int=64,\n", " val_bs:int=None, num_workers:int=defaults.cpus, dl_tfms:Optional[Collection[Callable]]=None,\n", " device:torch.device=None, collate_fn:Callable=data_collate, no_check:bool=False, sampler=None, **dl_kwargs)->'DataBunch':\n", " \"Create a `DataBunch` from `train_ds`, `valid_ds` and maybe `test_ds` with a batch size of `bs`. Passes `**dl_kwargs` to `DataLoader()`\"\n", " datasets = cls._init_ds(train_ds, valid_ds, test_ds)\n", " val_bs = ifnone(val_bs, bs)\n", " if sampler is None: sampler = [RandomSampler] + 3*[SequentialSampler]\n", " dls = [DataLoader(d, b, sampler=sa(d), drop_last=sh, num_workers=num_workers, **dl_kwargs) for d,b,sh,sa in\n", " zip(datasets, (bs,val_bs,val_bs,val_bs), (True,False,False,False), sampler) if d is not None]\n", " return cls(*dls, path=path, device=device, dl_tfms=dl_tfms, collate_fn=collate_fn, no_check=no_check)\n", "\n", "ImageDataBunch.create = create\n", "ImageImageList._bunch = ImageDataBunch\n", "\n", "class FixedLenRandomSampler(RandomSampler):\n", " def __init__(self, data_source, epoch_size):\n", " super().__init__(data_source)\n", " self.epoch_size = epoch_size\n", " self.not_sampled = np.array([True]*len(data_source))\n", " \n", " @property\n", " def reset_state(self): self.not_sampled[:] = True\n", " \n", " def __iter__(self):\n", " ns = sum(self.not_sampled)\n", " idx_last = []\n", " if ns >= len(self):\n", " idx = np.random.choice(np.where(self.not_sampled)[0], size=len(self), replace=False).tolist()\n", " if ns == len(self): self.reset_state\n", " else:\n", " idx_last = np.where(self.not_sampled)[0].tolist()\n", " self.reset_state\n", " idx = np.random.choice(np.where(self.not_sampled)[0], size=len(self)-len(idx_last), replace=False).tolist()\n", " self.not_sampled[idx] = False\n", " idx = [*idx_last, *idx]\n", " return iter(idx)\n", " \n", " def __len__(self):\n", " return self.epoch_size" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def get_data(bs:int, sz:int, keep_pct=1.0, random_seed=None, valid_pct=0.2, epoch_size=1000):\n", " \n", " # Create samplers\n", " train_sampler = partial(FixedLenRandomSampler, epoch_size=epoch_size)\n", " samplers = [train_sampler, SequentialSampler, SequentialSampler, SequentialSampler]\n", "\n", " return get_colorize_data(sz=sz, bs=bs, crappy_path=path_lr, good_path=path_hr, random_seed=random_seed,\n", " keep_pct=keep_pct, samplers=samplers, valid_pct=valid_pct)\n", "\n", "# Function modified to allow use of custom samplers\n", "def get_colorize_data(sz:int, bs:int, crappy_path:Path, good_path:Path, random_seed:int=None,\n", " keep_pct:float=1.0, num_workers:int=8, samplers=None, valid_pct=0.2, xtra_tfms=[])->ImageDataBunch:\n", " src = (ImageImageList.from_folder(crappy_path, convert_mode='RGB')\n", " .use_partial_data(sample_pct=keep_pct, seed=random_seed)\n", " .split_by_rand_pct(valid_pct, seed=random_seed))\n", " data = (src.label_from_func(lambda x: good_path/x.relative_to(crappy_path))\n", " .transform(get_transforms(max_zoom=1.2, max_lighting=0.5, max_warp=0.25, xtra_tfms=xtra_tfms), size=sz, tfm_y=True)\n", " .databunch(bs=bs, num_workers=num_workers, sampler=samplers, no_check=True)\n", " .normalize(imagenet_stats, do_y=True))\n", " data.c = 3\n", " return data\n", "\n", "def get_crit_data(classes, bs, sz):\n", " src = ImageList.from_folder(path, include=classes, recurse=True).random_split_by_pct(0.1, seed=42)\n", " ll = src.label_from_folder(classes=classes)\n", " data = (ll.transform(get_transforms(max_zoom=2.), size=sz)\n", " .databunch(bs=bs).normalize(imagenet_stats))\n", " return data\n", "\n", "def create_training_images(fn,i):\n", " dest = path_lr/fn.relative_to(path_hr)\n", " dest.parent.mkdir(parents=True, exist_ok=True)\n", " img = PIL.Image.open(fn).convert('LA').convert('RGB')\n", " img.save(dest) \n", " \n", "def save_preds(dl):\n", " i=0\n", " names = dl.dataset.items\n", " \n", " for b in dl:\n", " preds = learn_gen.pred_batch(batch=b, reconstruct=True)\n", " for o in preds:\n", " o.save(path_gen/names[i].name)\n", " i += 1\n", " \n", "def save_gen_images():\n", " if path_gen.exists(): shutil.rmtree(path_gen)\n", " path_gen.mkdir(exist_ok=True)\n", " data_gen = get_data(bs=bs, sz=sz, keep_pct=0.085)\n", " save_preds(data_gen.fix_dl)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Create black and white training images" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Only runs if the directory isn't already created." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "if not path_lr.exists():\n", " il = ImageList.from_folder(path_hr)\n", " parallel(create_training_images, il.items)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Number of black & white images\n", "data_size = len(list(path_lr.rglob('*.*')))\n", "print('Number of black & white images:', data_size)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Pre-train generator" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### NOTE\n", "Most of the training takes place here in pretraining for NoGAN. The goal here is to take the generator as far as possible with conventional training, as that is much easier to control and obtain glitch-free results compared to GAN training." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 64px" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Init logging of a new run\n", "wandb.init(tags=['Pre-train Gen']) # tags are optional" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "bs=88\n", "sz=64\n", "\n", "# Define target number of training/validation samples as well as number of epochs\n", "epoch_train_size = 100 * bs\n", "epoch_valid_size = 10 * bs\n", "valid_pct = epoch_valid_size / data_size\n", "number_epochs = (data_size - epoch_valid_size) // epoch_train_size\n", "\n", "# Log hyper parameters\n", "wandb.config.update({\"Step 1 - batch size\": bs, \"Step 1 - image size\": sz,\n", " \"Step 1 - epoch size\": epoch_train_size, \"Step 1 - number epochs\": number_epochs})" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data_gen = get_data(bs=bs, sz=sz, random_seed=12345, valid_pct=valid_pct, epoch_size=100*bs)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn_gen = gen_learner_wide(data=data_gen, gen_loss=FeatureLoss(), nf_factor=nf_factor)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# TODELETE learn_gen.callback_fns.append(partial(ImageGenTensorboardWriter, base_dir=TENSORBOARD_PATH, name='GenPre'))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn_gen.callback_fns.append(partial(WandbCallback,\n", " input_type='images', # log prediction samples\n", " save_model=False)) # bug in get_monitor_value in fastai v1.0.51 (needed for auto saving best model)\n", " # save_model default can be used if using fastai v1.0.53" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn_gen.fit_one_cycle(number_epochs, pct_start=0.8, max_lr=slice(1e-3))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn_gen.save(pre_gen_name)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn_gen.unfreeze()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn_gen.fit_one_cycle(number_epochs, pct_start=pct_start, max_lr=slice(3e-7, 3e-4))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn_gen.save(pre_gen_name)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 128px" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "bs=20\n", "sz=128\n", "\n", "# Define target number of training/validation samples as well as number of epochs\n", "epoch_train_size = 100 * bs\n", "epoch_valid_size = 10 * bs\n", "valid_pct = epoch_valid_size / data_size\n", "number_epochs = (data_size - epoch_valid_size) // epoch_train_size\n", "\n", "# Log hyper parameters\n", "wandb.config.update({\"Step 2 - batch size\": bs, \"Step 2 - image size\": sz,\n", " \"Step 2 - epoch size\": epoch_train_size, \"Step 2 - number epochs\": number_epochs})" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn_gen.data = get_data(bs=bs, sz=sz, random_seed=12345, valid_pct=valid_pct, epoch_size=100*bs)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn_gen.unfreeze()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn_gen.fit_one_cycle(number_epochs, pct_start=pct_start, max_lr=slice(1e-7,1e-4))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn_gen.save(pre_gen_name)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 192px" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "bs=8\n", "sz=192\n", "\n", "# Define target number of training/validation samples as well as number of epochs\n", "epoch_train_size = 100 * bs\n", "epoch_valid_size = 10 * bs\n", "valid_pct = epoch_valid_size / data_size\n", "number_epochs = (data_size - epoch_valid_size) // epoch_train_size // 2 # Training is long - we use half of data\n", "\n", "# Log hyper parameters\n", "wandb.config.update({\"Step 3 - batch size\": bs, \"Step 3 - image size\": sz,\n", " \"Step 3 - epoch size\": epoch_train_size, \"Step 3 - number epochs\": number_epochs})" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn_gen.data = get_data(sz=sz, bs=bs, keep_pct=keep_pct)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn_gen.unfreeze()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn_gen.fit_one_cycle(number_epochs, pct_start=pct_start, max_lr=slice(5e-8,5e-5))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn_gen.save(pre_gen_name)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# End logging of current session run\n", "# Note: this is optional and would be automatically triggered when stopping the kernel\n", "wandb.join()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Repeatable GAN Cycle" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### NOTE\n", "Best results so far have been based on repeating the cycle below a few times (about 5-8?), until diminishing returns are hit (no improvement in image quality). Each time you repeat the cycle, you want to increment that old_checkpoint_num by 1 so that new check points don't overwrite the old. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "old_checkpoint_num = 0\n", "checkpoint_num = old_checkpoint_num + 1\n", "gen_old_checkpoint_name = gen_name + '_' + str(old_checkpoint_num)\n", "gen_new_checkpoint_name = gen_name + '_' + str(checkpoint_num)\n", "crit_old_checkpoint_name = crit_name + '_' + str(old_checkpoint_num)\n", "crit_new_checkpoint_name= crit_name + '_' + str(checkpoint_num)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Save Generated Images" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "bs=8\n", "sz=192" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn_gen = gen_learner_wide(data=data_gen, gen_loss=FeatureLoss(), nf_factor=nf_factor).load(gen_old_checkpoint_name, with_opt=False)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "save_gen_images()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Pretrain Critic" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "##### Only need full pretraining of critic when starting from scratch. Otherwise, just finetune!" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "if old_checkpoint_num == 0:\n", " bs=64\n", " sz=128\n", " learn_gen=None\n", " gc.collect()\n", " data_crit = get_crit_data([name_gen, 'test'], bs=bs, sz=sz)\n", " data_crit.show_batch(rows=3, ds_type=DatasetType.Train, imgsize=3)\n", " learn_critic = colorize_crit_learner(data=data_crit, nf=256)\n", " # TODELETE learn_critic.callback_fns.append(partial(LearnerTensorboardWriter, base_dir=TENSORBOARD_PATH, name='CriticPre'))\n", " learn_critic.fit_one_cycle(6, 1e-3)\n", " learn_critic.save(crit_old_checkpoint_name)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "bs=16\n", "sz=192" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data_crit = get_crit_data([name_gen, 'test'], bs=bs, sz=sz)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data_crit.show_batch(rows=3, ds_type=DatasetType.Train, imgsize=3)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn_critic = colorize_crit_learner(data=data_crit, nf=256).load(crit_old_checkpoint_name, with_opt=False)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn_critic.fit_one_cycle(4, 1e-4)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn_critic.save(crit_new_checkpoint_name)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### GAN" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn_crit=None\n", "learn_gen=None\n", "gc.collect()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "lr=2e-5\n", "sz=192\n", "bs=5" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data_crit = get_crit_data([name_gen, 'test'], bs=bs, sz=sz)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn_crit = colorize_crit_learner(data=data_crit, nf=256).load(crit_new_checkpoint_name, with_opt=False)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn_gen = gen_learner_wide(data=data_gen, gen_loss=FeatureLoss(), nf_factor=nf_factor).load(gen_old_checkpoint_name, with_opt=False)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "switcher = partial(AdaptiveGANSwitcher, critic_thresh=0.65)\n", "learn = GANLearner.from_learners(learn_gen, learn_crit, weights_gen=(1.0,1.5), show_img=False, switcher=switcher,\n", " opt_func=partial(optim.Adam, betas=(0.,0.9)), wd=1e-3)\n", "learn.callback_fns.append(partial(GANDiscriminativeLR, mult_lr=5.))\n", "# TODELETE learn.callback_fns.append(partial(GANTensorboardWriter, base_dir=TENSORBOARD_PATH, name='GanLearner', visual_iters=100))\n", "learn.callback_fns.append(partial(GANSaveCallback, learn_gen=learn_gen, filename=gen_new_checkpoint_name, save_iters=100))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Instructions: \n", "Find the checkpoint just before where glitches start to be introduced. This is all very new so you may need to play around with just how far you go here with keep_pct." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.data = get_data(sz=sz, bs=bs, keep_pct=0.03)\n", "learn_gen.freeze_to(-1)\n", "learn.fit(1,lr)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.3" } }, "nbformat": 4, "nbformat_minor": 4 }