ソースを参照

Commiting changes I should have committed ages ago. Doh!

Jason Antic 6 年 前
コミット
7b5e205ef4
7 ファイル変更1122 行追加124 行削除
  1. 914 0
      ColorizeTrainingWide.ipynb
  2. 7 11
      ImageColorizer.ipynb
  3. 6 6
      README.md
  4. 76 78
      VideoColorizer.ipynb
  5. 13 13
      fasterai/generators.py
  6. 7 7
      fasterai/unet.py
  7. 99 9
      fasterai/visualize.py

+ 914 - 0
ColorizeTrainingWide.ipynb

@@ -0,0 +1,914 @@
+{
+ "cells": [
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Pretrained GAN"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import os\n",
+    "os.environ['CUDA_VISIBLE_DEVICES']='0' "
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import fastai\n",
+    "from fastai import *\n",
+    "from fastai.vision import *\n",
+    "from fastai.callbacks.tensorboard import *\n",
+    "from fastai.vision.gan import *\n",
+    "from fasterai.generators import *\n",
+    "from fasterai.critics import *\n",
+    "from fasterai.dataset import *\n",
+    "from fasterai.loss import *\n",
+    "from PIL import Image, ImageDraw, ImageFont\n",
+    "from PIL import ImageFile"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Setup"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "path = Path('data/imagenet/ILSVRC/Data/CLS-LOC')\n",
+    "path_hr = path\n",
+    "path_lr = path/'bandw'\n",
+    "\n",
+    "proj_id = 'ColorizeNew73'\n",
+    "gen_name = proj_id + '_gen'\n",
+    "crit_name = proj_id + '_crit'\n",
+    "\n",
+    "name_gen = proj_id + '_image_gen'\n",
+    "path_gen = path/name_gen\n",
+    "\n",
+    "TENSORBOARD_PATH = Path('data/tensorboard/' + proj_id)\n",
+    "\n",
+    "nf_factor = 2"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def save_all(suffix=''):\n",
+    "    learn_gen.save(gen_name + str(sz) + suffix)\n",
+    "    learn_crit.save(crit_name + str(sz) + suffix)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def load_all(suffix=''):\n",
+    "    learn_gen.load(gen_name + str(sz) + suffix, with_opt=False)\n",
+    "    learn_crit.load(crit_name + str(sz) + suffix, with_opt=False)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def get_data(bs:int, sz:int, keep_pct:float):\n",
+    "    return get_colorize_data(sz=sz, bs=bs, crappy_path=path_lr, good_path=path_hr, \n",
+    "                             random_seed=None, keep_pct=keep_pct)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def get_crit_data(classes, bs, sz):\n",
+    "    src = ImageList.from_folder(path, include=classes, recurse=True).random_split_by_pct(0.1, seed=42)\n",
+    "    ll = src.label_from_folder(classes=classes)\n",
+    "    data = (ll.transform(get_transforms(max_zoom=2.), size=sz)\n",
+    "           .databunch(bs=bs).normalize(imagenet_stats))\n",
+    "    return data"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def crappify(fn,i):\n",
+    "    dest = path_lr/fn.relative_to(path_hr)\n",
+    "    dest.parent.mkdir(parents=True, exist_ok=True)\n",
+    "    img = PIL.Image.open(fn).convert('LA').convert('RGB')\n",
+    "    img.save(dest)  "
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def save_preds(dl):\n",
+    "    i=0\n",
+    "    names = dl.dataset.items\n",
+    "    \n",
+    "    for b in dl:\n",
+    "        preds = learn_gen.pred_batch(batch=b, reconstruct=True)\n",
+    "        for o in preds:\n",
+    "            o.save(path_gen/names[i].name)\n",
+    "            i += 1"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def save_gen_images(learn_gen):\n",
+    "    if path_gen.exists(): shutil.rmtree(path_gen)\n",
+    "    path_gen.mkdir(exist_ok=True)\n",
+    "    data_gen = get_data(bs=bs, sz=sz, keep_pct=0.085)\n",
+    "    save_preds(data_gen.fix_dl)\n",
+    "    PIL.Image.open(path_gen.ls()[0])"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Crappified data"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "Prepare the input data by crappifying images."
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "Uncomment the first time you run this notebook."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "#il = ImageItemList.from_folder(path_hr)\n",
+    "#parallel(crappify, il.items)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# Pre-training"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### Pre-train generator"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "Now let's pretrain the generator."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "bs=88\n",
+    "sz=64\n",
+    "keep_pct=1.0"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "data_gen = get_data(bs=bs, sz=sz, keep_pct=keep_pct)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "learn_gen = gen_learner_deep(arch=models.resnet101, data=data_gen, gen_loss=FeatureLoss2(), nf_factor=nf_factor)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "learn_gen.callback_fns.append(partial(ImageGenTensorboardWriter, base_dir=TENSORBOARD_PATH, name='GenPre'))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "learn_gen.fit_one_cycle(2, pct_start=0.8, max_lr=slice(1e-3))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "learn_gen.save(gen_name)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "learn_gen.load(gen_name, with_opt=False)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "learn_gen.unfreeze()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "learn_gen.fit_one_cycle(2, pct_start=0.01,  max_lr=slice(3e-7, 3e-4))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "learn_gen.save(gen_name)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "bs=20\n",
+    "sz=128\n",
+    "keep_pct=1.0"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "learn_gen.data = get_data(sz=sz, bs=bs, keep_pct=keep_pct)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "learn_gen.unfreeze()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "learn_gen.load(gen_name, with_opt=False)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "learn_gen.fit_one_cycle(2, pct_start=0.01, max_lr=slice(1e-7,1e-4))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "learn_gen.save(gen_name)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "learn_gen.load(gen_name, with_opt=False)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "bs=8\n",
+    "sz=192\n",
+    "keep_pct=0.50"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "learn_gen.data = get_data(sz=sz, bs=bs, keep_pct=keep_pct)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "learn_gen.unfreeze()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "learn_gen.fit_one_cycle(1, pct_start=0.01, max_lr=slice(5e-8,5e-5))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "learn_gen.save(gen_name)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### Save generated images"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "save_gen_images(gen_name)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### Train critic"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "Pretrain the critic on crappy vs not crappy."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "bs=64\n",
+    "sz=128"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "learn_gen=None\n",
+    "gc.collect()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "loss_critic = AdaptiveLoss(nn.BCEWithLogitsLoss())"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "data_crit = get_crit_data([name_gen, 'test'], bs=bs, sz=sz)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "data_crit.show_batch(rows=3, ds_type=DatasetType.Train, imgsize=3)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "learn_critic = colorize_crit_learner(data=data_crit, nf=256)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "learn_critic.callback_fns.append(partial(LearnerTensorboardWriter, base_dir=TENSORBOARD_PATH, name='CriticPre'))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "learn_critic.fit_one_cycle(6, 1e-3)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "learn_critic.save(crit_name)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "bs=16\n",
+    "sz=192"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "learn_critic.data=get_crit_data([name_gen, 'test'], bs=bs, sz=sz)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "learn_critic.fit_one_cycle(4, 1e-4)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "learn_critic.save(crit_name)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## GAN"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "Now we'll combine those pretrained model in a GAN."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "learn_crit=None\n",
+    "learn_gen=None\n",
+    "gc.collect()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "lr=2e-5\n",
+    "sz=192\n",
+    "bs=5"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "#placeholder- not actually used\n",
+    "data_crit = get_crit_data([name_gen, 'test'], bs=bs, sz=sz)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "learn_crit = colorize_crit_learner(data=data_crit, nf=256).load(crit_name, with_opt=False)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "learn_gen = gen_learner_wide(arch=models.resnet101, data=data_gen, gen_loss=FeatureLoss2(), nf_factor=nf_factor).load(gen_name, with_opt=False)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "switcher = partial(AdaptiveGANSwitcher, critic_thresh=0.65)\n",
+    "learn = GANLearner.from_learners(learn_gen, learn_crit, weights_gen=(1.0,1.5), show_img=False, switcher=switcher,\n",
+    "                                 opt_func=partial(optim.Adam, betas=(0.,0.9)), wd=1e-3)\n",
+    "learn.callback_fns.append(partial(GANDiscriminativeLR, mult_lr=5.))\n",
+    "learn.callback_fns.append(partial(GANTensorboardWriter, base_dir=TENSORBOARD_PATH, name='GanLearner', visual_iters=100))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "for i in range(1,101):\n",
+    "    learn.data = get_data(sz=sz, bs=bs, keep_pct=0.001)\n",
+    "    learn_gen.freeze_to(-1)\n",
+    "    learn.fit(1,lr)\n",
+    "    save_all('_03_' + str(i))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "save_all('_01')"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### Save Generated Images Again"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "bs=8\n",
+    "sz=192"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "learn_gen = gen_learner_wide(arch=models.resnet101, data=data_gen, gen_loss=FeatureLoss2(), nf_factor=nf_factor).load('ColorizeNew73_gen192_05_7', with_opt=False)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "save_gen_images(gen_name)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### Train Critic Again"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "bs=16\n",
+    "sz=192"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "learn_gen=None\n",
+    "gc.collect()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "loss_critic = AdaptiveLoss(nn.BCEWithLogitsLoss())"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "data_crit = get_crit_data([name_gen, 'test'], bs=bs, sz=sz)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "data_crit.show_batch(rows=3, ds_type=DatasetType.Train, imgsize=3)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "learn_critic = colorize_crit_learner(data=data_crit, nf=256).load(crit_name + '5', with_opt=False)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "learn_critic.callback_fns.append(partial(LearnerTensorboardWriter, base_dir=TENSORBOARD_PATH, name='CriticPre'))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "learn_critic.fit_one_cycle(4, 1e-4)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "learn_critic.save(crit_name + '6')"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### GAN Again"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "learn_crit=None\n",
+    "learn_gen=None\n",
+    "gc.collect()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "lr=1e-6\n",
+    "sz=192\n",
+    "bs=5"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "data_crit = get_crit_data([name_gen, 'test'], bs=bs, sz=sz)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "learn_crit = colorize_crit_learner(data=data_crit, nf=256).load(crit_name + '6', with_opt=False)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "learn_gen = gen_learner_wide(arch=models.resnet101, data=data_gen, gen_loss=FeatureLoss2(), nf_factor=nf_factor).load('ColorizeNew73_gen192_05_7', with_opt=False)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "switcher = partial(AdaptiveGANSwitcher, critic_thresh=0.65)\n",
+    "learn = GANLearner.from_learners(learn_gen, learn_crit, weights_gen=(1.0,1.5), show_img=False, switcher=switcher,\n",
+    "                                 opt_func=partial(optim.Adam, betas=(0.,0.9)), wd=1e-3)\n",
+    "learn.callback_fns.append(partial(GANDiscriminativeLR, mult_lr=5.))\n",
+    "learn.callback_fns.append(partial(GANTensorboardWriter, base_dir=TENSORBOARD_PATH, name='GanLearner', visual_iters=100))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "for i in range(1,101):\n",
+    "    learn.data = get_data(sz=sz, bs=bs, keep_pct=0.001)\n",
+    "    learn_gen.freeze_to(-1)\n",
+    "    learn.fit(1,lr)\n",
+    "    save_all('_06_' + str(i))"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## fin"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.7.0"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}

+ 7 - 11
ImageColorizer.ipynb

@@ -42,13 +42,7 @@
     "#It literally just is a number multiplied by 16 to get the square render resolution.  \n",
     "#Note that this doesn't affect the resolution of the final output- the output is the same resolution as the input.\n",
     "#Example:  render_factor=21 => color is rendered at 16x21 = 336x336 px.  \n",
-    "render_factor=17\n",
-    "root_folder =  Path('data/imagenet/ILSVRC/Data/CLS-LOC/bandw')\n",
-    "weights_name = 'ColorizeNew68_gen192_01_5'\n",
-    "nf_factor = 1.25\n",
-    "\n",
-    "#weights_name = 'ColorizeNew70_gen192_01_5'\n",
-    "#nf_factor = 1.25"
+    "render_factor=48"
    ]
   },
   {
@@ -57,8 +51,10 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "vis = get_colorize_visualizer(root_folder=root_folder, weights_name=weights_name, nf_factor=nf_factor, render_factor=render_factor)\n",
-    "#vis = get_colorize_visualizer(root_folder=root_folder, weights_name=weights_name, nf_factor=nf_factor, render_factor=render_factor)"
+    "vis = get_video_colorizer(root_folder=Path('data/imagenet/ILSVRC/Data/CLS-LOC/bandw'), weights_name='ColorizeNew76_gen192_01_28', render_factor=render_factor, nf_factor=1.5).vis\n",
+    "#vis = get_video_colorizer(root_folder=Path('data/imagenet/ILSVRC/Data/CLS-LOC/bandw'), weights_name='ColorizeNew72_gen192_05_18', render_factor=render_factor).vis\n",
+    "#vis = get_image_colorizer(render_factor=render_factor)\n",
+    "#vis = get_image_colorizer(arch=models.resnet101, root_folder=Path('data/imagenet/ILSVRC/Data/CLS-LOC/bandw'), weights_name='ColorizeNew73_gen192_06_80', render_factor=render_factor)"
    ]
   },
   {
@@ -1462,7 +1458,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "vis.plot_transformed_image(\"test_images/LivingRoom1920Sweeden.jpg\")"
+    "vis.plot_transformed_image(\"test_images/LivingRoom1920Sweden.jpg\")"
    ]
   },
   {
@@ -2407,7 +2403,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "vis.plot_transformed_image(\"test_images/WaterfordIreland1909.jpg\", figsize=(70,70))"
+    "vis.plot_transformed_image(\"test_images/WaterfordIreland1909.jpg\")"
    ]
   },
   {

+ 6 - 6
README.md

@@ -94,13 +94,13 @@ Seneca Native in 1908
 
 This is a deep learning based model.  More specifically, what I've done is combined the following approaches:
 * **Self-Attention Generative Adversarial Network** (https://arxiv.org/abs/1805.08318).  Except the generator is a **pretrained U-Net**, and I've just modified it to have the spectral normalization and self-attention.  It's a pretty straightforward translation.  
-* **Two Time-Scale Update Rule** (https://arxiv.org/abs/1706.08500).  This is also very straightforward – it's just one to one generator/critic iterations and higher critic learning rate. This is modified to incorporate a "threshold" critic loss that makes sure that the critic is "caught up" before moving on to generator training.  This is particularly useful for the GAN supertransfer learning method described next.
-* **GAN Supertransfer Learning**  There's no paper here!  And I just totally made up that catchy term.  But it's the best way I can describe it.  Basically what you do is you first train the generator in a conventional way by itself with just the feature loss.  Then you generate images from that, and train the critic on distinguishing between those outputs and real images as a basic binary classifier.  Finally, you train the generator and critic together in a GAN setting (starting right at the target size of 192px in this case).  This training is super quick- only 1-10% of Imagenet dataset is iterated through, once!  Yet during this very short amount of GAN training the generator not only gets the full realistic colorization capabilities that we used to get through days of progressively resized GAN training, but it also doesn't accrue any of the artifacts and other ugly baggage of GANs. As far as I know this is a new technique.  And it's incredibly effective.  It seems paper-worthy but I'll leave the paper to whoever's so inclined (not I!).  This builds upon a technique developed in collaboration with Jeremy Howard and Sylvain Gugger (so fun!) for Fast.AI's Lesson 7 in version 3 of Practical Deep Learning for Coders part I.  The particular lesson notebook can be found here:  https://github.com/fastai/course-v3/blob/master/nbs/dl1/lesson7-superres-gan.ipynb   
+* **Two Time-Scale Update Rule** (https://arxiv.org/abs/1706.08500).  This is also very straightforward – it's just one to one generator/critic iterations and higher critic learning rate. This is modified to incorporate a "threshold" critic loss that makes sure that the critic is "caught up" before moving on to generator training.  This is particularly useful for the GAN Supertransfer Learning method described next.
+* **GAN Supertransfer Learning**.  There's no paper here! And I just totally made up that catchy term. But it's the best way I can describe it. Basically what you do is you first train the generator in a conventional way by itself with just the feature loss. Then you generate images from that, and train the critic on distinguishing between those outputs and real images as a basic binary classifier. Finally, you train the generator and critic together in a GAN setting (starting right at the target size of 192px in this case). This training is super quick- only 1-10% of Imagenet dataset is iterated through, once! Yet during this very short amount of GAN training the generator not only gets the full realistic colorization capabilities that we used to get through days of progressively resized GAN training, but it also doesn't accrue any of the artifacts and other ugly baggage of GANs. As far as I know this is a new technique. And it's incredibly effective. It seems paper-worthy but I'll leave the paper to whoever is so inclined (not I!). This builds upon a technique developed in collaboration with Jeremy Howard and Sylvain Gugger (so fun!) for Fast.AI's Lesson 7 in version 3 of Practical Deep Learning for Coders part I. The particular lesson notebook can be found here: https://github.com/fastai/course-v3/blob/master/nbs/dl1/lesson7-superres-gan.ipynb  
 * **Generator Loss** during GAN Supertransfer Learning is two parts:  One is a basic Perceptual Loss (or Feature Loss) based on VGG16 – this just biases the generator model to replicate the input image.  The second is the loss score from the critic.  For the curious – Perceptual Loss isn't sufficient by itself to produce good results.  It tends to just encourage a bunch of brown/green/blue – you know, cheating to the test, basically, which neural networks are really good at doing!  Key thing to realize here is that GANs essentially are learning the loss function for you – which is really one big step closer to toward the ideal that we're shooting for in machine learning.  And of course you generally get much better results when you get the machine to learn something you were previously hand coding.  That's certainly the case here.
 
 Of note:  There's no longer any "Progressive Growing of GANs" type training going on here.  It's just not needed in lieu of the superior results obtained by the GAN Supertransfer Learning technique described above.
 
-The beauty of this model is that it should be generally useful for all sorts of image modification, and it should do it quite well.  What you're seeing above are the results of the colorization model, but that's just one component in a pipeline that I'm looking to develop here with the exact same model. 
+The beauty of this model is that it should be generally useful for all sorts of image modification, and it should do it quite well.  What you're seeing above are the results of the colorization model, but that's just one component in a pipeline that I'm looking to develop here with the exact same approach.
 
 
 ### This Project, Going Forward
@@ -115,7 +115,7 @@ The easiest way to get started is to go straight to the Colab notebooks:
 
 Image [<img src="https://colab.research.google.com/assets/colab-badge.svg" align="center">](https://colab.research.google.com/github/jantic/DeOldify/blob/master/ImageColorizerColab.ipynb) | Video [<img src="https://colab.research.google.com/assets/colab-badge.svg" align="center">](https://colab.research.google.com/github/jantic/DeOldify/blob/master/VideoColorizerColab.ipynb) 
 
-Special thanks to Matt Robinson and Maria Benevente for their image Colab notebook contributions, and Robert Bell for the video Colab notebook work!
+Special thanks to Matt Robinson and María Benavente for their image Colab notebook contributions, and Robert Bell for the video Colab notebook work!
 
 -----------------------
 
@@ -123,8 +123,8 @@ Special thanks to Matt Robinson and Maria Benevente for their image Colab notebo
 
 #### Hardware and Operating System Requirements
 
-* **(Training Only) BEEFY Graphics card**.  I'd really like to have more memory than the 11 GB in my GeForce 1080TI (11GB).  You'll have a tough time with less.  The Unet and Critic are ridiculously large.  
-* **(Colorization Alone) A decent graphics card**. Approximately 3GB+ memory video cards should be sufficient.
+* **(Training Only) BEEFY Graphics card**.  I'd really like to have more memory than the 11 GB in my GeForce 1080TI (11GB).  You'll have a tough time with less.  The Generators and Critic are ridiculously large.  
+* **(Colorization Alone) A decent graphics card**. Approximately 4GB+ memory video cards should be sufficient.
 * **Linux (or maybe Windows 10)**  I'm using Ubuntu 16.04, but nothing about this precludes Windows 10 support as far as I know.  I just haven't tested it and am not going to make it a priority for now.  
 
 #### Easy Install

ファイルの差分が大きいため隠しています
+ 76 - 78
VideoColorizer.ipynb


+ 13 - 13
fasterai/generators.py

@@ -1,31 +1,31 @@
 from fastai.vision import *
 from fastai.vision.learner import cnn_config
-from .unet import CustomDynamicUnet, CustomDynamicUnet2
+from .unet import DynamicUnetWide, DynamicUnetDeep
 from .loss import FeatureLoss
 from .dataset import *
 
 #Weights are implicitly read from ./models/ folder 
-def colorize_gen_inference(root_folder:Path, weights_name:str, nf_factor:float)->Learner:
+def gen_inference_deep(root_folder:Path, weights_name:str, arch=models.resnet34, nf_factor:float=1.25)->Learner:
       data = get_dummy_databunch()
-      learn = colorize_gen_learner(data=data, gen_loss=F.l1_loss, nf_factor=nf_factor)
+      learn = gen_learner_deep(data=data, gen_loss=F.l1_loss, arch=arch, nf_factor=nf_factor)
       learn.path = root_folder
       learn.load(weights_name)
       learn.model.eval()
       return learn
 
-def colorize_gen_learner(data:ImageDataBunch, gen_loss=FeatureLoss(), arch=models.resnet34, nf_factor:float=1.0)->Learner:
-    return custom_unet_learner(data, arch, wd=1e-3, blur=True, norm_type=NormType.Spectral,
+def gen_learner_deep(data:ImageDataBunch, gen_loss=FeatureLoss(), arch=models.resnet34, nf_factor:float=1.25)->Learner:
+    return unet_learner_deep(data, arch, wd=1e-3, blur=True, norm_type=NormType.Spectral,
                         self_attention=True, y_range=(-3.,3.), loss_func=gen_loss, nf_factor=nf_factor)
 
 #The code below is meant to be merged into fastaiv1 ideally
-def custom_unet_learner(data:DataBunch, arch:Callable, pretrained:bool=True, blur_final:bool=True,
+def unet_learner_deep(data:DataBunch, arch:Callable, pretrained:bool=True, blur_final:bool=True,
                  norm_type:Optional[NormType]=NormType, split_on:Optional[SplitFuncOrIdxList]=None, 
                  blur:bool=False, self_attention:bool=False, y_range:Optional[Tuple[float,float]]=None, last_cross:bool=True,
                  bottle:bool=False, nf_factor:float=1.0, **kwargs:Any)->Learner:
     "Build Unet learner from `data` and `arch`."
     meta = cnn_config(arch)
     body = create_body(arch, pretrained)
-    model = to_device(CustomDynamicUnet(body, n_classes=data.c, blur=blur, blur_final=blur_final,
+    model = to_device(DynamicUnetDeep(body, n_classes=data.c, blur=blur, blur_final=blur_final,
           self_attention=self_attention, y_range=y_range, norm_type=norm_type, last_cross=last_cross,
           bottle=bottle, nf_factor=nf_factor), data.device)
     learn = Learner(data, model, **kwargs)
@@ -37,27 +37,27 @@ def custom_unet_learner(data:DataBunch, arch:Callable, pretrained:bool=True, blu
 #-----------------------------
 
 #Weights are implicitly read from ./models/ folder 
-def colorize_gen_inference2(root_folder:Path, weights_name:str, nf_factor:int, arch=models.resnet34)->Learner:
+def gen_inference_wide(root_folder:Path, weights_name:str, nf_factor:int=2, arch=models.resnet34)->Learner:
       data = get_dummy_databunch()
-      learn = colorize_gen_learner2(data=data, gen_loss=F.l1_loss, nf_factor=nf_factor, arch=arch)
+      learn = gen_learner_wide(data=data, gen_loss=F.l1_loss, nf_factor=nf_factor, arch=arch)
       learn.path = root_folder
       learn.load(weights_name)
       learn.model.eval()
       return learn
 
-def colorize_gen_learner2(data:ImageDataBunch, gen_loss=FeatureLoss(), arch=models.resnet34, nf_factor:int=1)->Learner:
-    return custom_unet_learner2(data, arch=arch, wd=1e-3, blur=True, norm_type=NormType.Spectral,
+def gen_learner_wide(data:ImageDataBunch, gen_loss=FeatureLoss(), arch=models.resnet34, nf_factor:int=2)->Learner:
+    return unet_learner_wide(data, arch=arch, wd=1e-3, blur=True, norm_type=NormType.Spectral,
                         self_attention=True, y_range=(-3.,3.), loss_func=gen_loss, nf_factor=nf_factor)
 
 #The code below is meant to be merged into fastaiv1 ideally
-def custom_unet_learner2(data:DataBunch, arch:Callable, pretrained:bool=True, blur_final:bool=True,
+def unet_learner_wide(data:DataBunch, arch:Callable, pretrained:bool=True, blur_final:bool=True,
                  norm_type:Optional[NormType]=NormType, split_on:Optional[SplitFuncOrIdxList]=None, 
                  blur:bool=False, self_attention:bool=False, y_range:Optional[Tuple[float,float]]=None, last_cross:bool=True,
                  bottle:bool=False, nf_factor:int=1, **kwargs:Any)->Learner:
     "Build Unet learner from `data` and `arch`."
     meta = cnn_config(arch)
     body = create_body(arch, pretrained)
-    model = to_device(CustomDynamicUnet2(body, n_classes=data.c, blur=blur, blur_final=blur_final,
+    model = to_device(DynamicUnetWide(body, n_classes=data.c, blur=blur, blur_final=blur_final,
           self_attention=self_attention, y_range=y_range, norm_type=norm_type, last_cross=last_cross,
           bottle=bottle, nf_factor=nf_factor), data.device)
     learn = Learner(data, model, **kwargs)

+ 7 - 7
fasterai/unet.py

@@ -7,7 +7,7 @@ from fastai.vision import *
 
 #The code below is meant to be merged into fastaiv1 ideally
 
-__all__ = ['CustomDynamicUnet', 'CustomUnetBlock', 'CustomPixelShuffle_ICNR']
+__all__ = ['DynamicUnetDeep', 'DynamicUnetWide']
 
 def _get_sfs_idxs(sizes:Sizes) -> List[int]:
     "Get the indexes of the layers where the size of the activation changes."
@@ -35,7 +35,7 @@ class CustomPixelShuffle_ICNR(nn.Module):
         x = self.shuf(self.relu(self.conv(x)))
         return self.blur(self.pad(x)) if self.blur else x
 
-class CustomUnetBlock(nn.Module):
+class UnetBlockDeep(nn.Module):
     "A quasi-UNet block, using `PixelShuffle_ICNR upsampling`."
     def __init__(self, up_in_c:int, x_in_c:int, hook:Hook, final_div:bool=True, blur:bool=False, leaky:float=None,
                  self_attention:bool=False, nf_factor:float=1.0,  **kwargs):
@@ -59,7 +59,7 @@ class CustomUnetBlock(nn.Module):
         return self.conv2(self.conv1(cat_x))
 
 
-class CustomDynamicUnet(SequentialEx):
+class DynamicUnetDeep(SequentialEx):
     "Create a U-Net from a given architecture."
     def __init__(self, encoder:nn.Module, n_classes:int, blur:bool=False, blur_final=True, self_attention:bool=False,
                  y_range:Optional[Tuple[float,float]]=None, last_cross:bool=True, bottle:bool=False,
@@ -82,7 +82,7 @@ class CustomDynamicUnet(SequentialEx):
             up_in_c, x_in_c = int(x.shape[1]), int(sfs_szs[idx][1])
             do_blur = blur and (not_final or blur_final)
             sa = self_attention and (i==len(sfs_idxs)-3)
-            unet_block = CustomUnetBlock(up_in_c, x_in_c, self.sfs[i], final_div=not_final, blur=blur, self_attention=sa,
+            unet_block = UnetBlockDeep(up_in_c, x_in_c, self.sfs[i], final_div=not_final, blur=blur, self_attention=sa,
                                    norm_type=norm_type, extra_bn=extra_bn, nf_factor=nf_factor, **kwargs).eval()
             layers.append(unet_block)
             x = unet_block(x)
@@ -104,7 +104,7 @@ class CustomDynamicUnet(SequentialEx):
 
 
 #------------------------------------------------------
-class CustomUnetBlock2(nn.Module):
+class UnetBlockWide(nn.Module):
     "A quasi-UNet block, using `PixelShuffle_ICNR upsampling`."
     def __init__(self, up_in_c:int, x_in_c:int, n_out:int,  hook:Hook, final_div:bool=True, blur:bool=False, leaky:float=None,
                  self_attention:bool=False,  **kwargs):
@@ -127,7 +127,7 @@ class CustomUnetBlock2(nn.Module):
         return self.conv(cat_x)
 
 
-class CustomDynamicUnet2(SequentialEx):
+class DynamicUnetWide(SequentialEx):
     "Create a U-Net from a given architecture."
     def __init__(self, encoder:nn.Module, n_classes:int, blur:bool=False, blur_final=True, self_attention:bool=False,
                  y_range:Optional[Tuple[float,float]]=None, last_cross:bool=True, bottle:bool=False,
@@ -155,7 +155,7 @@ class CustomDynamicUnet2(SequentialEx):
 
             n_out = nf if not_final else nf//2
 
-            unet_block = CustomUnetBlock2(up_in_c, x_in_c, n_out, self.sfs[i], final_div=not_final, blur=blur, self_attention=sa,
+            unet_block = UnetBlockWide(up_in_c, x_in_c, n_out, self.sfs[i], final_div=not_final, blur=blur, self_attention=sa,
                                    norm_type=norm_type, extra_bn=extra_bn, **kwargs).eval()
             layers.append(unet_block)
             x = unet_block(x)

+ 99 - 9
fasterai/visualize.py

@@ -4,11 +4,13 @@ from matplotlib.axes import Axes
 from matplotlib.figure import Figure
 from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
 from .filters import IFilter, MasterFilter, ColorizerFilter
-from .generators import colorize_gen_inference, colorize_gen_inference2
+from .generators import gen_inference_deep, gen_inference_wide
 from IPython.display import display
 from tensorboardX import SummaryWriter
 from scipy import misc
 from PIL import Image 
+import ffmpeg
+import youtube_dl
 
 
 class ModelImageVisualizer():
@@ -51,17 +53,106 @@ class ModelImageVisualizer():
         rows = rows if rows * columns == num_images else rows + 1
         return rows, columns
 
+class VideoColorizer():
+    def __init__(self, vis:ModelImageVisualizer):
+        self.vis=vis
+        workfolder = Path('./video')
+        self.source_folder = workfolder/"source"
+        self.bwframes_root = workfolder/"bwframes"
+        self.audio_root = workfolder/"audio"
+        self.colorframes_root = workfolder/"colorframes"
+        self.result_folder = workfolder/"result"
+
+    def _purge_images(self, dir):
+        for f in os.listdir(dir):
+            if re.search('.*?\.jpg', f):
+                os.remove(os.path.join(dir, f))
+
+    def _get_fps(self, source_path: Path)->float:
+        probe = ffmpeg.probe(str(source_path))
+        stream_data = next((stream for stream in probe['streams'] if stream['codec_type'] == 'video'), None)
+        avg_frame_rate = stream_data['avg_frame_rate']
+        fps_num=avg_frame_rate.split("/")[0]
+        fps_den = avg_frame_rate.rsplit("/")[1]
+        return round(float(fps_num)/float(fps_den))
+
+    def _download_video_from_url(self, source_url, source_path:Path):
+        if source_path.exists(): source_path.unlink()
+
+        ydl_opts = {    
+            'format': 'bestvideo[ext=mp4]+bestaudio[ext=m4a]/mp4',     
+            'outtmpl': str(source_path)   
+            }
+        with youtube_dl.YoutubeDL(ydl_opts) as ydl:
+            ydl.download([source_url])
+
+    def _extract_raw_frames(self, source_path:Path):
+        bwframes_folder = self.bwframes_root/(source_path.stem)
+        bwframe_path_template = str(bwframes_folder/'%5d.jpg')
+        bwframes_folder.mkdir(parents=True, exist_ok=True)
+        self._purge_images(bwframes_folder)
+        ffmpeg.input(str(source_path)).output(str(bwframe_path_template), format='image2', vcodec='mjpeg', qscale=0).run(capture_stdout=True)
+
+
+    def _colorize_raw_frames(self, source_path:Path):
+        colorframes_folder = self.colorframes_root/(source_path.stem)
+        colorframes_folder.mkdir(parents=True, exist_ok=True)
+        self._purge_images(colorframes_folder)
+        bwframes_folder = self.bwframes_root/(source_path.stem)
+
+        for img in progress_bar(os.listdir(str(bwframes_folder))):
+            img_path = bwframes_folder/img
+            if os.path.isfile(str(img_path)):
+                color_image = self.vis.get_transformed_image(str(img_path))
+                color_image.save(str(colorframes_folder/img))
+    
+    def _build_video(self, source_path:Path):
+        result_path = self.result_folder/source_path.name
+        colorframes_folder = self.colorframes_root/(source_path.stem)
+        colorframes_path_template = str(colorframes_folder/'%5d.jpg')
+        result_path.parent.mkdir(parents=True, exist_ok=True)
+        if result_path.exists(): result_path.unlink()
+        fps = self._get_fps(source_path)
+
+        ffmpeg.input(str(colorframes_path_template), format='image2', vcodec='mjpeg', framerate=str(fps)) \
+            .output(str(result_path), crf=17, vcodec='libx264') \
+            .run(capture_stdout=True)
+        
+        print('Video created here: ' + str(result_path))
+
+    def colorize_from_url(self, source_url, file_name:str):    
+        source_path =  self.source_folder/file_name
+        self._download_video_from_url(source_url, source_path)
+        self._colorize_from_path(source_path)
+
+    def colorize_from_file_name(self, file_name:str):
+        source_path =  self.source_folder/file_name
+        self._colorize_from_path(source_path)
+
+    def _colorize_from_path(self, source_path:Path):
+        self._extract_raw_frames(source_path)
+        self._colorize_raw_frames(source_path)
+        self._build_video(source_path)
+
+
+def get_video_colorizer2(root_folder:Path=Path('./'), weights_name:str='ColorizeVideos_gen2', 
+        results_dir = 'result_images', render_factor:int=36)->VideoColorizer:
+    learn = gen_inference_wide(root_folder=root_folder, weights_name=weights_name, arch=models.resnet101)
+    filtr = MasterFilter([ColorizerFilter(learn=learn)], render_factor=render_factor)
+    vis = ModelImageVisualizer(filtr, results_dir=results_dir)
+    return VideoColorizer(vis)
+
 
-def get_colorize_visualizer(root_folder:Path=Path('./'), weights_name:str='colorize_gen', 
-        results_dir = 'result_images', nf_factor:float=1.25, render_factor:int=21)->ModelImageVisualizer:
-    learn = colorize_gen_inference(root_folder=root_folder, weights_name=weights_name, nf_factor=nf_factor)
+def get_video_colorizer(root_folder:Path=Path('./'), weights_name:str='ColorizeVideos_gen', 
+        results_dir = 'result_images', render_factor:int=21, nf_factor:float=1.25)->VideoColorizer:
+    learn = gen_inference_deep(root_folder=root_folder, weights_name=weights_name, nf_factor=nf_factor)
     filtr = MasterFilter([ColorizerFilter(learn=learn)], render_factor=render_factor)
     vis = ModelImageVisualizer(filtr, results_dir=results_dir)
-    return vis
+    return VideoColorizer(vis)
 
-def get_colorize_visualizer2(root_folder:Path=Path('./'), weights_name:str='colorize_gen', 
-        results_dir = 'result_images', nf_factor:int=1, render_factor:int=21, arch=models.resnet34)->ModelImageVisualizer:
-    learn = colorize_gen_inference2(root_folder=root_folder, weights_name=weights_name, nf_factor=nf_factor, arch=arch)
+def get_image_colorizer(root_folder:Path=Path('./'), weights_name:str='ColorizeImages_gen', 
+        results_dir = 'result_images', render_factor:int=21, arch=models.resnet34)->ModelImageVisualizer:
+    learn = gen_inference_wide(root_folder=root_folder, weights_name=weights_name, arch=arch)
     filtr = MasterFilter([ColorizerFilter(learn=learn)], render_factor=render_factor)
     vis = ModelImageVisualizer(filtr, results_dir=results_dir)
     return vis
@@ -69,4 +160,3 @@ def get_colorize_visualizer2(root_folder:Path=Path('./'), weights_name:str='colo
 
 
 
-

この差分においてかなりの量のファイルが変更されているため、一部のファイルを表示していません