ソースを参照

More progress on FastAI v1 upgrade

Jason Antic 6 年 前
コミット
2a52040220
8 ファイル変更988 行追加93 行削除
  1. 5 0
      .gitignore
  2. 25 29
      ColorizeTraining.ipynb
  3. 740 0
      SuperResTraining.ipynb
  4. 4 2
      fasterai/critics.py
  5. 37 1
      fasterai/generators.py
  6. 32 0
      fasterai/loss.py
  7. 54 57
      fasterai/tensorboard.py
  8. 91 4
      fasterai/unet.py

+ 5 - 0
.gitignore

@@ -488,3 +488,8 @@ fastai
 .ipynb_checkpoints/SuperResTraining-checkpoint.ipynb
 .ipynb_checkpoints/ColorizeTrainingNew2-checkpoint.ipynb
 .ipynb_checkpoints/ColorizeTrainingNew-checkpoint.ipynb
+.ipynb_checkpoints/ColorizeTrainingNew3-checkpoint.ipynb
+.ipynb_checkpoints/ColorizeTrainingNew4-checkpoint.ipynb
+ColorizeTrainingNew2.ipynb
+ColorizeTrainingNew3.ipynb
+ColorizeTrainingNew4.ipynb

+ 25 - 29
ColorizeTraining.ipynb

@@ -6,9 +6,8 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "%matplotlib inline\n",
-    "%reload_ext autoreload\n",
-    "%autoreload 2"
+    "import os\n",
+    "os.environ['CUDA_VISIBLE_DEVICES']='0' "
    ]
   },
   {
@@ -30,9 +29,9 @@
     "from fasterai.generators import *\n",
     "from pathlib import Path\n",
     "from itertools import repeat\n",
-    "torch.cuda.set_device(2)\n",
     "plt.style.use('dark_background')\n",
-    "torch.backends.cudnn.benchmark=True\n"
+    "torch.backends.cudnn.benchmark=True\n",
+    "from PIL import ImageFile"
    ]
   },
   {
@@ -44,7 +43,7 @@
     "IMAGENET = Path('data/imagenet/ILSVRC/Data/CLS-LOC')\n",
     "BWIMAGENET = Path('data/imagenet/ILSVRC/Data/CLS-LOC/bandw')\n",
     "\n",
-    "proj_id = 'colorizeV5o'\n",
+    "proj_id = 'colorizeESR45'\n",
     "TENSORBOARD_PATH = Path('data/tensorboard/' + proj_id)\n",
     "\n",
     "gpath = IMAGENET.parent/(proj_id + '_gen_64.h5')\n",
@@ -132,8 +131,8 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "def colorize_gen_learner_exp(data:ImageDataBunch, gen_loss=FeatureLoss(), arch=models.resnet34):\n",
-    "    return unet_learner2(data, arch, wd=1e-3, blur=True, norm_type=NormType.Spectral,\n",
+    "def colorize_gen_learner_exp(data:ImageDataBunch, gen_loss=FeatureLoss4(), arch=models.resnet34):\n",
+    "    return unet_learner3(data, arch, wd=1e-3, blur=True, norm_type=NormType.Spectral,\n",
     "                        self_attention=True, y_range=(-3.,3.), loss_func=gen_loss)"
    ]
   },
@@ -152,24 +151,24 @@
    "source": [
     "#Needed to instantiate critic but not actually used\n",
     "sz=64\n",
-    "bs=32\n",
+    "bs=128\n",
     "\n",
     "data = get_data(sz=sz, bs=bs, keep_pct=1.0)\n",
     "learn_crit = colorize_crit_learner(data=data, nf=256)\n",
     "learn_crit.unfreeze()\n",
     "\n",
-    "gen_loss = FeatureLoss2(gram_wgt=5e3)\n",
+    "gen_loss = FeatureLoss4()\n",
     "learn_gen = colorize_gen_learner_exp(data=data)\n",
     "\n",
     "switcher = partial(AdaptiveGANSwitcher, critic_thresh=0.65)\n",
-    "learn = GANLearner.from_learners(learn_gen, learn_crit, weights_gen=(1.0,1.0), show_img=False, switcher=switcher,\n",
-    "                                 opt_func=partial(optim.Adam, betas=(0.,0.99)), wd=1e-3)\n",
+    "learn = GANLearner.from_learners(learn_gen, learn_crit, weights_gen=(1.0,2.0), show_img=False, switcher=switcher,\n",
+    "                                 opt_func=partial(optim.Adam, betas=(0.,0.9)), wd=1e-3)\n",
     "\n",
     "learn.callback_fns.append(partial(GANDiscriminativeLR, mult_lr=5.))\n",
     "learn.callback_fns.append(partial(GANTensorboardWriter, base_dir=TENSORBOARD_PATH, name='GanLearner', visual_iters=100))\n",
     "\n",
     "lr=1e-4\n",
-    "unfreeze_fctr=0.05"
+    "unfreeze_fctr=0.1"
    ]
   },
   {
@@ -230,10 +229,9 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "load()\n",
-    "lr=lr/2\n",
+    "#lr=lr/2\n",
     "sz=96\n",
-    "#bs=bs//2"
+    "bs=bs//2"
    ]
   },
   {
@@ -262,7 +260,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "learn.data = get_data(sz=sz, bs=bs, keep_pct=0.25)"
+    "learn.data = get_data(sz=sz, bs=bs, keep_pct=1.0)"
    ]
   },
   {
@@ -316,7 +314,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "lr=lr/2\n",
+    "#lr=lr/2\n",
     "sz=128\n",
     "bs=bs//2"
    ]
@@ -347,7 +345,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "learn.data = get_data(sz=sz, bs=bs, keep_pct=0.25)"
+    "learn.data = get_data(sz=sz, bs=bs, keep_pct=1.0)"
    ]
   },
   {
@@ -406,15 +404,6 @@
     "bs=int(bs//1.5)"
    ]
   },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "bs=10"
-   ]
-  },
   {
    "cell_type": "code",
    "execution_count": null,
@@ -582,7 +571,7 @@
    "source": [
     "lr=lr/1.5\n",
     "sz=224\n",
-    "bs=int(bs//1.5)"
+    "bs=bs//1.5"
    ]
   },
   {
@@ -652,6 +641,13 @@
     "save()"
    ]
   },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  },
   {
    "cell_type": "code",
    "execution_count": null,

+ 740 - 0
SuperResTraining.ipynb

@@ -0,0 +1,740 @@
+{
+ "cells": [
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Pretrained GAN"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import os\n",
+    "os.environ['CUDA_VISIBLE_DEVICES']='2' "
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import fastai\n",
+    "from fastai import *\n",
+    "from fastai.vision import *\n",
+    "from fastai.callbacks import *\n",
+    "from fastai.vision.gan import *\n",
+    "from fasterai.generators import *\n",
+    "from fasterai.tensorboard import *"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "path = untar_data(URLs.PETS)\n",
+    "path_hr = path/'images'\n",
+    "path_lr = path/'crappy'\n",
+    "\n",
+    "proj_id = 'SuperResRefine5c'\n",
+    "TENSORBOARD_PATH = Path('data/tensorboard/' + proj_id)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Crappified data"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "Prepare the input data by crappifying images."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from PIL import Image, ImageDraw, ImageFont"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def crappify(fn,i):\n",
+    "    dest = path_lr/fn.relative_to(path_hr)\n",
+    "    dest.parent.mkdir(parents=True, exist_ok=True)\n",
+    "    img = PIL.Image.open(fn)\n",
+    "    targ_sz = resize_to(img, 96, use_min=True)\n",
+    "    img = img.resize(targ_sz, resample=PIL.Image.BILINEAR).convert('RGB')\n",
+    "    w,h = img.size\n",
+    "    q = random.randint(10,70)\n",
+    "    ImageDraw.Draw(img).text((random.randint(0,w//2),random.randint(0,h//2)), str(q), fill=(255,255,255))\n",
+    "    img.save(dest, quality=q)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "Uncomment the first time you run this notebook."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "#il = ImageItemList.from_folder(path_hr)\n",
+    "#parallel(crappify, il.items)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "For gradual resizing we can change the commented line here."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "bs,size=32, 128\n",
+    "# bs,size = 24,160\n",
+    "#bs,size = 8,256\n",
+    "arch = models.resnet34"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Pre-train generator"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "Now let's pretrain the generator."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "arch = models.resnet34\n",
+    "src = ImageImageList.from_folder(path_lr).random_split_by_pct(0.1, seed=42)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def get_data(bs,size):\n",
+    "    data = (src.label_from_func(lambda x: path_hr/x.name)\n",
+    "           .transform(get_transforms(max_zoom=2.), size=size, tfm_y=True)\n",
+    "           .databunch(bs=bs).normalize(imagenet_stats, do_y=True))\n",
+    "\n",
+    "    data.c = 3\n",
+    "    return data"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "data_gen = get_data(bs,size)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "wd = 1e-3"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "y_range = (-3.,3.)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "loss_gen = FeatureLoss(gram_wgt=5e3)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def create_gen_learner():\n",
+    "    return unet_learner2(data_gen, arch, wd=wd, blur=True, norm_type=NormType.Spectral,\n",
+    "                         self_attention=True, y_range=y_range, loss_func=loss_gen)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "learn_gen = create_gen_learner()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "learn_gen.fit_one_cycle(8, pct_start=0.8)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "learn_gen.unfreeze()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "learn_gen.fit_one_cycle(12, slice(1e-6,1e-3))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "learn_gen.show_results(rows=4)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "learn_gen.save('gen-pre-c')"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Save generated images"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "learn_gen.load('gen-pre-c');"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "name_gen = 'image_gen-c'\n",
+    "path_gen = path/name_gen"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# shutil.rmtree(path_gen)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "path_gen.mkdir(exist_ok=True)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def save_preds(dl):\n",
+    "    i=0\n",
+    "    names = dl.dataset.items\n",
+    "    \n",
+    "    for b in dl:\n",
+    "        preds = learn_gen.pred_batch(batch=b, reconstruct=True)\n",
+    "        for o in preds:\n",
+    "            o.save(path_gen/names[i].name)\n",
+    "            i += 1"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "save_preds(data_gen.fix_dl)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "PIL.Image.open(path_gen.ls()[0])"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Train critic"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "learn_gen=None\n",
+    "gc.collect()"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "Pretrain the critic on crappy vs not crappy."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def get_crit_data(classes, bs, size):\n",
+    "    src = ImageItemList.from_folder(path, include=classes).random_split_by_pct(0.1, seed=42)\n",
+    "    ll = src.label_from_folder(classes=classes)\n",
+    "    data = (ll.transform(get_transforms(max_zoom=2.), size=size)\n",
+    "           .databunch(bs=bs).normalize(imagenet_stats))\n",
+    "    return data"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "data_crit = get_crit_data([name_gen, 'images'], bs=bs, size=size)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "data_crit.show_batch(rows=3, ds_type=DatasetType.Train, imgsize=3)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "loss_critic = AdaptiveLoss(nn.BCEWithLogitsLoss())"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def create_critic_learner(data, metrics):\n",
+    "    learner = Learner(data, gan_critic(), metrics=metrics, loss_func=loss_critic, wd=wd)\n",
+    "    return learner"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "learn_critic = create_critic_learner(data_crit, accuracy_thresh_expand)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "learn_critic.fit_one_cycle(6, 1e-3)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "learn_critic.save('critic-pre-c')"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## GAN"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "Now we'll combine those pretrained model in a GAN."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "learn_crit=None\n",
+    "learn_gen=None\n",
+    "gc.collect()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "data_crit = get_crit_data(['crappy', 'images'], bs=bs, size=size)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "learn_crit = create_critic_learner(data_crit, metrics=None).load('critic-pre-c')"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "learn_gen = create_gen_learner().load('gen-pre-c')"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "To define a GAN Learner, we just have to specify the learner objects foor the generator and the critic. The switcher is a callback that decides when to switch from discriminator to generator and vice versa. Here we do as many iterations of the discriminator as needed to get its loss back < 0.5 then one iteration of the generator.\n",
+    "\n",
+    "The loss of the critic is given by `learn_crit.loss_func`. We take the average of this loss function on the batch of real predictions (target 1) and the batch of fake predicitions (target 0). \n",
+    "\n",
+    "The loss of the generator is weighted sum (weights in `weights_gen`) of `learn_crit.loss_func` on the batch of fake (passed throught the critic to become predictions) with a target of 1, and the `learn_gen.loss_func` applied to the output (batch of fake) and the target (corresponding batch of superres images)."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "switcher = partial(AdaptiveGANSwitcher, critic_thresh=0.65)\n",
+    "learn = GANLearner.from_learners(learn_gen, learn_crit, weights_gen=(1.0,2.0), show_img=False, switcher=switcher,\n",
+    "                                 opt_func=partial(optim.Adam, betas=(0.,0.99)), wd=wd)\n",
+    "learn.callback_fns.append(partial(GANDiscriminativeLR, mult_lr=5.))\n",
+    "learn.callback_fns.append(partial(GANTensorboardWriter, base_dir=TENSORBOARD_PATH, name='GanLearner', visual_iters=100))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "lr = 1e-4"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "learn.fit(10,lr)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "learn.show_results()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "learn.save('gan-c')"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "learn.load('gan-c')"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "learn.data=get_data(14,192)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "learn.fit(10,lr/2)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "learn.show_results(rows=14)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "learn.save('gan-c')"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "learn.data=get_data(7,256)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "learn.fit(10,lr/4)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "learn.save('gan-c')"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "learn.show_results(rows=7)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "learn.fit(20,lr/40)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "learn.save('gan-c')"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "learn.load('gan-c');"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "learn.data=get_data(16,256)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "learn.show_results(rows=14)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "learn.data=get_data(32,192)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "learn.show_results(rows=32)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## fin"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.7.0"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}

+ 4 - 2
fasterai/critics.py

@@ -1,7 +1,7 @@
 from fastai.core import *
 from fastai.torch_core import *
 from fastai.vision import *
-from fastai.vision.gan import AdaptiveLoss
+from fastai.vision.gan import AdaptiveLoss, accuracy_thresh_expand
 
 _conv_args = dict(leaky=0.2, norm_type=NormType.Spectral)
 
@@ -16,13 +16,15 @@ def gan_critic2(n_channels:int=3, nf:int=256, n_blocks:int=3, p:int=0.15):
         nn.Dropout2d(p/2)]
     for i in range(n_blocks):
         layers += [
+            _conv(nf, nf, ks=3, stride=1),
             nn.Dropout2d(p),
             _conv(nf, nf*2, ks=4, stride=2, self_attention=(i==0))]
         nf *= 2
     layers += [
+        _conv(nf, nf, ks=3, stride=1),
         _conv(nf, 1, ks=4, bias=False, padding=0, use_activ=False),
         Flatten()]
     return nn.Sequential(*layers)
 
 def colorize_crit_learner(data:ImageDataBunch, loss_critic=AdaptiveLoss(nn.BCEWithLogitsLoss()), nf:int=256)->Learner:
-    return Learner(data, gan_critic2(nf=nf), metrics=None, loss_func=loss_critic, wd=1e-3)
+    return Learner(data, gan_critic2(nf=nf), metrics=accuracy_thresh_expand, loss_func=loss_critic, wd=1e-3)

+ 37 - 1
fasterai/generators.py

@@ -1,6 +1,6 @@
 from fastai.vision import *
 from fastai.vision.learner import cnn_config
-from fasterai.unet import *
+from fasterai.unet import DynamicUnet2, DynamicUnet3, DynamicUnet4
 from .loss import FeatureLoss
 
 def colorize_gen_learner(data:ImageDataBunch, gen_loss=FeatureLoss(), arch=models.resnet34):
@@ -24,3 +24,39 @@ def unet_learner2(data:DataBunch, arch:Callable, pretrained:bool=True, blur_fina
     if pretrained: learn.freeze()
     apply_init(model[2], nn.init.kaiming_normal_)
     return learn
+
+
+def unet_learner3(data:DataBunch, arch:Callable, pretrained:bool=True, blur_final:bool=True,
+                 norm_type:Optional[NormType]=NormType, split_on:Optional[SplitFuncOrIdxList]=None, 
+                 blur:bool=False, self_attention:bool=False, y_range:Optional[Tuple[float,float]]=None, last_cross:bool=True,
+                 bottle:bool=False, nf_factor:float=1.0, **kwargs:Any)->None:
+    "Build Unet learner from `data` and `arch`."
+    meta = cnn_config(arch)
+    body = create_body(arch, pretrained)
+    model = to_device(DynamicUnet3(body, n_classes=data.c, blur=blur, blur_final=blur_final,
+          self_attention=self_attention, y_range=y_range, norm_type=norm_type, last_cross=last_cross,
+          bottle=bottle, nf_factor=nf_factor), data.device)
+    learn = Learner(data, model, **kwargs)
+    learn.split(ifnone(split_on,meta['split']))
+    if pretrained: learn.freeze()
+    apply_init(model[2], nn.init.kaiming_normal_)
+    return learn
+
+
+#No batch norm in ESRGAN paper
+def unet_learner4(data:DataBunch, arch:Callable, pretrained:bool=True, blur_final:bool=True,
+                 norm_type:Optional[NormType]=NormType, split_on:Optional[SplitFuncOrIdxList]=None, 
+                 blur:bool=False, self_attention:bool=False, y_range:Optional[Tuple[float,float]]=None, last_cross:bool=True,
+                 bottle:bool=False, nf_factor:float=1.0, **kwargs:Any)->None:
+    "Build Unet learner from `data` and `arch`."
+    meta = cnn_config(arch)
+    body = create_body(arch, pretrained)
+    model = to_device(DynamicUnet4(body, n_classes=data.c, blur=blur, blur_final=blur_final,
+          self_attention=self_attention, y_range=y_range, norm_type=norm_type, last_cross=last_cross,
+          bottle=bottle, nf_factor=nf_factor), data.device)
+    learn = Learner(data, model, **kwargs)
+    learn.split(ifnone(split_on,meta['split']))
+    if pretrained: learn.freeze()
+    apply_init(model[2], nn.init.kaiming_normal_)
+    return learn
+

+ 32 - 0
fasterai/loss.py

@@ -156,4 +156,36 @@ class FeatureLoss3(nn.Module):
         self.metrics = dict(zip(self.metric_names, self.feat_losses))
         return sum(self.feat_losses)
     
+    def __del__(self): self.hooks.remove()
+
+
+#"Before activations" in ESRGAN paper
+class FeatureLoss4(nn.Module):
+    def __init__(self, layer_wgts=[5,15,2]):
+        super().__init__()
+
+        self.m_feat = models.vgg16_bn(True).features.cuda().eval()
+        requires_grad(self.m_feat, False)
+        blocks = [i-2 for i,o in enumerate(children(self.m_feat)) if isinstance(o,nn.MaxPool2d)]
+        layer_ids = blocks[2:5]
+        self.loss_features = [self.m_feat[i] for i in layer_ids]
+        self.hooks = hook_outputs(self.loss_features, detach=False)
+        self.wgts = layer_wgts
+        self.metric_names = ['pixel',] + [f'feat_{i}' for i in range(len(layer_ids))] 
+        self.base_loss = F.l1_loss
+
+    def _make_features(self, x, clone=False):
+        self.m_feat(x)
+        return [(o.clone() if clone else o) for o in self.hooks.stored]
+
+    def forward(self, input, target):
+        out_feat = self._make_features(target, clone=True)
+        in_feat = self._make_features(input)
+        self.feat_losses = [self.base_loss(input,target)]
+        self.feat_losses += [self.base_loss(f_in, f_out)*w
+                             for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)]
+        
+        self.metrics = dict(zip(self.metric_names, self.feat_losses))
+        return sum(self.feat_losses)
+    
     def __del__(self): self.hooks.remove()

+ 54 - 57
fasterai/tensorboard.py

@@ -27,8 +27,11 @@ class ModelHistogramVisualizer():
         return 
 
     def write_tensorboard_histograms(self, model:nn.Module, iter_count:int, tbwriter:SummaryWriter, name:str='model'):
-        for param_name, param in model.named_parameters():
-            tbwriter.add_histogram(name + '/weights/' + param_name, param, iter_count)
+        try:
+            for param_name, param in model.named_parameters():
+                tbwriter.add_histogram(name + '/weights/' + param_name, param, iter_count)
+        except Exception as e:
+            print(("Failed to update histogram for model:  {0}").format(e))
 
 
 class ModelStatsVisualizer(): 
@@ -36,39 +39,42 @@ class ModelStatsVisualizer():
         return 
 
     def write_tensorboard_stats(self, model:nn.Module, iter_count:int, tbwriter:SummaryWriter, name:str='model'):
-        gradients = [x.grad  for x in model.parameters() if x.grad is not None]
-        gradient_nps = [to_np(x.data) for x in gradients]
- 
-        if len(gradients) == 0:
-            return 
+        try:
+            gradients = [x.grad  for x in model.parameters() if x.grad is not None]
+            gradient_nps = [to_np(x.data) for x in gradients]
+    
+            if len(gradients) == 0:
+                return 
 
-        avg_norm = sum(x.data.norm() for x in gradients)/len(gradients)
-        tbwriter.add_scalar(name + '/gradients/avg_norm', avg_norm, iter_count)
+            avg_norm = sum(x.data.norm() for x in gradients)/len(gradients)
+            tbwriter.add_scalar(name + '/gradients/avg_norm', avg_norm, iter_count)
 
-        median_norm = statistics.median(x.data.norm() for x in gradients)
-        tbwriter.add_scalar(name + '/gradients/median_norm', median_norm, iter_count)
+            median_norm = statistics.median(x.data.norm() for x in gradients)
+            tbwriter.add_scalar(name + '/gradients/median_norm', median_norm, iter_count)
 
-        max_norm = max(x.data.norm() for x in gradients)
-        tbwriter.add_scalar(name + '/gradients/max_norm', max_norm, iter_count)
+            max_norm = max(x.data.norm() for x in gradients)
+            tbwriter.add_scalar(name + '/gradients/max_norm', max_norm, iter_count)
 
-        min_norm = min(x.data.norm() for x in gradients)
-        tbwriter.add_scalar(name + '/gradients/min_norm', min_norm, iter_count)
+            min_norm = min(x.data.norm() for x in gradients)
+            tbwriter.add_scalar(name + '/gradients/min_norm', min_norm, iter_count)
 
-        num_zeros = sum((np.asarray(x)==0.0).sum() for x in  gradient_nps)
-        tbwriter.add_scalar(name + '/gradients/num_zeros', num_zeros, iter_count)
+            num_zeros = sum((np.asarray(x)==0.0).sum() for x in  gradient_nps)
+            tbwriter.add_scalar(name + '/gradients/num_zeros', num_zeros, iter_count)
 
 
-        avg_gradient= sum(x.data.mean() for x in gradients)/len(gradients)
-        tbwriter.add_scalar(name + '/gradients/avg_gradient', avg_gradient, iter_count)
+            avg_gradient= sum(x.data.mean() for x in gradients)/len(gradients)
+            tbwriter.add_scalar(name + '/gradients/avg_gradient', avg_gradient, iter_count)
 
-        median_gradient = statistics.median(x.data.median() for x in gradients)
-        tbwriter.add_scalar(name + '/gradients/median_gradient', median_gradient, iter_count)
+            median_gradient = statistics.median(x.data.median() for x in gradients)
+            tbwriter.add_scalar(name + '/gradients/median_gradient', median_gradient, iter_count)
 
-        max_gradient = max(x.data.max() for x in gradients) 
-        tbwriter.add_scalar(name + '/gradients/max_gradient', max_gradient, iter_count)
+            max_gradient = max(x.data.max() for x in gradients) 
+            tbwriter.add_scalar(name + '/gradients/max_gradient', max_gradient, iter_count)
 
-        min_gradient = min(x.data.min() for x in gradients) 
-        tbwriter.add_scalar(name + '/gradients/min_gradient', min_gradient, iter_count)
+            min_gradient = min(x.data.min() for x in gradients) 
+            tbwriter.add_scalar(name + '/gradients/min_gradient', min_gradient, iter_count)
+        except Exception as e:
+            print(("Failed to update tensorboard stats for model:  {0}").format(e))
 
 class ImageGenVisualizer():
     def output_image_gen_visuals(self, learn:Learner, trn_batch:Tuple, val_batch:Tuple, iter_count:int, tbwriter:SummaryWriter):
@@ -80,20 +86,23 @@ class ImageGenVisualizer():
         self._write_tensorboard_images(image_sets=image_sets, iter_count=iter_count, tbwriter=tbwriter, ds_type=ds_type)
     
     def _write_tensorboard_images(self, image_sets:[ModelImageSet], iter_count:int, tbwriter:SummaryWriter, ds_type: DatasetType):
-        orig_images = []
-        gen_images = []
-        real_images = []
+        try:
+            orig_images = []
+            gen_images = []
+            real_images = []
 
-        for image_set in image_sets:
-            orig_images.append(image_set.orig.px)
-            gen_images.append(image_set.gen.px)
-            real_images.append(image_set.real.px)
+            for image_set in image_sets:
+                orig_images.append(image_set.orig.px)
+                gen_images.append(image_set.gen.px)
+                real_images.append(image_set.real.px)
 
-        prefix = str(ds_type)
+            prefix = str(ds_type)
 
-        tbwriter.add_image(prefix + ' orig images', vutils.make_grid(orig_images, normalize=True), iter_count)
-        tbwriter.add_image(prefix + ' gen images', vutils.make_grid(gen_images, normalize=True), iter_count)
-        tbwriter.add_image(prefix + ' real images', vutils.make_grid(real_images, normalize=True), iter_count)
+            tbwriter.add_image(prefix + ' orig images', vutils.make_grid(orig_images, normalize=True), iter_count)
+            tbwriter.add_image(prefix + ' gen images', vutils.make_grid(gen_images, normalize=True), iter_count)
+            tbwriter.add_image(prefix + ' real images', vutils.make_grid(real_images, normalize=True), iter_count)
+        except Exception as e:
+            print(("Failed to update tensorboard images for model:  {0}").format(e))
 
 
 #--------Below are what you actually want ot use, in practice----------------#
@@ -110,22 +119,18 @@ class LearnerTensorboardWriter(LearnerCallback):
         self.loss_iters = loss_iters
         self.weight_iters = weight_iters
         self.stats_iters = stats_iters
-        self.iter_count = 0
         self.weight_vis = ModelHistogramVisualizer()
         self.model_vis = ModelStatsVisualizer() 
         self.data = None
-        #Keeping track of iterations in callback, because callback can be used for multiple epocs and multiple fit calls.
-        #This ensures that graphs show continuous iterations rather than resetting to 0 (which makes them much harder to read!)
-        self.iteration = -1
 
     def _update_batches_if_needed(self):
-        #one_batch is extremely slow.  this is an optimization
+        #one_batch function is extremely slow.  this is an optimization
         update_batches = self.data is not self.learn.data
 
         if update_batches:
             self.data = self.learn.data
-            self.trn_batch = self.learn.data.one_batch(DatasetType.Train, detach=False, denorm=False)
-            self.val_batch = self.learn.data.one_batch(DatasetType.Valid, detach=False, denorm=False)
+            self.trn_batch = self.learn.data.one_batch(DatasetType.Train, detach=True, denorm=False, cpu=False)
+            self.val_batch = self.learn.data.one_batch(DatasetType.Valid, detach=True, denorm=False, cpu=False)
 
     def _write_model_stats(self, iteration):
         self.model_vis.write_tensorboard_stats(model=self.learn.model, iter_count=iteration, tbwriter=self.tbwriter) 
@@ -153,10 +158,7 @@ class LearnerTensorboardWriter(LearnerCallback):
             if value is None: continue
             self.tbwriter.add_scalar('/metrics/' + name, to_np(value), iteration) 
 
-    def on_batch_end(self, last_loss, metrics, **kwargs):
-        self.iteration +=1
-        iteration = self.iteration
-
+    def on_batch_end(self, last_loss, metrics, iteration, **kwargs):
         if iteration==0:
             return
 
@@ -171,8 +173,7 @@ class LearnerTensorboardWriter(LearnerCallback):
         if iteration % self.stats_iters == 0:
             self._write_model_stats(iteration)
 
-    def on_epoch_end(self, metrics, last_metrics, **kwargs):
-        iteration = self.iteration  
+    def on_epoch_end(self, metrics, last_metrics, iteration, **kwargs):
         self._write_val_loss(iteration, last_metrics)
         self._write_metrics(iteration)
 
@@ -232,10 +233,8 @@ class GANTensorboardWriter(LearnerTensorboardWriter):
                                                 iter_count=iteration, tbwriter=self.tbwriter)
         trainer.switch(gen_mode=gen_mode)
 
-    def on_batch_end(self, metrics, **kwargs):
-        super().on_batch_end(metrics=metrics, **kwargs)
-
-        iteration = self.iteration
+    def on_batch_end(self, metrics, iteration, **kwargs):
+        super().on_batch_end(metrics=metrics, iteration=iteration, **kwargs)
 
         if iteration==0:
             return
@@ -257,10 +256,8 @@ class ImageGenTensorboardWriter(LearnerTensorboardWriter):
         self.img_gen_vis.output_image_gen_visuals(learn=self.learn, trn_batch=self.trn_batch, val_batch=self.val_batch, 
             iter_count=iteration, tbwriter=self.tbwriter)
 
-    def on_batch_end(self, metrics, **kwargs):
-        super().on_batch_end(metrics=metrics, **kwargs)
-
-        iteration = self.iteration
+    def on_batch_end(self, metrics, iteration, **kwargs):
+        super().on_batch_end(metrics=metrics, iteration=iteration, **kwargs)
 
         if iteration==0:
             return

+ 91 - 4
fasterai/unet.py

@@ -36,13 +36,13 @@ class PixelShuffle_ICNR2(nn.Module):
 class UnetBlock2(nn.Module):
     "A quasi-UNet block, using `PixelShuffle_ICNR upsampling`."
     def __init__(self, up_in_c:int, x_in_c:int, hook:Hook, final_div:bool=True, blur:bool=False, leaky:float=None,
-                 self_attention:bool=False,  **kwargs):
+                 self_attention:bool=False, nf_factor:float=1.0,  **kwargs):
         super().__init__()
         self.hook = hook
         self.shuf = PixelShuffle_ICNR2(up_in_c, up_in_c//2, blur=blur, leaky=leaky, **kwargs)
         self.bn = batchnorm_2d(x_in_c)
         ni = up_in_c//2 + x_in_c
-        nf = ni if final_div else ni//2
+        nf = int((ni if final_div else ni//2)*nf_factor)
         self.conv1 = conv_layer2(ni, nf, leaky=leaky, **kwargs)
         self.conv2 = conv_layer2(nf, nf, leaky=leaky, self_attention=self_attention, **kwargs)
         self.relu = relu(leaky=leaky)
@@ -61,7 +61,7 @@ class DynamicUnet2(SequentialEx):
     "Create a U-Net from a given architecture."
     def __init__(self, encoder:nn.Module, n_classes:int, blur:bool=False, blur_final=True, self_attention:bool=False,
                  y_range:Optional[Tuple[float,float]]=None, last_cross:bool=True, bottle:bool=False,
-                 norm_type:Optional[NormType]=NormType.Batch, **kwargs):
+                 norm_type:Optional[NormType]=NormType.Batch, nf_factor:float=1.0, **kwargs):
         #extra_bn =  norm_type in (NormType.Spectral, NormType.Weight)
         extra_bn =  norm_type == NormType.Spectral
         imsize = (256,256)
@@ -82,7 +82,7 @@ class DynamicUnet2(SequentialEx):
             do_blur = blur and (not_final or blur_final)
             sa = self_attention and (i==len(sfs_idxs)-3)
             unet_block = UnetBlock2(up_in_c, x_in_c, self.sfs[i], final_div=not_final, blur=blur, self_attention=sa,
-                                   norm_type=norm_type, extra_bn=extra_bn, **kwargs).eval()
+                                   norm_type=norm_type, extra_bn=extra_bn, nf_factor=nf_factor, **kwargs).eval()
             layers.append(unet_block)
             x = unet_block(x)
 
@@ -91,10 +91,97 @@ class DynamicUnet2(SequentialEx):
         if last_cross:
             layers.append(MergeLayer(dense=True))
             ni += in_channels(encoder)
+            #TODO:  Missing norm_type argument here.  DOH!
             layers.append(res_block(ni, bottle=bottle, **kwargs))
         layers += [conv_layer2(ni, n_classes, ks=1, use_activ=False, norm_type=norm_type)]
         if y_range is not None: layers.append(SigmoidRange(*y_range))
         super().__init__(*layers)
 
+    def __del__(self):
+        if hasattr(self, "sfs"): self.sfs.remove()
+
+
+class DynamicUnet3(SequentialEx):
+    "Create a U-Net from a given architecture."
+    def __init__(self, encoder:nn.Module, n_classes:int, blur:bool=False, blur_final=True, self_attention:bool=False,
+                 y_range:Optional[Tuple[float,float]]=None, last_cross:bool=True, bottle:bool=False,
+                 norm_type:Optional[NormType]=NormType.Batch, nf_factor:float=1.0, **kwargs):
+        extra_bn =  norm_type == NormType.Spectral
+        imsize = (256,256)
+        sfs_szs = model_sizes(encoder, size=imsize)
+        sfs_idxs = list(reversed(_get_sfs_idxs(sfs_szs)))
+        self.sfs = hook_outputs([encoder[i] for i in sfs_idxs])
+        x = dummy_eval(encoder, imsize).detach()
+
+        ni = sfs_szs[-1][1]
+        middle_conv = nn.Sequential(conv_layer2(ni, ni*2, norm_type=norm_type, extra_bn=extra_bn, **kwargs),
+                                    conv_layer2(ni*2, ni, norm_type=norm_type, extra_bn=extra_bn, **kwargs)).eval()
+        x = middle_conv(x)
+        layers = [encoder, batchnorm_2d(ni), nn.ReLU(), middle_conv]
+
+        for i,idx in enumerate(sfs_idxs):
+            not_final = i!=len(sfs_idxs)-1
+            up_in_c, x_in_c = int(x.shape[1]), int(sfs_szs[idx][1])
+            do_blur = blur and (not_final or blur_final)
+            sa = self_attention and (i==len(sfs_idxs)-3)
+            unet_block = UnetBlock2(up_in_c, x_in_c, self.sfs[i], final_div=not_final, blur=blur, self_attention=sa,
+                                   norm_type=norm_type, extra_bn=extra_bn, nf_factor=nf_factor, **kwargs).eval()
+            layers.append(unet_block)
+            x = unet_block(x)
+
+        ni = x.shape[1]
+        if imsize != sfs_szs[0][-2:]: layers.append(PixelShuffle_ICNR(ni, **kwargs))
+        if last_cross:
+            layers.append(MergeLayer(dense=True))
+            ni += in_channels(encoder)
+            layers.append(res_block(ni, bottle=bottle, norm_type=norm_type, **kwargs))
+        layers += [conv_layer2(ni, n_classes, ks=1, use_activ=False, norm_type=norm_type)]
+        if y_range is not None: layers.append(SigmoidRange(*y_range))
+        super().__init__(*layers)
+
+    def __del__(self):
+        if hasattr(self, "sfs"): self.sfs.remove()
+
+#No batch norm
+class DynamicUnet4(SequentialEx):
+    "Create a U-Net from a given architecture."
+    def __init__(self, encoder:nn.Module, n_classes:int, blur:bool=False, blur_final=True, self_attention:bool=False,
+                 y_range:Optional[Tuple[float,float]]=None, last_cross:bool=True, bottle:bool=False,
+                 norm_type:Optional[NormType]=NormType.Batch, nf_factor:float=1.0, **kwargs):
+        #extra_bn =  norm_type == NormType.Spectral
+        extra_bn = False
+        imsize = (256,256)
+        sfs_szs = model_sizes(encoder, size=imsize)
+        sfs_idxs = list(reversed(_get_sfs_idxs(sfs_szs)))
+        self.sfs = hook_outputs([encoder[i] for i in sfs_idxs])
+        x = dummy_eval(encoder, imsize).detach()
+
+        ni = sfs_szs[-1][1]
+        middle_conv = nn.Sequential(conv_layer2(ni, ni*2, norm_type=norm_type, extra_bn=extra_bn, **kwargs),
+                                    conv_layer2(ni*2, ni, norm_type=norm_type, extra_bn=extra_bn, **kwargs)).eval()
+        x = middle_conv(x)
+        #layers = [encoder, batchnorm_2d(ni), nn.ReLU(), middle_conv]
+        layers = [encoder, nn.ReLU(), middle_conv]
+
+        for i,idx in enumerate(sfs_idxs):
+            not_final = i!=len(sfs_idxs)-1
+            up_in_c, x_in_c = int(x.shape[1]), int(sfs_szs[idx][1])
+            do_blur = blur and (not_final or blur_final)
+            sa = self_attention and (i==len(sfs_idxs)-3)
+            unet_block = UnetBlock2(up_in_c, x_in_c, self.sfs[i], final_div=not_final, blur=blur, self_attention=sa,
+                                   norm_type=norm_type, extra_bn=extra_bn, nf_factor=nf_factor, **kwargs).eval()
+            layers.append(unet_block)
+            x = unet_block(x)
+
+        ni = x.shape[1]
+        if imsize != sfs_szs[0][-2:]: layers.append(PixelShuffle_ICNR(ni, **kwargs))
+        if last_cross:
+            layers.append(MergeLayer(dense=True))
+            ni += in_channels(encoder)
+            layers.append(res_block(ni, bottle=bottle, norm_type=norm_type, **kwargs))
+        layers += [conv_layer2(ni, n_classes, ks=1, use_activ=False, norm_type=norm_type)]
+        if y_range is not None: layers.append(SigmoidRange(*y_range))
+        super().__init__(*layers)
+
     def __del__(self):
         if hasattr(self, "sfs"): self.sfs.remove()