Bläddra i källkod

Cleaning up the training notebooks and clarifying with comments

Jason Antic 6 år sedan
förälder
incheckning
3e3207fe92
5 ändrade filer med 236 tillägg och 840 borttagningar
  1. 6 6
      .gitignore
  2. 90 356
      ColorizeTrainingArtistic.ipynb
  3. 80 356
      ColorizeTrainingStable.ipynb
  4. 57 119
      ColorizeTrainingVideo.ipynb
  5. 3 3
      ImageColorizer.ipynb

+ 6 - 6
.gitignore

@@ -2,10 +2,9 @@ data
 fasterai/__pycache__/*.pyc
 *SymbolicLinks.sh
 *.ipynb_checkpoints*
-ColorizeTraining[0-9]*.ipynb
-ColorizeTrainingNew[0-9]*.ipynb
-Colorize[0-9]*.ipynb
-ColorizeVisualization[0-9]*.ipynb
+ColorizeTraining*[0-9]*.ipynb
+*Colorizer[0-9]*.ipynb
+lesson7-superres*.ipynb
 *.pyc
 test.py
 result_images/*.jpg
@@ -24,6 +23,7 @@ test_images/James3.jpg
 test_images/James4.jpg
 test_images/James5.jpg
 test_images/James6.jpg
-fasterai/.ipynb_checkpoints/augs-checkpoint.py
-fasterai/.ipynb_checkpoints/visualize-checkpoint.py
+test_images/image.png
+test_images/image.jpg
+fasterai/.ipynb_checkpoints/*-checkpoint.py
 tmp*

+ 90 - 356
ColorizeTrainingArtistic.ipynb

@@ -4,7 +4,16 @@
    "cell_type": "markdown",
    "metadata": {},
    "source": [
-    "## Pretrained GAN"
+    "## Artistic Model Training"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "#### NOTES:  \n",
+    "* This is \"NoGAN\" based training, described in the DeOldify readme.\n",
+    "* This model prioritizes colorful renderings.  It has higher variation in renderings at different resolutions compared to the \"stable\" model"
    ]
   },
   {
@@ -14,7 +23,7 @@
    "outputs": [],
    "source": [
     "import os\n",
-    "os.environ['CUDA_VISIBLE_DEVICES']='1' "
+    "os.environ['CUDA_VISIBLE_DEVICES']='0' "
    ]
   },
   {
@@ -32,6 +41,7 @@
     "from fasterai.critics import *\n",
     "from fasterai.dataset import *\n",
     "from fasterai.loss import *\n",
+    "from fasterai.save import *\n",
     "from PIL import Image, ImageDraw, ImageFont\n",
     "from PIL import ImageFile"
    ]
@@ -53,8 +63,10 @@
     "path_hr = path\n",
     "path_lr = path/'bandw'\n",
     "\n",
-    "proj_id = 'Artistic2'\n",
+    "proj_id = 'ArtisticModel'\n",
+    "\n",
     "gen_name = proj_id + '_gen'\n",
+    "pre_gen_name = gen_name + '_0'\n",
     "crit_name = proj_id + '_crit'\n",
     "\n",
     "name_gen = proj_id + '_image_gen'\n",
@@ -62,30 +74,8 @@
     "\n",
     "TENSORBOARD_PATH = Path('data/tensorboard/' + proj_id)\n",
     "\n",
-    "nf_factor = 1.50\n",
-    "pct_start = 1e-4"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "def save_all(suffix=''):\n",
-    "    learn_gen.save(gen_name + str(sz) + suffix)\n",
-    "    learn_crit.save(crit_name + str(sz) + suffix)"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "def load_all(suffix=''):\n",
-    "    learn_gen.load(gen_name + str(sz) + suffix, with_opt=False)\n",
-    "    learn_crit.load(crit_name + str(sz) + suffix, with_opt=False)"
+    "nf_factor = 1.5\n",
+    "pct_start = 1e-8"
    ]
   },
   {
@@ -96,42 +86,21 @@
    "source": [
     "def get_data(bs:int, sz:int, keep_pct:float):\n",
     "    return get_colorize_data(sz=sz, bs=bs, crappy_path=path_lr, good_path=path_hr, \n",
-    "                             random_seed=None, keep_pct=keep_pct)"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
+    "                             random_seed=None, keep_pct=keep_pct)\n",
+    "\n",
     "def get_crit_data(classes, bs, sz):\n",
     "    src = ImageList.from_folder(path, include=classes, recurse=True).random_split_by_pct(0.1, seed=42)\n",
     "    ll = src.label_from_folder(classes=classes)\n",
     "    data = (ll.transform(get_transforms(max_zoom=2.), size=sz)\n",
     "           .databunch(bs=bs).normalize(imagenet_stats))\n",
-    "    return data"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "def crappify(fn,i):\n",
+    "    return data\n",
+    "\n",
+    "def create_training_images(fn,i):\n",
     "    dest = path_lr/fn.relative_to(path_hr)\n",
     "    dest.parent.mkdir(parents=True, exist_ok=True)\n",
     "    img = PIL.Image.open(fn).convert('LA').convert('RGB')\n",
-    "    img.save(dest)  "
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
+    "    img.save(dest)  \n",
+    "    \n",
     "def save_preds(dl):\n",
     "    i=0\n",
     "    names = dl.dataset.items\n",
@@ -140,15 +109,8 @@
     "        preds = learn_gen.pred_batch(batch=b, reconstruct=True)\n",
     "        for o in preds:\n",
     "            o.save(path_gen/names[i].name)\n",
-    "            i += 1"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
+    "            i += 1\n",
+    "    \n",
     "def save_gen_images(learn_gen):\n",
     "    if path_gen.exists(): shutil.rmtree(path_gen)\n",
     "    path_gen.mkdir(exist_ok=True)\n",
@@ -161,21 +123,14 @@
    "cell_type": "markdown",
    "metadata": {},
    "source": [
-    "## Crappified data"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
-    "Prepare the input data by crappifying images."
+    "## Create black and white training images"
    ]
   },
   {
    "cell_type": "markdown",
    "metadata": {},
    "source": [
-    "Uncomment the first time you run this notebook."
+    "Only runs if the directory isn't already created."
    ]
   },
   {
@@ -184,29 +139,31 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "#il = ImageItemList.from_folder(path_hr)\n",
-    "#parallel(crappify, il.items)"
+    "if not path_lr.exists():\n",
+    "    il = ImageItemList.from_folder(path_hr)\n",
+    "    parallel(create_training_images, il.items)"
    ]
   },
   {
    "cell_type": "markdown",
    "metadata": {},
    "source": [
-    "# Pre-training"
+    "## Pre-train generator"
    ]
   },
   {
    "cell_type": "markdown",
    "metadata": {},
    "source": [
-    "### Pre-train generator"
+    "#### NOTE\n",
+    "Most of the training takes place here in pretraining for NoGAN.  The goal here is to take the generator as far as possible with conventional training, as that is much easier to control and obtain glitch-free results compared to GAN training."
    ]
   },
   {
    "cell_type": "markdown",
    "metadata": {},
    "source": [
-    "Now let's pretrain the generator."
+    "### 64px"
    ]
   },
   {
@@ -262,16 +219,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "learn_gen.save(gen_name)"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "learn_gen.load(gen_name, with_opt=False)"
+    "learn_gen.save(pre_gen_name)"
    ]
   },
   {
@@ -298,16 +246,14 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "learn_gen.save(gen_name)"
+    "learn_gen.save(pre_gen_name)"
    ]
   },
   {
-   "cell_type": "code",
-   "execution_count": null,
+   "cell_type": "markdown",
    "metadata": {},
-   "outputs": [],
    "source": [
-    "learn_gen.load(gen_name)"
+    "### 128px"
    ]
   },
   {
@@ -354,7 +300,14 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "learn_gen.save(gen_name)"
+    "learn_gen.save(pre_gen_name)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### 192px"
    ]
   },
   {
@@ -401,84 +354,22 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "learn_gen.save(gen_name)"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
-    "### Save generated images"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "save_gen_images(gen_name)"
+    "learn_gen.save(pre_gen_name)"
    ]
   },
   {
    "cell_type": "markdown",
    "metadata": {},
    "source": [
-    "### Train critic"
+    "## Repeatable GAN Cycle"
    ]
   },
   {
    "cell_type": "markdown",
    "metadata": {},
    "source": [
-    "Pretrain the critic on crappy vs not crappy."
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "bs=64\n",
-    "sz=128"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "learn_gen=None\n",
-    "gc.collect()"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "data_crit = get_crit_data([name_gen, 'test'], bs=bs, sz=sz)"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "data_crit.show_batch(rows=3, ds_type=DatasetType.Train, imgsize=3)"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "learn_critic = colorize_crit_learner(data=data_crit, nf=256)"
+    "#### NOTE\n",
+    "Best results so far have been based on repeating the cycle below a few times (about 5-8?), until diminishing returns are hit (no improvement in image quality).  Each time you repeat the cycle, you want to increment that old_checkpoint_num by 1 so that new check points don't overwrite the old.  "
    ]
   },
   {
@@ -487,25 +378,19 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "learn_critic.callback_fns.append(partial(LearnerTensorboardWriter, base_dir=TENSORBOARD_PATH, name='CriticPre'))"
+    "old_checkpoint_num = 0\n",
+    "checkpoint_num = old_checkpoint_num + 1\n",
+    "gen_old_checkpoint_name = gen_name + '_' + str(old_checkpoint_num)\n",
+    "gen_new_checkpoint_name = gen_name + '_' + str(checkpoint_num)\n",
+    "crit_old_checkpoint_name = crit_name + '_' + str(old_checkpoint_num)\n",
+    "crit_new_checkpoint_name= crit_name + '_' + str(checkpoint_num)"
    ]
   },
   {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "learn_critic.fit_one_cycle(6, 1e-3)"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
+   "cell_type": "markdown",
    "metadata": {},
-   "outputs": [],
    "source": [
-    "learn_critic.save(crit_name)"
+    "### Save Generated Images"
    ]
   },
   {
@@ -514,7 +399,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "bs=16\n",
+    "bs=8\n",
     "sz=192"
    ]
   },
@@ -524,7 +409,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "learn_critic.data=get_crit_data([name_gen, 'test'], bs=bs, sz=sz)"
+    "learn_gen = gen_learner_deep(data=data_gen, gen_loss=FeatureLoss(), nf_factor=nf_factor).load(gen_old_checkpoint_name, with_opt=False)"
    ]
   },
   {
@@ -533,150 +418,21 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "learn_critic.fit_one_cycle(4, 1e-4)"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "learn_critic.save(crit_name)"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
-    "## GAN"
+    "save_gen_images(gen_name)"
    ]
   },
   {
    "cell_type": "markdown",
    "metadata": {},
    "source": [
-    "Now we'll combine those pretrained model in a GAN."
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "learn_crit=None\n",
-    "learn_gen=None\n",
-    "gc.collect()"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "lr=1e-5\n",
-    "sz=192\n",
-    "bs=9"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "#placeholder- not actually used\n",
-    "data_crit = get_crit_data([name_gen, 'test'], bs=bs, sz=sz)"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "learn_crit = colorize_crit_learner(data=data_crit, nf=256).load(crit_name, with_opt=False)"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "learn_gen = gen_learner_deep(data=data_gen, gen_loss=FeatureLoss(), nf_factor=nf_factor).load(gen_name, with_opt=False)"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "switcher = partial(AdaptiveGANSwitcher, critic_thresh=0.65)\n",
-    "learn = GANLearner.from_learners(learn_gen, learn_crit, weights_gen=(1.0,2.0), show_img=False, switcher=switcher,\n",
-    "                                 opt_func=partial(optim.Adam, betas=(0.,0.9)), wd=1e-3)\n",
-    "learn.callback_fns.append(partial(GANDiscriminativeLR, mult_lr=5.))\n",
-    "learn.callback_fns.append(partial(GANTensorboardWriter, base_dir=TENSORBOARD_PATH, name='GanLearner', visual_iters=100))"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "for i in range(1,101):\n",
-    "    learn.data = get_data(sz=sz, bs=bs, keep_pct=0.001)\n",
-    "    learn_gen.freeze_to(-1)\n",
-    "    learn.fit(1,lr)\n",
-    "    save_all('_01_' + str(i))"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "save_all('_01')"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "learn.show_results(rows=bs)"
+    "### Pretrain Critic"
    ]
   },
   {
    "cell_type": "markdown",
    "metadata": {},
    "source": [
-    "### Save Generated Images Again"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "bs=12\n",
-    "sz=192"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "learn_gen = gen_learner_deep(data=data_gen, gen_loss=FeatureLoss(), nf_factor=nf_factor).load('ColorizeNew79_gen192_06_40', with_opt=False)"
+    "##### Only need full pretraining of critic when starting from scratch.  Otherwise, just finetune!"
    ]
   },
   {
@@ -685,14 +441,17 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "save_gen_images(gen_name)"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
-    "### Train Critic Again"
+    "if old_checkpoint_num == 0:\n",
+    "    bs=64\n",
+    "    sz=128\n",
+    "    learn_gen=None\n",
+    "    gc.collect()\n",
+    "    data_crit = get_crit_data([name_gen, 'test'], bs=bs, sz=sz)\n",
+    "    data_crit.show_batch(rows=3, ds_type=DatasetType.Train, imgsize=3)\n",
+    "    learn_critic = colorize_crit_learner(data=data_crit, nf=256)\n",
+    "    learn_critic.callback_fns.append(partial(LearnerTensorboardWriter, base_dir=TENSORBOARD_PATH, name='CriticPre'))\n",
+    "    learn_critic.fit_one_cycle(6, 1e-3)\n",
+    "    learn_critic.save(crit_old_checkpoint_name)"
    ]
   },
   {
@@ -705,16 +464,6 @@
     "sz=192"
    ]
   },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "learn_gen=None\n",
-    "gc.collect()"
-   ]
-  },
   {
    "cell_type": "code",
    "execution_count": null,
@@ -739,7 +488,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "learn_critic = colorize_crit_learner(data=data_crit, nf=256).load(crit_name + '6', with_opt=False)"
+    "learn_critic = colorize_crit_learner(data=data_crit, nf=256).load(crit_old_checkpoint_name, with_opt=False)"
    ]
   },
   {
@@ -766,14 +515,14 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "learn_critic.save(crit_name + '7')"
+    "learn_critic.save(crit_new_checkpoint_name)"
    ]
   },
   {
    "cell_type": "markdown",
    "metadata": {},
    "source": [
-    "### GAN Again"
+    "### GAN"
    ]
   },
   {
@@ -813,8 +562,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "#learn_crit = colorize_crit_learner(data=data_crit, nf=256).load(crit_name + '7', with_opt=False)\n",
-    "learn_crit = colorize_crit_learner(data=data_crit, nf=256).load('ColorizeNew79_crit192_07_100', with_opt=False)"
+    "learn_crit = colorize_crit_learner(data=data_crit, nf=256).load(crit_new_checkpoint_name, with_opt=False)"
    ]
   },
   {
@@ -823,7 +571,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "learn_gen = gen_learner_deep(data=data_gen, gen_loss=FeatureLoss(), nf_factor=nf_factor).load('ColorizeNew79_gen192_07_100', with_opt=False)"
+    "learn_gen = gen_learner_deep(data=data_gen, gen_loss=FeatureLoss(), nf_factor=nf_factor).load(gen_old_checkpoint_name, with_opt=False)"
    ]
   },
   {
@@ -836,27 +584,16 @@
     "learn = GANLearner.from_learners(learn_gen, learn_crit, weights_gen=(1.0,2.0), show_img=False, switcher=switcher,\n",
     "                                 opt_func=partial(optim.Adam, betas=(0.,0.9)), wd=1e-3)\n",
     "learn.callback_fns.append(partial(GANDiscriminativeLR, mult_lr=5.))\n",
-    "learn.callback_fns.append(partial(GANTensorboardWriter, base_dir=TENSORBOARD_PATH, name='GanLearner', visual_iters=100))"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "for i in range(1,101):\n",
-    "    learn.data = get_data(sz=sz, bs=bs, keep_pct=0.001)\n",
-    "    learn_gen.freeze_to(-1)\n",
-    "    learn.fit(1,lr)\n",
-    "    save_all('_07_' + str(i))"
+    "learn.callback_fns.append(partial(GANTensorboardWriter, base_dir=TENSORBOARD_PATH, name='GanLearner', visual_iters=100))\n",
+    "learn.callback_fns.append(partial(GANSaveCallback, learn_gen=learn_gen, filename=gen_new_checkpoint_name, save_iters=100))"
    ]
   },
   {
    "cell_type": "markdown",
    "metadata": {},
    "source": [
-    "## fin"
+    "#### Instructions:  \n",
+    "Find the checkpoint just before where glitches start to be introduced.  This is all very new so you may need to play around with just how far you go here with keep_pct."
    ]
   },
   {
@@ -864,14 +601,11 @@
    "execution_count": null,
    "metadata": {},
    "outputs": [],
-   "source": []
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": []
+   "source": [
+    "learn.data = get_data(sz=sz, bs=bs, keep_pct=0.03)\n",
+    "learn_gen.freeze_to(-1)\n",
+    "learn.fit(1,lr)"
+   ]
   }
  ],
  "metadata": {
@@ -890,7 +624,7 @@
    "name": "python",
    "nbconvert_exporter": "python",
    "pygments_lexer": "ipython3",
-   "version": "3.7.3"
+   "version": "3.7.0"
   }
  },
  "nbformat": 4,

+ 80 - 356
ColorizeTrainingStable.ipynb

@@ -4,7 +4,16 @@
    "cell_type": "markdown",
    "metadata": {},
    "source": [
-    "## Pretrained GAN"
+    "## Stable Model Training"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "#### NOTES:  \n",
+    "* This is \"NoGAN\" based training, described in the DeOldify readme.\n",
+    "* This model prioritizes stable and reliable renderings.  It does particularly well on portraits and landscapes.  It's not as colorful as the artistic model."
    ]
   },
   {
@@ -32,6 +41,7 @@
     "from fasterai.critics import *\n",
     "from fasterai.dataset import *\n",
     "from fasterai.loss import *\n",
+    "from fasterai.save import *\n",
     "from PIL import Image, ImageDraw, ImageFont\n",
     "from PIL import ImageFile"
    ]
@@ -53,8 +63,10 @@
     "path_hr = path\n",
     "path_lr = path/'bandw'\n",
     "\n",
-    "proj_id = 'Stable'\n",
+    "proj_id = 'StableModel'\n",
+    "\n",
     "gen_name = proj_id + '_gen'\n",
+    "pre_gen_name = gen_name + '_0'\n",
     "crit_name = proj_id + '_crit'\n",
     "\n",
     "name_gen = proj_id + '_image_gen'\n",
@@ -63,29 +75,7 @@
     "TENSORBOARD_PATH = Path('data/tensorboard/' + proj_id)\n",
     "\n",
     "nf_factor = 2\n",
-    "pct_start = 1e-4"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "def save_all(suffix=''):\n",
-    "    learn_gen.save(gen_name + str(sz) + suffix)\n",
-    "    learn_crit.save(crit_name + str(sz) + suffix)"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "def load_all(suffix=''):\n",
-    "    learn_gen.load(gen_name + str(sz) + suffix, with_opt=False)\n",
-    "    learn_crit.load(crit_name + str(sz) + suffix, with_opt=False)"
+    "pct_start = 1e-8"
    ]
   },
   {
@@ -96,42 +86,21 @@
    "source": [
     "def get_data(bs:int, sz:int, keep_pct:float):\n",
     "    return get_colorize_data(sz=sz, bs=bs, crappy_path=path_lr, good_path=path_hr, \n",
-    "                             random_seed=None, keep_pct=keep_pct)"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
+    "                             random_seed=None, keep_pct=keep_pct)\n",
+    "\n",
     "def get_crit_data(classes, bs, sz):\n",
     "    src = ImageList.from_folder(path, include=classes, recurse=True).random_split_by_pct(0.1, seed=42)\n",
     "    ll = src.label_from_folder(classes=classes)\n",
     "    data = (ll.transform(get_transforms(max_zoom=2.), size=sz)\n",
     "           .databunch(bs=bs).normalize(imagenet_stats))\n",
-    "    return data"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "def crappify(fn,i):\n",
+    "    return data\n",
+    "\n",
+    "def create_training_images(fn,i):\n",
     "    dest = path_lr/fn.relative_to(path_hr)\n",
     "    dest.parent.mkdir(parents=True, exist_ok=True)\n",
     "    img = PIL.Image.open(fn).convert('LA').convert('RGB')\n",
-    "    img.save(dest)  "
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
+    "    img.save(dest)  \n",
+    "    \n",
     "def save_preds(dl):\n",
     "    i=0\n",
     "    names = dl.dataset.items\n",
@@ -140,15 +109,8 @@
     "        preds = learn_gen.pred_batch(batch=b, reconstruct=True)\n",
     "        for o in preds:\n",
     "            o.save(path_gen/names[i].name)\n",
-    "            i += 1"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
+    "            i += 1\n",
+    "    \n",
     "def save_gen_images(learn_gen):\n",
     "    if path_gen.exists(): shutil.rmtree(path_gen)\n",
     "    path_gen.mkdir(exist_ok=True)\n",
@@ -161,21 +123,14 @@
    "cell_type": "markdown",
    "metadata": {},
    "source": [
-    "## Crappified data"
+    "## Create black and white training images"
    ]
   },
   {
    "cell_type": "markdown",
    "metadata": {},
    "source": [
-    "Prepare the input data by crappifying images."
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
-    "Uncomment the first time you run this notebook."
+    "Only runs if the directory isn't already created."
    ]
   },
   {
@@ -184,29 +139,31 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "#il = ImageItemList.from_folder(path_hr)\n",
-    "#parallel(crappify, il.items)"
+    "if not path_lr.exists():\n",
+    "    il = ImageItemList.from_folder(path_hr)\n",
+    "    parallel(create_training_images, il.items)"
    ]
   },
   {
    "cell_type": "markdown",
    "metadata": {},
    "source": [
-    "# Pre-training"
+    "## Pre-train generator"
    ]
   },
   {
    "cell_type": "markdown",
    "metadata": {},
    "source": [
-    "### Pre-train generator"
+    "#### NOTE\n",
+    "Most of the training takes place here in pretraining for NoGAN.  The goal here is to take the generator as far as possible with conventional training, as that is much easier to control and obtain glitch-free results compared to GAN training."
    ]
   },
   {
    "cell_type": "markdown",
    "metadata": {},
    "source": [
-    "Now let's pretrain the generator."
+    "### 64px"
    ]
   },
   {
@@ -262,7 +219,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "learn_gen.save(gen_name)"
+    "learn_gen.save(pre_gen_name)"
    ]
   },
   {
@@ -271,7 +228,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "learn_gen.load(gen_name, with_opt=False)"
+    "learn_gen.unfreeze()"
    ]
   },
   {
@@ -280,7 +237,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "learn_gen.unfreeze()"
+    "learn_gen.fit_one_cycle(1, pct_start=pct_start,  max_lr=slice(3e-7, 3e-4))"
    ]
   },
   {
@@ -289,16 +246,14 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "learn_gen.fit_one_cycle(1, pct_start=pct_start,  max_lr=slice(3e-7, 3e-4))"
+    "learn_gen.save(pre_gen_name)"
    ]
   },
   {
-   "cell_type": "code",
-   "execution_count": null,
+   "cell_type": "markdown",
    "metadata": {},
-   "outputs": [],
    "source": [
-    "learn_gen.save(gen_name)"
+    "### 128px"
    ]
   },
   {
@@ -330,15 +285,6 @@
     "learn_gen.unfreeze()"
    ]
   },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "learn_gen.load(gen_name, with_opt=False)"
-   ]
-  },
   {
    "cell_type": "code",
    "execution_count": null,
@@ -354,16 +300,14 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "learn_gen.save(gen_name)"
+    "learn_gen.save(pre_gen_name)"
    ]
   },
   {
-   "cell_type": "code",
-   "execution_count": null,
+   "cell_type": "markdown",
    "metadata": {},
-   "outputs": [],
    "source": [
-    "learn_gen.load(gen_name, with_opt=False)"
+    "### 192px"
    ]
   },
   {
@@ -410,75 +354,22 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "learn_gen.save(gen_name)"
+    "learn_gen.save(pre_gen_name)"
    ]
   },
   {
    "cell_type": "markdown",
    "metadata": {},
    "source": [
-    "### Save generated images"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "save_gen_images(gen_name)"
+    "## Repeatable GAN Cycle"
    ]
   },
   {
    "cell_type": "markdown",
    "metadata": {},
    "source": [
-    "### Train critic"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
-    "Pretrain the critic on crappy vs not crappy."
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "bs=64\n",
-    "sz=128"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "learn_gen=None\n",
-    "gc.collect()"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "data_crit = get_crit_data([name_gen, 'test'], bs=bs, sz=sz)"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "data_crit.show_batch(rows=3, ds_type=DatasetType.Train, imgsize=3)"
+    "#### NOTE\n",
+    "Best results so far have been based on repeating the cycle below a few times (about 5-8?), until diminishing returns are hit (no improvement in image quality).  Each time you repeat the cycle, you want to increment that old_checkpoint_num by 1 so that new check points don't overwrite the old.  "
    ]
   },
   {
@@ -487,117 +378,19 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "learn_critic = colorize_crit_learner(data=data_crit, nf=256)"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "learn_critic.callback_fns.append(partial(LearnerTensorboardWriter, base_dir=TENSORBOARD_PATH, name='CriticPre'))"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "learn_critic.fit_one_cycle(6, 1e-3)"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "learn_critic.save(crit_name + '1')"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "bs=16\n",
-    "sz=192"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "learn_critic.data=get_crit_data([name_gen, 'test'], bs=bs, sz=sz)"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "learn_critic.fit_one_cycle(4, 1e-4)"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "learn_critic.save(crit_name + '1')"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
-    "## GAN"
+    "old_checkpoint_num = 0\n",
+    "checkpoint_num = old_checkpoint_num + 1\n",
+    "gen_old_checkpoint_name = gen_name + '_' + str(old_checkpoint_num)\n",
+    "gen_new_checkpoint_name = gen_name + '_' + str(checkpoint_num)\n",
+    "crit_old_checkpoint_name = crit_name + '_' + str(old_checkpoint_num)\n",
+    "crit_new_checkpoint_name= crit_name + '_' + str(checkpoint_num)"
    ]
   },
   {
    "cell_type": "markdown",
    "metadata": {},
    "source": [
-    "Now we'll combine those pretrained model in a GAN."
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "learn_crit=None\n",
-    "learn_gen=None\n",
-    "gc.collect()"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "lr=2e-5\n",
-    "sz=192\n",
-    "bs=5"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "#placeholder- not actually used\n",
-    "data_crit = get_crit_data([name_gen, 'test'], bs=bs, sz=sz)"
+    "### Save Generated Images"
    ]
   },
   {
@@ -606,16 +399,8 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "learn_crit = colorize_crit_learner(data=data_crit, nf=256).load(crit_name + '1', with_opt=False)"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "learn_gen = gen_learner_wide(data=data_gen, gen_loss=FeatureLoss(), nf_factor=nf_factor).load(gen_name, with_opt=False)"
+    "bs=8\n",
+    "sz=192"
    ]
   },
   {
@@ -624,11 +409,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "switcher = partial(AdaptiveGANSwitcher, critic_thresh=0.65)\n",
-    "learn = GANLearner.from_learners(learn_gen, learn_crit, weights_gen=(1.0,1.5), show_img=False, switcher=switcher,\n",
-    "                                 opt_func=partial(optim.Adam, betas=(0.,0.9)), wd=1e-3)\n",
-    "learn.callback_fns.append(partial(GANDiscriminativeLR, mult_lr=5.))\n",
-    "learn.callback_fns.append(partial(GANTensorboardWriter, base_dir=TENSORBOARD_PATH, name='GanLearner', visual_iters=100))"
+    "learn_gen = gen_learner_wide(data=data_gen, gen_loss=FeatureLoss(), nf_factor=nf_factor).load(gen_old_checkpoint_name, with_opt=False)"
    ]
   },
   {
@@ -637,38 +418,21 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "for i in range(1,101):\n",
-    "    learn.data = get_data(sz=sz, bs=bs, keep_pct=0.001)\n",
-    "    learn_gen.freeze_to(-1)\n",
-    "    learn.fit(1,lr)\n",
-    "    save_all('_1_' + str(i))"
+    "save_gen_images(gen_name)"
    ]
   },
   {
    "cell_type": "markdown",
    "metadata": {},
    "source": [
-    "## Repeat Pretrain-GAN Cycle"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "old_checkpoint_num = 5\n",
-    "checkpoint_num = old_checkpoint_num + 1\n",
-    "gen_old_checkpoint_name = 'ColorizeNew73_gen192_5_7'\n",
-    "crit_old_checkpoint_name = crit_name + str(old_checkpoint_num)\n",
-    "crit_new_checkpoint_name= crit_name + str(checkpoint_num)"
+    "### Pretrain Critic"
    ]
   },
   {
    "cell_type": "markdown",
    "metadata": {},
    "source": [
-    "### Save Generated Images Again"
+    "##### Only need full pretraining of critic when starting from scratch.  Otherwise, just finetune!"
    ]
   },
   {
@@ -677,33 +441,17 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "bs=8\n",
-    "sz=192"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "learn_gen = gen_learner_wide(data=data_gen, gen_loss=FeatureLoss(), nf_factor=nf_factor).load(gen_old_checkpoint_name, with_opt=False)"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "save_gen_images(gen_name)"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
-    "### Train Critic Again"
+    "if old_checkpoint_num == 0:\n",
+    "    bs=64\n",
+    "    sz=128\n",
+    "    learn_gen=None\n",
+    "    gc.collect()\n",
+    "    data_crit = get_crit_data([name_gen, 'test'], bs=bs, sz=sz)\n",
+    "    data_crit.show_batch(rows=3, ds_type=DatasetType.Train, imgsize=3)\n",
+    "    learn_critic = colorize_crit_learner(data=data_crit, nf=256)\n",
+    "    learn_critic.callback_fns.append(partial(LearnerTensorboardWriter, base_dir=TENSORBOARD_PATH, name='CriticPre'))\n",
+    "    learn_critic.fit_one_cycle(6, 1e-3)\n",
+    "    learn_critic.save(crit_old_checkpoint_name)"
    ]
   },
   {
@@ -716,16 +464,6 @@
     "sz=192"
    ]
   },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "learn_gen=None\n",
-    "gc.collect()"
-   ]
-  },
   {
    "cell_type": "code",
    "execution_count": null,
@@ -784,7 +522,7 @@
    "cell_type": "markdown",
    "metadata": {},
    "source": [
-    "### GAN Again"
+    "### GAN"
    ]
   },
   {
@@ -846,27 +584,16 @@
     "learn = GANLearner.from_learners(learn_gen, learn_crit, weights_gen=(1.0,1.5), show_img=False, switcher=switcher,\n",
     "                                 opt_func=partial(optim.Adam, betas=(0.,0.9)), wd=1e-3)\n",
     "learn.callback_fns.append(partial(GANDiscriminativeLR, mult_lr=5.))\n",
-    "learn.callback_fns.append(partial(GANTensorboardWriter, base_dir=TENSORBOARD_PATH, name='GanLearner', visual_iters=100))"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "for i in range(1,101):\n",
-    "    learn.data = get_data(sz=sz, bs=bs, keep_pct=0.001)\n",
-    "    learn_gen.freeze_to(-1)\n",
-    "    learn.fit(1,lr)\n",
-    "    save_all('_' + str(checkpoint_num) '_' + str(i))"
+    "learn.callback_fns.append(partial(GANTensorboardWriter, base_dir=TENSORBOARD_PATH, name='GanLearner', visual_iters=100))\n",
+    "learn.callback_fns.append(partial(GANSaveCallback, learn_gen=learn_gen, filename=gen_new_checkpoint_name, save_iters=100))"
    ]
   },
   {
    "cell_type": "markdown",
    "metadata": {},
    "source": [
-    "## fin"
+    "#### Instructions:  \n",
+    "Find the checkpoint just before where glitches start to be introduced.  This is all very new so you may need to play around with just how far you go here with keep_pct."
    ]
   },
   {
@@ -874,14 +601,11 @@
    "execution_count": null,
    "metadata": {},
    "outputs": [],
-   "source": []
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": []
+   "source": [
+    "learn.data = get_data(sz=sz, bs=bs, keep_pct=0.03)\n",
+    "learn_gen.freeze_to(-1)\n",
+    "learn.fit(1,lr)"
+   ]
   }
  ],
  "metadata": {

+ 57 - 119
ColorizeTrainingVideo.ipynb

@@ -4,7 +4,16 @@
    "cell_type": "markdown",
    "metadata": {},
    "source": [
-    "## Pretrained GAN"
+    "## Video Model Training"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "#### NOTES:  \n",
+    "* It's assumed that there's a pretrained generator from the ColorizeTrainingStable notebook available at the specified path.\n",
+    "* This is \"NoGAN\" based training, described in the DeOldify readme."
    ]
   },
   {
@@ -32,6 +41,7 @@
     "from fasterai.critics import *\n",
     "from fasterai.dataset import *\n",
     "from fasterai.loss import *\n",
+    "from fasterai.save import *\n",
     "from PIL import Image, ImageDraw, ImageFont\n",
     "from PIL import ImageFile"
    ]
@@ -53,8 +63,9 @@
     "path_hr = path\n",
     "path_lr = path/'bandw'\n",
     "\n",
-    "proj_id = 'WideNoise4'\n",
+    "proj_id = 'VideoModel'\n",
     "gen_name = proj_id + '_gen'\n",
+    "pre_gen_name = gen_name + '_0'\n",
     "crit_name = proj_id + '_crit'\n",
     "\n",
     "name_gen = proj_id + '_image_gen'\n",
@@ -63,29 +74,8 @@
     "TENSORBOARD_PATH = Path('data/tensorboard/' + proj_id)\n",
     "\n",
     "nf_factor = 2\n",
-    "xtra_tfms=[noisify(p=0.8)]"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "def save_all(suffix=''):\n",
-    "    learn_gen.save(gen_name + str(sz) + suffix)\n",
-    "    learn_crit.save(crit_name + str(sz) + suffix)"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "def load_all(suffix=''):\n",
-    "    learn_gen.load(gen_name + str(sz) + suffix, with_opt=False)\n",
-    "    learn_crit.load(crit_name + str(sz) + suffix, with_opt=False)"
+    "xtra_tfms=[noisify(p=0.8)]\n",
+    "pct_start = 1e-8"
    ]
   },
   {
@@ -96,42 +86,15 @@
    "source": [
     "def get_data(bs:int, sz:int, keep_pct:float):\n",
     "    return get_colorize_data(sz=sz, bs=bs, crappy_path=path_lr, good_path=path_hr, \n",
-    "                             random_seed=None, keep_pct=keep_pct, xtra_tfms=xtra_tfms)"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
+    "                             random_seed=None, keep_pct=keep_pct, xtra_tfms=xtra_tfms)\n",
+    "\n",
     "def get_crit_data(classes, bs, sz):\n",
     "    src = ImageList.from_folder(path, include=classes, recurse=True).random_split_by_pct(0.1, seed=42)\n",
     "    ll = src.label_from_folder(classes=classes)\n",
     "    data = (ll.transform(get_transforms(max_zoom=2.), size=sz)\n",
     "           .databunch(bs=bs).normalize(imagenet_stats))\n",
-    "    return data"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "def crappify(fn,i):\n",
-    "    dest = path_lr/fn.relative_to(path_hr)\n",
-    "    dest.parent.mkdir(parents=True, exist_ok=True)\n",
-    "    img = PIL.Image.open(fn).convert('LA').convert('RGB')\n",
-    "    img.save(dest)  "
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
+    "    return data\n",
+    "    \n",
     "def save_preds(dl):\n",
     "    i=0\n",
     "    names = dl.dataset.items\n",
@@ -140,15 +103,8 @@
     "        preds = learn_gen.pred_batch(batch=b, reconstruct=True)\n",
     "        for o in preds:\n",
     "            o.save(path_gen/names[i].name)\n",
-    "            i += 1"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
+    "            i += 1\n",
+    "            \n",
     "def save_gen_images(learn_gen):\n",
     "    if path_gen.exists(): shutil.rmtree(path_gen)\n",
     "    path_gen.mkdir(exist_ok=True)\n",
@@ -161,7 +117,14 @@
    "cell_type": "markdown",
    "metadata": {},
    "source": [
-    "## Finetune Generator With Noise."
+    "## Finetune Generator With Noise Augmented Images."
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "##### This helps the generator better deal with noisy/grainy video (which is pretty normal)."
    ]
   },
   {
@@ -208,7 +171,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "learn_gen = learn_gen.load(gen_name, with_opt=False)"
+    "learn_gen = learn_gen.load(pre_gen_name, with_opt=False)"
    ]
   },
   {
@@ -226,7 +189,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "learn_gen.fit_one_cycle(1, pct_start=0.01, max_lr=slice(5e-8,5e-5))"
+    "learn_gen.fit_one_cycle(1, pct_start=pct_start, max_lr=slice(5e-8,5e-5))"
    ]
   },
   {
@@ -235,14 +198,22 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "learn_gen.save(gen_name)"
+    "learn_gen.save(pre_gen_name)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Repeatable GAN Cycle"
    ]
   },
   {
    "cell_type": "markdown",
    "metadata": {},
    "source": [
-    "## Repeat Pretrain-GAN Cycle"
+    "#### NOTE\n",
+    "Best results so far have been based only doing a single run of the cells below (otherwise glitches are introduced that are visible in video).  "
    ]
   },
   {
@@ -251,18 +222,19 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "old_checkpoint_num = 1\n",
+    "old_checkpoint_num = 0\n",
     "checkpoint_num = old_checkpoint_num + 1\n",
-    "gen_old_checkpoint_name = 'WideNoise4_gen'\n",
-    "crit_old_checkpoint_name = crit_name + str(old_checkpoint_num)\n",
-    "crit_new_checkpoint_name= crit_name + str(checkpoint_num)"
+    "gen_old_checkpoint_name = gen_name + '_' + str(old_checkpoint_num)\n",
+    "gen_new_checkpoint_name = gen_name + '_' + str(checkpoint_num)\n",
+    "crit_old_checkpoint_name = crit_name + '_' + str(old_checkpoint_num)\n",
+    "crit_new_checkpoint_name= crit_name + '_' + str(checkpoint_num)"
    ]
   },
   {
    "cell_type": "markdown",
    "metadata": {},
    "source": [
-    "### Save Generated Images Again"
+    "### Save Generated Images"
    ]
   },
   {
@@ -297,7 +269,7 @@
    "cell_type": "markdown",
    "metadata": {},
    "source": [
-    "### Train Critic Again"
+    "### Pretrain Critic"
    ]
   },
   {
@@ -378,7 +350,7 @@
    "cell_type": "markdown",
    "metadata": {},
    "source": [
-    "### GAN Again"
+    "### GAN"
    ]
   },
   {
@@ -440,27 +412,16 @@
     "learn = GANLearner.from_learners(learn_gen, learn_crit, weights_gen=(1.0,1.5), show_img=False, switcher=switcher,\n",
     "                                 opt_func=partial(optim.Adam, betas=(0.,0.9)), wd=1e-3)\n",
     "learn.callback_fns.append(partial(GANDiscriminativeLR, mult_lr=5.))\n",
-    "learn.callback_fns.append(partial(GANTensorboardWriter, base_dir=TENSORBOARD_PATH, name='GanLearner', visual_iters=100, stats_iters=10, loss_iters=1))"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "learn.data = get_data(sz=sz, bs=bs, keep_pct=0.03)\n",
-    "learn_gen.freeze_to(-1)\n",
-    "learn.fit(1,lr)"
+    "learn.callback_fns.append(partial(GANTensorboardWriter, base_dir=TENSORBOARD_PATH, name='GanLearner', visual_iters=100, stats_iters=10, loss_iters=1))\n",
+    "learn.callback_fns.append(partial(GANSaveCallback, learn_gen=learn_gen, filename=gen_new_checkpoint_name, save_iters=100))"
    ]
   },
   {
-   "cell_type": "code",
-   "execution_count": null,
+   "cell_type": "markdown",
    "metadata": {},
-   "outputs": [],
    "source": [
-    "save_all('_' + str(checkpoint_num) + '_' + 'derp')"
+    "#### Instructions:  \n",
+    "Find the checkpoint just before where glitches start to be introduced.  So far this has been found at the point of iterating through 1.4% of the data when using learning rate of 1e-5, and at 2.2% of the data for 5e-6."
    ]
   },
   {
@@ -469,33 +430,10 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "#for i in range(1,31):\n",
-    "    #learn.data = get_data(sz=sz, bs=bs, keep_pct=0.001)\n",
-    "    #learn_gen.freeze_to(-1)\n",
-    "    #learn.fit(1,lr)\n",
-    "    #save_all('_' + str(checkpoint_num) + '_' + str(i))"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
-    "## fin"
+    "learn.data = get_data(sz=sz, bs=bs, keep_pct=0.03)\n",
+    "learn_gen.freeze_to(-1)\n",
+    "learn.fit(1,lr)"
    ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": []
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": []
   }
  ],
  "metadata": {

+ 3 - 3
ImageColorizer.ipynb

@@ -84,7 +84,7 @@
     "#NOTE:  Max is 45 with 11GB video cards. 35 is a good default\n",
     "render_factor=35\n",
     "#NOTE:  Make source_url None to just read from file at ./video/source/[file_name] directly without modification\n",
-    "source_url='https://i.imgur.com/NajrX6Z.jpg'\n",
+    "source_url='https://i.redd.it/4k2pz0e9yts21.jpg'\n",
     "source_path = 'test_images/image.png'\n",
     "result_path = None\n",
     "\n",
@@ -109,8 +109,8 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "for i in range(10,46,2):\n",
-    "    colorizer.plot_transformed_image(source_path, render_factor=i, display_render_factor=True, figsize=(8,8))"
+    "for i in range(10,46):\n",
+    "    colorizer.plot_transformed_image(source_path, render_factor=i, display_render_factor=True, figsize=(10,10))"
    ]
   },
   {