Pārlūkot izejas kodu

Adding improved training regime; Refactored Unet and Filter code to accomodate multiple resnet implementations

Jason Antic 6 gadi atpakaļ
vecāks
revīzija
33343bc2f2
7 mainītis faili ar 392 papildinājumiem un 152 dzēšanām
  1. 1 0
      .gitignore
  2. 59 40
      ColorizeTraining.ipynb
  3. 148 27
      ColorizeVisualization.ipynb
  4. 64 43
      DeFadeTraining.ipynb
  5. 47 16
      fasterai/filters.py
  6. 68 21
      fasterai/generators.py
  7. 5 5
      fasterai/visualize.py

+ 1 - 0
.gitignore

@@ -365,3 +365,4 @@ result_images/Sami1880s.jpg
 result_images/Scotland1919.jpg
 result_images/SenecaNative1908.jpg
 result_images/TitanicGym.jpg
+.~ColorizeVisualization.ipynb

+ 59 - 40
ColorizeTraining.ipynb

@@ -49,18 +49,18 @@
     "proj_id = 'colorize'\n",
     "TENSORBOARD_PATH = Path('data/tensorboard/' + proj_id)\n",
     "\n",
-    "gpath = IMAGENET.parent/(proj_id + '_gen_192.h5')\n",
-    "dpath = IMAGENET.parent/(proj_id + '_critic_192.h5')\n",
+    "gpath = IMAGENET.parent/(proj_id + '_gen_128.h5')\n",
+    "dpath = IMAGENET.parent/(proj_id + '_critic_128.h5')\n",
     "\n",
-    "c_lr=2e-4\n",
+    "c_lr=5e-4\n",
     "c_lrs = np.array([c_lr,c_lr,c_lr])\n",
     "\n",
-    "g_lr=c_lr/4\n",
-    "g_lrs = np.array([g_lr/1000,g_lr/100,g_lr])\n",
+    "g_lr=c_lr/5\n",
+    "g_lrs = np.array([g_lr/100,g_lr/10,g_lr])\n",
     "\n",
     "keep_pcts=[0.25,0.25]\n",
     "gen_freeze_tos=[-1,0]\n",
-    "lrs_unfreeze_factor=1.0\n",
+    "lrs_unfreeze_factor=0.05\n",
     "x_tfms = [BlackAndWhiteTransform()]\n",
     "extra_aug_tfms = [RandomLighting(0.1, 0.1, tfm_y=TfmType.PIXEL)]\n",
     "torch.backends.cudnn.benchmark=True"
@@ -83,7 +83,7 @@
     "#netGVis = ModelVisualizationHook(TENSORBOARD_PATH, netG, 'netG')\n",
     "#load_model(netG, gpath)\n",
     "\n",
-    "netD = DCCritic(ni=3, nf=512).cuda()\n",
+    "netD = DCCritic(ni=3, nf=256).cuda()\n",
     "#netDVis = ModelVisualizationHook(TENSORBOARD_PATH, netD, 'netD')\n",
     "#load_model(netD, dpath)"
    ]
@@ -106,64 +106,83 @@
    "source": [
     "scheds=[]\n",
     "\n",
-    "scheds.extend(GANTrainSchedule.generate_schedules(szs=[64,64], bss=[32,32], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms,keep_pcts=keep_pcts, \n",
+    "scheds.extend(GANTrainSchedule.generate_schedules(szs=[64,64], bss=[128,128], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms,keep_pcts=[1.0,1.0], \n",
     "    save_base_name=proj_id, c_lrs=c_lrs, g_lrs=g_lrs, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=gen_freeze_tos))\n",
     "\n",
+    "c_lrs=c_lrs/2\n",
+    "g_lrs=g_lrs/2\n",
+    "\n",
     "#unshock\n",
-    "scheds.extend(GANTrainSchedule.generate_schedules(szs=[96], bss=[16], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.1], \n",
-    "    save_base_name=proj_id, c_lrs=c_lrs/20, g_lrs=g_lrs/20, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))\n",
+    "scheds.extend(GANTrainSchedule.generate_schedules(szs=[96], bss=[64], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.1], \n",
+    "    save_base_name=proj_id, c_lrs=c_lrs/10, g_lrs=g_lrs/10, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))\n",
     "\n",
-    "scheds.extend(GANTrainSchedule.generate_schedules(szs=[96,96], bss=[16,16], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=keep_pcts, \n",
-    "    save_base_name=proj_id, c_lrs=c_lrs/2, g_lrs=g_lrs/2, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=gen_freeze_tos))\n",
+    "scheds.extend(GANTrainSchedule.generate_schedules(szs=[96,96], bss=[64,64], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=keep_pcts, \n",
+    "    save_base_name=proj_id, c_lrs=c_lrs, g_lrs=g_lrs, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=gen_freeze_tos))\n",
     "\n",
+    "c_lrs=c_lrs/2\n",
+    "g_lrs=g_lrs/2\n",
     "\n",
     "#unshock\n",
-    "scheds.extend(GANTrainSchedule.generate_schedules(szs=[128], bss=[8], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.1], \n",
-    "    save_base_name=proj_id, c_lrs=c_lrs/30, g_lrs=g_lrs/30, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))\n",
+    "scheds.extend(GANTrainSchedule.generate_schedules(szs=[128], bss=[32], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.1], \n",
+    "    save_base_name=proj_id, c_lrs=c_lrs/10, g_lrs=g_lrs/10, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))\n",
+    "\n",
+    "scheds.extend(GANTrainSchedule.generate_schedules(szs=[128,128], bss=[32,32], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=keep_pcts, \n",
+    "    save_base_name=proj_id, c_lrs=c_lrs, g_lrs=g_lrs, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=gen_freeze_tos))\n",
+    "\n",
     "\n",
-    "scheds.extend(GANTrainSchedule.generate_schedules(szs=[128,128], bss=[8,8], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=keep_pcts, \n",
-    "    save_base_name=proj_id, c_lrs=c_lrs/3, g_lrs=g_lrs/3, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=gen_freeze_tos))\n",
+    "c_lrs=c_lrs/2\n",
+    "g_lrs=g_lrs/2\n",
     "\n",
     "#unshock\n",
-    "scheds.extend(GANTrainSchedule.generate_schedules(szs=[160], bss=[5], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.1], \n",
-    "    save_base_name=proj_id, c_lrs=c_lrs/80, g_lrs=g_lrs/80, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))\n",
+    "scheds.extend(GANTrainSchedule.generate_schedules(szs=[160], bss=[20], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.1], \n",
+    "    save_base_name=proj_id, c_lrs=c_lrs/10, g_lrs=g_lrs/10, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))\n",
     "\n",
-    "scheds.extend(GANTrainSchedule.generate_schedules(szs=[160], bss=[5], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.25], \n",
-    "    save_base_name=proj_id, c_lrs=c_lrs/8, g_lrs=g_lrs/8, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))\n",
+    "scheds.extend(GANTrainSchedule.generate_schedules(szs=[160], bss=[20], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.25], \n",
+    "    save_base_name=proj_id, c_lrs=c_lrs, g_lrs=g_lrs, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))\n",
     "\n",
-    "scheds.extend(GANTrainSchedule.generate_schedules(szs=[160], bss=[5], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.25], \n",
-    "    save_base_name=proj_id, c_lrs=c_lrs/16, g_lrs=g_lrs/16, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[0]))\n",
+    "scheds.extend(GANTrainSchedule.generate_schedules(szs=[160], bss=[20], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.25], \n",
+    "    save_base_name=proj_id, c_lrs=c_lrs, g_lrs=g_lrs, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[0]))\n",
+    "\n",
+    "\n",
+    "c_lrs=c_lrs/2\n",
+    "g_lrs=g_lrs/2\n",
     "\n",
     "#unshock\n",
-    "scheds.extend(GANTrainSchedule.generate_schedules(szs=[192], bss=[4], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.1], \n",
-    "    save_base_name=proj_id, c_lrs=c_lrs/80, g_lrs=g_lrs/80, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))\n",
+    "scheds.extend(GANTrainSchedule.generate_schedules(szs=[192], bss=[12], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.1], \n",
+    "    save_base_name=proj_id, c_lrs=c_lrs/10, g_lrs=g_lrs/10, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))\n",
+    "\n",
+    "scheds.extend(GANTrainSchedule.generate_schedules(szs=[192], bss=[12], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.25], \n",
+    "    save_base_name=proj_id, c_lrs=c_lrs, g_lrs=g_lrs, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))\n",
+    "\n",
+    "scheds.extend(GANTrainSchedule.generate_schedules(szs=[192], bss=[12], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.25], \n",
+    "    save_base_name=proj_id, c_lrs=c_lrs, g_lrs=g_lrs, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[0]))\n",
     "\n",
-    "scheds.extend(GANTrainSchedule.generate_schedules(szs=[192], bss=[3], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.25], \n",
-    "    save_base_name=proj_id, c_lrs=c_lrs/8, g_lrs=g_lrs/8, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))\n",
     "\n",
-    "scheds.extend(GANTrainSchedule.generate_schedules(szs=[192], bss=[3], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.25], \n",
-    "    save_base_name=proj_id, c_lrs=c_lrs/16, g_lrs=g_lrs/16, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[0]))\n",
+    "c_lrs=c_lrs/2\n",
+    "g_lrs=g_lrs/2\n",
     "\n",
     "#unshock\n",
-    "scheds.extend(GANTrainSchedule.generate_schedules(szs=[224], bss=[2], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.1], \n",
-    "    save_base_name=proj_id, c_lrs=c_lrs/160, g_lrs=g_lrs/160, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))\n",
+    "scheds.extend(GANTrainSchedule.generate_schedules(szs=[224], bss=[8], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.1], \n",
+    "    save_base_name=proj_id, c_lrs=c_lrs/10, g_lrs=g_lrs/10, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))\n",
     "\n",
-    "scheds.extend(GANTrainSchedule.generate_schedules(szs=[224], bss=[2], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.25], \n",
-    "    save_base_name=proj_id, c_lrs=c_lrs/16, g_lrs=g_lrs/16, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))\n",
+    "scheds.extend(GANTrainSchedule.generate_schedules(szs=[224], bss=[8], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.25], \n",
+    "    save_base_name=proj_id, c_lrs=c_lrs, g_lrs=g_lrs, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))\n",
     "\n",
-    "scheds.extend(GANTrainSchedule.generate_schedules(szs=[224], bss=[2], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.25], \n",
-    "    save_base_name=proj_id, c_lrs=c_lrs/32, g_lrs=g_lrs/32, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[0]))\n",
+    "scheds.extend(GANTrainSchedule.generate_schedules(szs=[224], bss=[8], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.25], \n",
+    "    save_base_name=proj_id, c_lrs=c_lrs, g_lrs=g_lrs, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[0]))\n",
     "\n",
+    "c_lrs=c_lrs/2\n",
+    "g_lrs=g_lrs/2\n",
     "\n",
     "#unshock\n",
-    "scheds.extend(GANTrainSchedule.generate_schedules(szs=[256], bss=[2], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.1], \n",
-    "    save_base_name=proj_id, c_lrs=c_lrs/160, g_lrs=g_lrs/160, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))\n",
+    "scheds.extend(GANTrainSchedule.generate_schedules(szs=[256], bss=[6], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.1], \n",
+    "    save_base_name=proj_id, c_lrs=c_lrs/10, g_lrs=g_lrs/10, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))\n",
     "\n",
-    "scheds.extend(GANTrainSchedule.generate_schedules(szs=[256], bss=[2], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.25], \n",
-    "    save_base_name=proj_id, c_lrs=c_lrs/16, g_lrs=g_lrs/16, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))\n",
+    "scheds.extend(GANTrainSchedule.generate_schedules(szs=[256], bss=[6], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.25], \n",
+    "    save_base_name=proj_id, c_lrs=c_lrs, g_lrs=g_lrs, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))\n",
     "\n",
-    "scheds.extend(GANTrainSchedule.generate_schedules(szs=[256], bss=[2], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.25], \n",
-    "    save_base_name=proj_id, c_lrs=c_lrs/32, g_lrs=g_lrs/32, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[0]))"
+    "scheds.extend(GANTrainSchedule.generate_schedules(szs=[256], bss=[6], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.25], \n",
+    "    save_base_name=proj_id, c_lrs=c_lrs, g_lrs=g_lrs, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[0]))"
    ]
   },
   {

+ 148 - 27
ColorizeVisualization.ipynb

@@ -36,9 +36,9 @@
     "from pathlib import Path\n",
     "from itertools import repeat\n",
     "import tensorboardX\n",
-    "torch.cuda.set_device(0)\n",
     "plt.style.use('dark_background')\n",
-    "torch.backends.cudnn.benchmark=True"
+    "torch.backends.cudnn.benchmark=True\n",
+    "torch.cuda.set_device(0)"
    ]
   },
   {
@@ -48,11 +48,14 @@
    "outputs": [],
    "source": [
     "IMAGENET = Path('data/imagenet/ILSVRC/Data/CLS-LOC/train')\n",
-    "colorizer_path = IMAGENET.parent/('colorize_gen_192.h5')\n",
+    "#colorizer_path = IMAGENET.parent/('colorize_gen_192.h5')\n",
+    "colorizer_path = IMAGENET.parent/('colorize34_gen_192.h5')\n",
+    "#colorizer_path = IMAGENET.parent/('colorize18_gen_160.h5')\n",
+    "defader_path = IMAGENET.parent/('bwdefade3_gen_160.h5')\n",
     "\n",
     "#The higher the render_factor, the more GPU memory will be used and generally images will look better.  \n",
     "#11GB can take a factor of 42 max.  Performance generally gracefully degrades with lower factors, \n",
-    "#though you may also find that certiain images will actually render better at lower numbers.  \n",
+    "#though you may also find that certain images will actually render better at lower numbers.  \n",
     "#This tends to be the case with the oldest photos.\n",
     "render_factor=42"
    ]
@@ -63,7 +66,8 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "filters = [Colorizer(gpu=0, weights_path=colorizer_path)]\n",
+    "filters = [Colorizer34(gpu=0, weights_path=colorizer_path,nf_factor=2, map_to_orig=True)]\n",
+    "#filters = [DeFader(gpu=3, weights_path=defader_path, nf_factor=2), Colorizer34(gpu=0, weights_path=colorizer_path,nf_factor=2, map_to_orig=True)]\n",
     "vis = ModelImageVisualizer(filters, render_factor=render_factor, results_dir='result_images')"
    ]
   },
@@ -73,7 +77,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "vis.plot_transformed_image(\"test_images/Chief.jpg\")"
+    "vis.plot_transformed_image(\"test_images/1852GatekeepersWindsor.jpg\")"
    ]
   },
   {
@@ -82,7 +86,124 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "vis.plot_transformed_image(\"test_images/marilyn_woods.jpg\")"
+    "vis.plot_transformed_image(\"test_images/Chief.jpg\", render_factor=17)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "vis.plot_transformed_image(\"test_images/1850SchoolForGirls.jpg\", render_factor=42)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "vis.plot_transformed_image(\"test_images/AtlanticCityBeach1905.jpg\", render_factor=42)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "vis.plot_transformed_image(\"test_images/CottonMillWorkers1913.jpg\", render_factor=41)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "vis.plot_transformed_image(\"test_images/BrooklynNavyYardHospital.jpg\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "vis.plot_transformed_image(\"test_images/FinnishPeasant1867.jpg\", render_factor=20)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "vis.plot_transformed_image(\"test_images/AtlanticCity1905.png\", render_factor=39)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "vis.plot_transformed_image(\"test_images/PushingCart.jpg\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "vis.plot_transformed_image(\"test_images/Drive1905.jpg\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "vis.plot_transformed_image(\"test_images/IronLung.png\", render_factor=39)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "vis.plot_transformed_image(\"test_images/FamilyWithDog.jpg\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "vis.plot_transformed_image(\"test_images/DayAtSeaBelgium.jpg\", render_factor=41)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "vis.plot_transformed_image(\"test_images/marilyn_woods.jpg\", render_factor=30)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "vis.plot_transformed_image(\"test_images/OldWomanSweden1904.jpg\", render_factor=39)"
    ]
   },
   {
@@ -154,7 +275,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "vis.plot_transformed_image(\"test_images/einstein_beach.jpg\")"
+    "vis.plot_transformed_image(\"test_images/einstein_beach.jpg\", render_factor=42)"
    ]
   },
   {
@@ -262,7 +383,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "vis.plot_transformed_image(\"test_images/teddy_rubble.jpg\")"
+    "vis.plot_transformed_image(\"test_images/teddy_rubble.jpg\", render_factor=40)"
    ]
   },
   {
@@ -352,7 +473,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "vis.plot_transformed_image(\"test_images/unnamed.jpg\")"
+    "vis.plot_transformed_image(\"test_images/unnamed.jpg\", render_factor=40)"
    ]
   },
   {
@@ -487,7 +608,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "vis.plot_transformed_image(\"test_images/women-bikers.png\", figsize=(60,60))"
+    "vis.plot_transformed_image(\"test_images/women-bikers.png\", figsize=(60,60), render_factor=42)"
    ]
   },
   {
@@ -550,7 +671,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "vis.plot_transformed_image(\"test_images/poverty.jpg\")"
+    "vis.plot_transformed_image(\"test_images/poverty.jpg\", render_factor=40)"
    ]
   },
   {
@@ -595,7 +716,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "vis.plot_transformed_image(\"test_images/FatMenClub.jpg\", render_factor=35)"
+    "vis.plot_transformed_image(\"test_images/FatMenClub.jpg\", render_factor=31)"
    ]
   },
   {
@@ -721,7 +842,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "vis.plot_transformed_image(\"test_images/bicycles.jpg\",render_factor=17)"
+    "vis.plot_transformed_image(\"test_images/bicycles.jpg\")"
    ]
   },
   {
@@ -982,7 +1103,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "vis.plot_transformed_image(\"test_images/1860Girls.jpg\")"
+    "vis.plot_transformed_image(\"test_images/1860Girls.jpg\", render_factor=41)"
    ]
   },
   {
@@ -1072,7 +1193,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "vis.plot_transformed_image(\"test_images/ViennaBoys1880s.png\")"
+    "vis.plot_transformed_image(\"test_images/ViennaBoys1880s.png\", render_factor=17)"
    ]
   },
   {
@@ -1099,7 +1220,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "vis.plot_transformed_image(\"test_images/ArkansasCowboys1880s.jpg\")"
+    "vis.plot_transformed_image(\"test_images/ArkansasCowboys1880s.jpg\", render_factor=24)"
    ]
   },
   {
@@ -1108,7 +1229,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "vis.plot_transformed_image(\"test_images/Ballet1890Russia.jpg\")"
+    "vis.plot_transformed_image(\"test_images/Ballet1890Russia.jpg\", render_factor=34)"
    ]
   },
   {
@@ -1144,7 +1265,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "vis.plot_transformed_image(\"test_images/Harlem1932.jpg\")"
+    "vis.plot_transformed_image(\"test_images/Harlem1932.jpg\", render_factor=41)"
    ]
   },
   {
@@ -1207,7 +1328,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "vis.plot_transformed_image(\"test_images/1890sTouristsEgypt.png\")"
+    "vis.plot_transformed_image(\"test_images/1890sTouristsEgypt.png\", render_factor=40)"
    ]
   },
   {
@@ -1243,7 +1364,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "vis.plot_transformed_image(\"test_images/Texas1938Woman.png\")"
+    "vis.plot_transformed_image(\"test_images/Texas1938Woman.png\", render_factor=38)"
    ]
   },
   {
@@ -1468,7 +1589,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "vis.plot_transformed_image(\"test_images/Apsaroke1908.png\")"
+    "vis.plot_transformed_image(\"test_images/Apsaroke1908.png\", render_factor=38)"
    ]
   },
   {
@@ -1558,7 +1679,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "vis.plot_transformed_image(\"test_images/MadisonSquare1900.jpg\")"
+    "vis.plot_transformed_image(\"test_images/MadisonSquare1900.jpg\", render_factor=40)"
    ]
   },
   {
@@ -1666,7 +1787,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "vis.plot_transformed_image(\"test_images/1892WaterLillies.jpg\")"
+    "vis.plot_transformed_image(\"test_images/1892WaterLillies.jpg\", render_factor=33)"
    ]
   },
   {
@@ -1936,7 +2057,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "vis.plot_transformed_image(\"test_images/Donegal1907Yarn.jpg\")"
+    "vis.plot_transformed_image(\"test_images/Donegal1907Yarn.jpg\", render_factor=40)"
    ]
   },
   {
@@ -2404,7 +2525,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "vis.plot_transformed_image(\"test_images/TheatreGroupBombay1875.jpg\")"
+    "vis.plot_transformed_image(\"test_images/TheatreGroupBombay1875.jpg\", render_factor=42)"
    ]
   },
   {
@@ -2566,7 +2687,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "vis.plot_transformed_image(\"test_images/SecondHandClothesLondonLate1800s.jpg\")"
+    "vis.plot_transformed_image(\"test_images/SecondHandClothesLondonLate1800s.jpg\", render_factor=43)"
    ]
   },
   {

+ 64 - 43
DeFadeTraining.ipynb

@@ -34,7 +34,7 @@
     "from pathlib import Path\n",
     "from itertools import repeat\n",
     "import tensorboardX\n",
-    "torch.cuda.set_device(3)\n",
+    "torch.cuda.set_device(0)\n",
     "plt.style.use('dark_background')\n",
     "torch.backends.cudnn.benchmark=True\n"
    ]
@@ -46,19 +46,21 @@
    "outputs": [],
    "source": [
     "IMAGENET = Path('data/imagenet/ILSVRC/Data/CLS-LOC/train')\n",
-    "proj_id = 'defade'\n",
+    "proj_id = 'bwdefade'\n",
     "TENSORBOARD_PATH = Path('data/tensorboard/' + proj_id)\n",
-    "gpath = IMAGENET.parent/(proj_id + '_gen_192.h5')\n",
-    "dpath = IMAGENET.parent/(proj_id + '_critic_192.h5')\n",
-    "c_lr=2e-4\n",
+    "gpath = IMAGENET.parent/(proj_id + '_gen_64.h5')\n",
+    "dpath = IMAGENET.parent/(proj_id + '_critic_64.h5')\n",
+    "c_lr=5e-4\n",
     "c_lrs = np.array([c_lr,c_lr,c_lr])\n",
-    "g_lr=c_lr/4\n",
-    "g_lrs = np.array([g_lr/1000,g_lr/100,g_lr])\n",
+    "\n",
+    "g_lr=c_lr/5\n",
+    "g_lrs = np.array([g_lr/100,g_lr/10,g_lr])\n",
+    "\n",
     "keep_pcts=[0.25,0.25]\n",
     "gen_freeze_tos=[-1,0]\n",
-    "lrs_unfreeze_factor=1.0\n",
+    "lrs_unfreeze_factor=0.05\n",
     "x_tfms = [RandomLighting(0.5, 0.5)]\n",
-    "extra_aug_tfms = []\n",
+    "extra_aug_tfms = [BlackAndWhiteTransform(tfm_y=TfmType.PIXEL)]\n",
     "torch.backends.cudnn.benchmark=True"
    ]
   },
@@ -79,7 +81,7 @@
     "#netGVis = ModelVisualizationHook(TENSORBOARD_PATH, netG, 'netG')\n",
     "#load_model(netG, gpath)\n",
     "\n",
-    "netD = DCCritic(ni=3, nf=512).cuda()\n",
+    "netD = DCCritic(ni=3, nf=384).cuda()\n",
     "#netDVis = ModelVisualizationHook(TENSORBOARD_PATH, netD, 'netD')\n",
     "#load_model(netD, dpath)"
    ]
@@ -102,64 +104,83 @@
    "source": [
     "scheds=[]\n",
     "\n",
-    "scheds.extend(GANTrainSchedule.generate_schedules(szs=[64,64], bss=[32,32], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms,keep_pcts=keep_pcts, \n",
+    "scheds.extend(GANTrainSchedule.generate_schedules(szs=[64,64], bss=[128,128], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms,keep_pcts=[1.0,1.0], \n",
     "    save_base_name=proj_id, c_lrs=c_lrs, g_lrs=g_lrs, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=gen_freeze_tos))\n",
     "\n",
+    "c_lrs=c_lrs/2\n",
+    "g_lrs=g_lrs/2\n",
+    "\n",
     "#unshock\n",
-    "scheds.extend(GANTrainSchedule.generate_schedules(szs=[96], bss=[16], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.1], \n",
-    "    save_base_name=proj_id, c_lrs=c_lrs/20, g_lrs=g_lrs/20, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))\n",
+    "scheds.extend(GANTrainSchedule.generate_schedules(szs=[96], bss=[64], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.1], \n",
+    "    save_base_name=proj_id, c_lrs=c_lrs/10, g_lrs=g_lrs/10, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))\n",
     "\n",
-    "scheds.extend(GANTrainSchedule.generate_schedules(szs=[96,96], bss=[16,16], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=keep_pcts, \n",
-    "    save_base_name=proj_id, c_lrs=c_lrs/2, g_lrs=g_lrs/2, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=gen_freeze_tos))\n",
+    "scheds.extend(GANTrainSchedule.generate_schedules(szs=[96,96], bss=[64,64], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=keep_pcts, \n",
+    "    save_base_name=proj_id, c_lrs=c_lrs, g_lrs=g_lrs, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=gen_freeze_tos))\n",
     "\n",
+    "c_lrs=c_lrs/2\n",
+    "g_lrs=g_lrs/2\n",
     "\n",
     "#unshock\n",
-    "scheds.extend(GANTrainSchedule.generate_schedules(szs=[128], bss=[8], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.1], \n",
-    "    save_base_name=proj_id, c_lrs=c_lrs/30, g_lrs=g_lrs/30, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))\n",
+    "scheds.extend(GANTrainSchedule.generate_schedules(szs=[128], bss=[32], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.1], \n",
+    "    save_base_name=proj_id, c_lrs=c_lrs/10, g_lrs=g_lrs/10, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))\n",
+    "\n",
+    "scheds.extend(GANTrainSchedule.generate_schedules(szs=[128,128], bss=[32,32], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=keep_pcts, \n",
+    "    save_base_name=proj_id, c_lrs=c_lrs, g_lrs=g_lrs, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=gen_freeze_tos))\n",
     "\n",
-    "scheds.extend(GANTrainSchedule.generate_schedules(szs=[128,128], bss=[8,8], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=keep_pcts, \n",
-    "    save_base_name=proj_id, c_lrs=c_lrs/3, g_lrs=g_lrs/3, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=gen_freeze_tos))\n",
+    "\n",
+    "c_lrs=c_lrs/2\n",
+    "g_lrs=g_lrs/2\n",
     "\n",
     "#unshock\n",
-    "scheds.extend(GANTrainSchedule.generate_schedules(szs=[160], bss=[5], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.1], \n",
-    "    save_base_name=proj_id, c_lrs=c_lrs/80, g_lrs=g_lrs/80, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))\n",
+    "scheds.extend(GANTrainSchedule.generate_schedules(szs=[160], bss=[20], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.1], \n",
+    "    save_base_name=proj_id, c_lrs=c_lrs/10, g_lrs=g_lrs/10, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))\n",
+    "\n",
+    "scheds.extend(GANTrainSchedule.generate_schedules(szs=[160], bss=[20], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.25], \n",
+    "    save_base_name=proj_id, c_lrs=c_lrs, g_lrs=g_lrs, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))\n",
+    "\n",
+    "scheds.extend(GANTrainSchedule.generate_schedules(szs=[160], bss=[20], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.25], \n",
+    "    save_base_name=proj_id, c_lrs=c_lrs, g_lrs=g_lrs, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[0]))\n",
     "\n",
-    "scheds.extend(GANTrainSchedule.generate_schedules(szs=[160], bss=[5], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.25], \n",
-    "    save_base_name=proj_id, c_lrs=c_lrs/8, g_lrs=g_lrs/8, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))\n",
     "\n",
-    "scheds.extend(GANTrainSchedule.generate_schedules(szs=[160], bss=[5], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.25], \n",
-    "    save_base_name=proj_id, c_lrs=c_lrs/16, g_lrs=g_lrs/16, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[0]))\n",
+    "c_lrs=c_lrs/2\n",
+    "g_lrs=g_lrs/2\n",
     "\n",
     "#unshock\n",
-    "scheds.extend(GANTrainSchedule.generate_schedules(szs=[192], bss=[4], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.1], \n",
-    "    save_base_name=proj_id, c_lrs=c_lrs/80, g_lrs=g_lrs/80, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))\n",
+    "scheds.extend(GANTrainSchedule.generate_schedules(szs=[192], bss=[12], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.1], \n",
+    "    save_base_name=proj_id, c_lrs=c_lrs/10, g_lrs=g_lrs/10, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))\n",
+    "\n",
+    "scheds.extend(GANTrainSchedule.generate_schedules(szs=[192], bss=[12], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.25], \n",
+    "    save_base_name=proj_id, c_lrs=c_lrs, g_lrs=g_lrs, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))\n",
+    "\n",
+    "scheds.extend(GANTrainSchedule.generate_schedules(szs=[192], bss=[12], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.25], \n",
+    "    save_base_name=proj_id, c_lrs=c_lrs, g_lrs=g_lrs, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[0]))\n",
     "\n",
-    "scheds.extend(GANTrainSchedule.generate_schedules(szs=[192], bss=[3], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.25], \n",
-    "    save_base_name=proj_id, c_lrs=c_lrs/8, g_lrs=g_lrs/8, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))\n",
     "\n",
-    "scheds.extend(GANTrainSchedule.generate_schedules(szs=[192], bss=[3], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.25], \n",
-    "    save_base_name=proj_id, c_lrs=c_lrs/16, g_lrs=g_lrs/16, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[0]))\n",
+    "c_lrs=c_lrs/2\n",
+    "g_lrs=g_lrs/2\n",
     "\n",
     "#unshock\n",
-    "scheds.extend(GANTrainSchedule.generate_schedules(szs=[224], bss=[2], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.1], \n",
-    "    save_base_name=proj_id, c_lrs=c_lrs/160, g_lrs=g_lrs/160, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))\n",
+    "scheds.extend(GANTrainSchedule.generate_schedules(szs=[224], bss=[8], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.1], \n",
+    "    save_base_name=proj_id, c_lrs=c_lrs/10, g_lrs=g_lrs/10, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))\n",
     "\n",
-    "scheds.extend(GANTrainSchedule.generate_schedules(szs=[224], bss=[2], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.25], \n",
-    "    save_base_name=proj_id, c_lrs=c_lrs/16, g_lrs=g_lrs/16, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))\n",
+    "scheds.extend(GANTrainSchedule.generate_schedules(szs=[224], bss=[8], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.25], \n",
+    "    save_base_name=proj_id, c_lrs=c_lrs, g_lrs=g_lrs, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))\n",
     "\n",
-    "scheds.extend(GANTrainSchedule.generate_schedules(szs=[224], bss=[2], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.25], \n",
-    "    save_base_name=proj_id, c_lrs=c_lrs/32, g_lrs=g_lrs/32, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[0]))\n",
+    "scheds.extend(GANTrainSchedule.generate_schedules(szs=[224], bss=[8], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.25], \n",
+    "    save_base_name=proj_id, c_lrs=c_lrs, g_lrs=g_lrs, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[0]))\n",
     "\n",
+    "c_lrs=c_lrs/2\n",
+    "g_lrs=g_lrs/2\n",
     "\n",
     "#unshock\n",
-    "scheds.extend(GANTrainSchedule.generate_schedules(szs=[256], bss=[2], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.1], \n",
-    "    save_base_name=proj_id, c_lrs=c_lrs/160, g_lrs=g_lrs/160, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))\n",
+    "scheds.extend(GANTrainSchedule.generate_schedules(szs=[256], bss=[6], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.1], \n",
+    "    save_base_name=proj_id, c_lrs=c_lrs/10, g_lrs=g_lrs/10, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))\n",
     "\n",
-    "scheds.extend(GANTrainSchedule.generate_schedules(szs=[256], bss=[2], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.25], \n",
-    "    save_base_name=proj_id, c_lrs=c_lrs/16, g_lrs=g_lrs/16, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))\n",
+    "scheds.extend(GANTrainSchedule.generate_schedules(szs=[256], bss=[6], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.25], \n",
+    "    save_base_name=proj_id, c_lrs=c_lrs, g_lrs=g_lrs, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))\n",
     "\n",
-    "scheds.extend(GANTrainSchedule.generate_schedules(szs=[256], bss=[2], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.25], \n",
-    "    save_base_name=proj_id, c_lrs=c_lrs/32, g_lrs=g_lrs/32, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[0]))\n"
+    "scheds.extend(GANTrainSchedule.generate_schedules(szs=[256], bss=[6], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.25], \n",
+    "    save_base_name=proj_id, c_lrs=c_lrs, g_lrs=g_lrs, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[0]))"
    ]
   },
   {

+ 47 - 16
fasterai/filters.py

@@ -1,6 +1,6 @@
 from numpy import ndarray
 from abc import ABC, abstractmethod
-from .generators import Unet34, GeneratorModule
+from .generators import Unet34, Unet101, GeneratorModule
 from .transforms import BlackAndWhiteTransform
 from fastai.torch_imports import *
 from fastai.core import *
@@ -23,7 +23,7 @@ class Filter(ABC):
         self.denorm = Denormalize(*inception_stats)
     
     @abstractmethod
-    def filter(self, orig_image:ndarray, render_factor:int)->ndarray:
+    def filter(self, orig_image:ndarray, filtered_image:ndarray, render_factor:int)->ndarray:
         pass
 
     def _init_model(self, model:nn.Module, weights_path:Path):
@@ -58,9 +58,10 @@ class Filter(ABC):
             image = image[None]
         return self.denorm(np.rollaxis(image,1,4))
 
-    def _model_process(self, model:GeneratorModule, orig:ndarray, sz:int):
+    def _model_process(self, model:GeneratorModule, orig:ndarray, sz:int, gpu:int):
         orig = self._get_model_ready_image_ndarray(orig, sz)
         orig = VV_(orig[None]) 
+        orig = orig.to(device=gpu)
         result = model(orig)
         result = result.detach().cpu().numpy()
         result = self._denorm(result)
@@ -75,17 +76,27 @@ class Filter(ABC):
         return cv2.resize(result, sz, interpolation=cv2.INTER_AREA)  
 
 
-class Colorizer(Filter):
-    def __init__(self, gpu:int, weights_path:Path):
+
+class AbstractColorizer(Filter):
+    def __init__(self, gpu:int, weights_path:Path, nf_factor:int=2, map_to_orig:bool=True):
         super().__init__(tfms=[BlackAndWhiteTransform()])
-        self.model = Unet34(nf_factor=2).cuda(gpu)
+        self.model = self._get_model(nf_factor=nf_factor, gpu=gpu)
+        self.gpu = gpu
         self._init_model(self.model, weights_path)
         self.render_base=16
+        self.map_to_orig=map_to_orig
+
+    @abstractmethod
+    def _get_model(self, nf_factor:int, gpu:int)->GeneratorModule:
+        pass
     
-    def filter(self, orig_image:ndarray, render_factor:int=36)->ndarray:
+    def filter(self, orig_image:ndarray, filtered_image:ndarray, render_factor:int=36)->ndarray:
         render_sz = render_factor * self.render_base
-        model_image = self._model_process(self.model, orig=orig_image, sz=render_sz)
-        return self._post_process(model_image, orig_image)
+        model_image = self._model_process(self.model, orig=filtered_image, sz=render_sz, gpu=self.gpu)
+        if self.map_to_orig:
+            return self._post_process(model_image, orig_image)
+        else:
+            return self._post_process(model_image, filtered_image)
 
 
     #This takes advantage of the fact that human eyes are much less sensitive to 
@@ -105,16 +116,36 @@ class Colorizer(Filter):
         hires[:,:,1:3] = color_yuv[:,:,1:3]
         return cv2.cvtColor(hires, cv2.COLOR_YUV2BGR)   
 
+class Colorizer34(AbstractColorizer):
+    def __init__(self, gpu:int, weights_path:Path, nf_factor:int=2, map_to_orig:bool=True):
+        super().__init__(gpu=gpu, weights_path=weights_path, nf_factor=nf_factor, map_to_orig=map_to_orig)
+
+    def _get_model(self, nf_factor:int, gpu:int)->GeneratorModule:
+        return Unet34(nf_factor=nf_factor).cuda(gpu)
+
+
+class Colorizer101(AbstractColorizer):
+    def __init__(self, gpu:int, weights_path:Path, nf_factor:int=2, map_to_orig:bool=True):
+        super().__init__(gpu=gpu, weights_path=weights_path, nf_factor=nf_factor, map_to_orig=map_to_orig)
+
+    def _get_model(self, nf_factor:int, gpu:int)->GeneratorModule:
+        return Unet101(nf_factor=nf_factor).cuda(gpu)
+
+
 #TODO:  May not want to do square rendering here like in colorization- it definitely loses 
 #fidelity visibly (but not too terribly).  Will revisit.
 class DeFader(Filter): 
-    def __init__(self, gpu:int, weights_path:Path):
-        super().__init__(tfms=[])
-        self.model = Unet34(nf_factor=2).cuda(gpu)
+    def __init__(self, gpu:int, weights_path:Path, nf_factor:int=2):
+        super().__init__(tfms=[BlackAndWhiteTransform()])
+        self.model = Unet34(nf_factor=nf_factor).cuda(gpu)
         self._init_model(self.model, weights_path)
-        self.render_base=16    
+        self.render_base=16
+        self.gpu = gpu
 
-    def filter(self, orig_image:ndarray, render_factor:int=36)->ndarray:
+    def filter(self, orig_image:ndarray, filtered_image:ndarray, render_factor:int=36)->ndarray:
         render_sz = render_factor * self.render_base
-        model_image = self._model_process(self.model, orig=orig_image, sz=render_sz)
-        return self._unsquare(model_image, orig_image)
+        model_image = self._model_process(self.model, orig=filtered_image, sz=render_sz, gpu=self.gpu)
+        return self._post_process(model_image, filtered_image)
+
+    def _post_process(self, result:ndarray, orig:ndarray):
+        return self._unsquare(result, orig)

+ 68 - 21
fasterai/generators.py

@@ -4,6 +4,7 @@ from fastai.transforms import scale_min
 from .modules import ConvBlock, UnetBlock, UpSampleBlock, SaveFeatures
 from abc import ABC, abstractmethod
 from torchvision import transforms
+from torch.nn.utils.spectral_norm import spectral_norm
 
 class GeneratorModule(ABC, nn.Module):
     def __init__(self):
@@ -29,30 +30,27 @@ class GeneratorModule(ABC, nn.Module):
         return next(self.parameters()).device
 
 
-class Unet34(GeneratorModule): 
-    @staticmethod
-    def _get_pretrained_resnet_base(layers_cut:int=0):
-        f = resnet34
-        cut,lr_cut = model_meta[f]
-        cut-=layers_cut
-        layers = cut_model(f(True), cut)
-        return nn.Sequential(*layers), lr_cut
-
+class AbstractUnet(GeneratorModule): 
     def __init__(self, nf_factor:int=1, scale:int=1):
         super().__init__()
         assert (math.log(scale,2)).is_integer()
-        leakyReLu=False
-        self_attention=True
-        bn=True
-        sn=True
-        self.rn, self.lr_cut = Unet34._get_pretrained_resnet_base()
+        self.rn, self.lr_cut = self._get_pretrained_resnet_base()
+        ups = self._get_decoding_layers(nf_factor=nf_factor, scale=scale)
         self.relu = nn.ReLU()
-        self.up1 = UnetBlock(512,256,512*nf_factor, sn=sn, leakyReLu=leakyReLu, bn=bn)
-        self.up2 = UnetBlock(512*nf_factor,128,512*nf_factor, sn=sn, leakyReLu=leakyReLu, bn=bn)
-        self.up3 = UnetBlock(512*nf_factor,64,512*nf_factor, sn=sn, self_attention=self_attention, leakyReLu=leakyReLu, bn=bn)
-        self.up4 = UnetBlock(512*nf_factor,64,256*nf_factor, sn=sn, leakyReLu=leakyReLu, bn=bn)
-        self.up5 = UpSampleBlock(256*nf_factor, 32*nf_factor, 2*scale, sn=sn, leakyReLu=leakyReLu, bn=bn) 
-        self.out= nn.Sequential(ConvBlock(32*nf_factor, 3, ks=3, actn=False, bn=False, sn=sn), nn.Tanh())
+        self.up1 = ups[0]
+        self.up2 = ups[1]
+        self.up3 = ups[2]
+        self.up4 = ups[3]
+        self.up5 = ups[4]
+        self.out= nn.Sequential(ConvBlock(32*nf_factor, 3, ks=3, actn=False, bn=False, sn=True), nn.Tanh())
+
+    @abstractmethod
+    def _get_pretrained_resnet_base(self, layers_cut:int=0):
+        pass
+
+    @abstractmethod
+    def _get_decoding_layers(self, nf_factor:int, scale:int):
+        pass
 
     #Gets around irritating inconsistent halving coming from resnet
     def _pad(self, x:torch.Tensor, target:torch.Tensor, total_padh:int, total_padw:int)-> torch.Tensor:
@@ -128,4 +126,53 @@ class Unet34(GeneratorModule):
         for sf in self.sfs: 
             sf.remove()
 
- 
+
+class Unet34(AbstractUnet): 
+    def __init__(self, nf_factor:int=1, scale:int=1):
+        super().__init__(nf_factor=nf_factor, scale=scale)
+
+    def _get_pretrained_resnet_base(self, layers_cut:int=0):
+        f = resnet34
+        cut,lr_cut = model_meta[f]
+        cut-=layers_cut
+        layers = cut_model(f(True), cut)
+        return nn.Sequential(*layers), lr_cut
+
+    def _get_decoding_layers(self, nf_factor:int, scale:int):
+        self_attention=True
+        bn=True
+        sn=True
+        leakyReLu=False
+        layers = []
+        layers.append(UnetBlock(512,256,512*nf_factor, sn=sn, leakyReLu=leakyReLu, bn=bn))
+        layers.append(UnetBlock(512*nf_factor,128,512*nf_factor, sn=sn, leakyReLu=leakyReLu, bn=bn))
+        layers.append(UnetBlock(512*nf_factor,64,512*nf_factor, sn=sn, self_attention=self_attention, leakyReLu=leakyReLu, bn=bn))
+        layers.append(UnetBlock(512*nf_factor,64,256*nf_factor, sn=sn, leakyReLu=leakyReLu, bn=bn))
+        layers.append(UpSampleBlock(256*nf_factor, 32*nf_factor, 2*scale, sn=sn, leakyReLu=leakyReLu, bn=bn))
+        return layers 
+
+
+class Unet101(AbstractUnet): 
+    def __init__(self, nf_factor:int=1, scale:int=1):
+        super().__init__(nf_factor=nf_factor, scale=scale)
+
+    def _get_pretrained_resnet_base(self, layers_cut:int=0):
+        f = resnet101
+        cut,lr_cut = model_meta[f]
+        cut-=layers_cut
+        layers = cut_model(f(True), cut)
+        return nn.Sequential(*layers), lr_cut
+
+    def _get_decoding_layers(self, nf_factor:int, scale:int):
+        self_attention=True
+        bn=True
+        sn=True
+        leakyReLu=False
+        layers = []
+        layers.append(UnetBlock(2048,1024,512*nf_factor, sn=sn, leakyReLu=leakyReLu, bn=bn))
+        layers.append(UnetBlock(512*nf_factor,512,512*nf_factor, sn=sn, leakyReLu=leakyReLu, bn=bn))
+        layers.append(UnetBlock(512*nf_factor,256,512*nf_factor, sn=sn, self_attention=self_attention, leakyReLu=leakyReLu, bn=bn))
+        layers.append(UnetBlock(512*nf_factor,64,256*nf_factor, sn=sn, leakyReLu=leakyReLu, bn=bn))
+        layers.append(UpSampleBlock(256*nf_factor, 32*nf_factor, 2*scale, sn=sn, leakyReLu=leakyReLu, bn=bn))
+        return layers 
+

+ 5 - 5
fasterai/visualize.py

@@ -11,7 +11,7 @@ from fasterai.transforms import BlackAndWhiteTransform
 from .training import GenResult, CriticResult, GANTrainer
 from .images import ModelImageSet, EasyTensorImage
 from .generators import GeneratorModule
-from .filters import Filter, Colorizer
+from .filters import Filter
 from IPython.display import display
 from tensorboardX import SummaryWriter
 from scipy import misc
@@ -46,14 +46,14 @@ class ModelImageVisualizer():
         misc.imsave(result_path, np.clip(result,0,1))
 
     def _get_transformed_image_ndarray(self, path:Path, render_factor:int=None):
-        orig = open_image(str(path))
-        result = orig
+        orig_image = open_image(str(path))
+        filtered_image = orig_image
         render_factor = self.render_factor if render_factor is None else render_factor
 
         for filt in self.filters:
-            result = filt.filter(result, render_factor=render_factor)
+            filtered_image = filt.filter(orig_image, filtered_image, render_factor=render_factor)
 
-        return result
+        return filtered_image
 
     def _plot_image_from_ndarray(self, image:ndarray, axes:Axes=None, figsize=(20,20)):
         if axes is None: