فهرست منبع

Progress on fastai v1 upgrade

Jason Antic 6 سال پیش
والد
کامیت
8168931689
11فایلهای تغییر یافته به همراه1006 افزوده شده و 116 حذف شده
  1. 3 0
      .gitignore
  2. 50 79
      ColorizeTraining.ipynb
  3. 658 0
      ColorizeTrainingNew.ipynb
  4. 21 20
      ColorizeVisualization.ipynb
  5. 24 3
      fasterai/critics.py
  6. 1 1
      fasterai/dataset.py
  7. 21 2
      fasterai/generators.py
  8. 61 0
      fasterai/layers.py
  9. 49 4
      fasterai/loss.py
  10. 18 7
      fasterai/tensorboard.py
  11. 100 0
      fasterai/unet.py

+ 3 - 0
.gitignore

@@ -485,3 +485,6 @@ test_images/Andy.jpg
 *.prof
 fastai
 *.pth
+.ipynb_checkpoints/SuperResTraining-checkpoint.ipynb
+.ipynb_checkpoints/ColorizeTrainingNew2-checkpoint.ipynb
+.ipynb_checkpoints/ColorizeTrainingNew-checkpoint.ipynb

+ 50 - 79
ColorizeTraining.ipynb

@@ -2,7 +2,7 @@
  "cells": [
   {
    "cell_type": "code",
-   "execution_count": 1,
+   "execution_count": null,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -13,7 +13,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 2,
+   "execution_count": null,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -37,25 +37,25 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 3,
+   "execution_count": null,
    "metadata": {},
    "outputs": [],
    "source": [
     "IMAGENET = Path('data/imagenet/ILSVRC/Data/CLS-LOC')\n",
     "BWIMAGENET = Path('data/imagenet/ILSVRC/Data/CLS-LOC/bandw')\n",
     "\n",
-    "proj_id = 'colorizeV2h'\n",
+    "proj_id = 'colorizeV5o'\n",
     "TENSORBOARD_PATH = Path('data/tensorboard/' + proj_id)\n",
     "\n",
-    "gpath = IMAGENET.parent/(proj_id + '_gen_128.h5')\n",
-    "dpath = IMAGENET.parent/(proj_id + '_critic_128.h5')\n",
+    "gpath = IMAGENET.parent/(proj_id + '_gen_64.h5')\n",
+    "dpath = IMAGENET.parent/(proj_id + '_critic_64.h5')\n",
     "\n",
     "torch.backends.cudnn.benchmark=True"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 4,
+   "execution_count": null,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -75,7 +75,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 5,
+   "execution_count": null,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -85,7 +85,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 6,
+   "execution_count": null,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -95,18 +95,18 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 7,
+   "execution_count": null,
    "metadata": {},
    "outputs": [],
    "source": [
     "def get_data(sz:int, bs:int, keep_pct:float):\n",
     "    return get_colorize_data(sz=sz, bs=bs, crappy_path=BWIMAGENET, good_path=IMAGENET, \n",
-    "                             random_seed=None, keep_pct=keep_pct,num_workers=32)"
+    "                             random_seed=None, keep_pct=keep_pct,num_workers=16)"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 8,
+   "execution_count": null,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -126,6 +126,17 @@
     "    learn_crit.load(proj_id + '_crit_' + str(sz))"
    ]
   },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def colorize_gen_learner_exp(data:ImageDataBunch, gen_loss=FeatureLoss(), arch=models.resnet34):\n",
+    "    return unet_learner2(data, arch, wd=1e-3, blur=True, norm_type=NormType.Spectral,\n",
+    "                        self_attention=True, y_range=(-3.,3.), loss_func=gen_loss)"
+   ]
+  },
   {
    "cell_type": "markdown",
    "metadata": {},
@@ -141,17 +152,17 @@
    "source": [
     "#Needed to instantiate critic but not actually used\n",
     "sz=64\n",
-    "bs=128\n",
+    "bs=32\n",
     "\n",
     "data = get_data(sz=sz, bs=bs, keep_pct=1.0)\n",
-    "learn_crit = colorize_crit_learner(data=data, nf=128)\n",
+    "learn_crit = colorize_crit_learner(data=data, nf=256)\n",
     "learn_crit.unfreeze()\n",
     "\n",
-    "gen_loss = FeatureLoss()\n",
-    "learn_gen = colorize_gen_learner(data=data, gen_loss=gen_loss, arch=models.resnet34)\n",
+    "gen_loss = FeatureLoss2(gram_wgt=5e3)\n",
+    "learn_gen = colorize_gen_learner_exp(data=data)\n",
     "\n",
     "switcher = partial(AdaptiveGANSwitcher, critic_thresh=0.65)\n",
-    "learn = GANLearner.from_learners(learn_gen, learn_crit, weights_gen=(1.,0.15), show_img=False, switcher=switcher,\n",
+    "learn = GANLearner.from_learners(learn_gen, learn_crit, weights_gen=(1.0,1.0), show_img=False, switcher=switcher,\n",
     "                                 opt_func=partial(optim.Adam, betas=(0.,0.99)), wd=1e-3)\n",
     "\n",
     "learn.callback_fns.append(partial(GANDiscriminativeLR, mult_lr=5.))\n",
@@ -172,64 +183,7 @@
    "cell_type": "code",
    "execution_count": null,
    "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/html": [
-       "\n",
-       "    <div>\n",
-       "        <style>\n",
-       "            /* Turns off some styling */\n",
-       "            progress {\n",
-       "                /* gets rid of default border in Firefox and Opera. */\n",
-       "                border: none;\n",
-       "                /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
-       "                background-size: auto;\n",
-       "            }\n",
-       "            .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
-       "                background: #F44336;\n",
-       "            }\n",
-       "        </style>\n",
-       "      <progress value='0' class='' max='1', style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
-       "      0.00% [0/1 00:00<00:00]\n",
-       "    </div>\n",
-       "    \n",
-       "<table style='width:300px; margin-bottom:10px'>\n",
-       "  <tr>\n",
-       "    <th>epoch</th>\n",
-       "    <th>train_loss</th>\n",
-       "    <th>gen_loss</th>\n",
-       "    <th>disc_loss</th>\n",
-       "  </tr>\n",
-       "</table>\n",
-       "\n",
-       "\n",
-       "    <div>\n",
-       "        <style>\n",
-       "            /* Turns off some styling */\n",
-       "            progress {\n",
-       "                /* gets rid of default border in Firefox and Opera. */\n",
-       "                border: none;\n",
-       "                /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
-       "                background-size: auto;\n",
-       "            }\n",
-       "            .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
-       "                background: #F44336;\n",
-       "            }\n",
-       "        </style>\n",
-       "      <progress value='4938' class='' max='9684', style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
-       "      50.99% [4938/9684 1:32:47<1:29:10 1.0264]\n",
-       "    </div>\n",
-       "    "
-      ],
-      "text/plain": [
-       "<IPython.core.display.HTML object>"
-      ]
-     },
-     "metadata": {},
-     "output_type": "display_data"
-    }
-   ],
+   "outputs": [],
    "source": [
     "learn_gen.freeze_to(-1)\n",
     "learn.fit(1,lr)"
@@ -276,9 +230,10 @@
    "metadata": {},
    "outputs": [],
    "source": [
+    "load()\n",
     "lr=lr/2\n",
     "sz=96\n",
-    "bs=bs//2"
+    "#bs=bs//2"
    ]
   },
   {
@@ -448,7 +403,16 @@
    "source": [
     "lr=lr/1.5\n",
     "sz=160\n",
-    "bs=bs//1.5"
+    "bs=int(bs//1.5)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "bs=10"
    ]
   },
   {
@@ -533,7 +497,7 @@
    "source": [
     "lr=lr/1.5\n",
     "sz=192\n",
-    "bs=bs//1.5"
+    "bs=int(bs//1.5)"
    ]
   },
   {
@@ -618,7 +582,7 @@
    "source": [
     "lr=lr/1.5\n",
     "sz=224\n",
-    "bs=bs//1.5"
+    "bs=int(bs//1.5)"
    ]
   },
   {
@@ -687,6 +651,13 @@
    "source": [
     "save()"
    ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": []
   }
  ],
  "metadata": {

تفاوت فایلی نمایش داده نمی شود زیرا این فایل بسیار بزرگ است
+ 658 - 0
ColorizeTrainingNew.ipynb


+ 21 - 20
ColorizeVisualization.ipynb

@@ -30,7 +30,7 @@
     "from fasterai.generators import *\n",
     "from pathlib import Path\n",
     "from itertools import repeat\n",
-    "torch.cuda.set_device(0)\n",
+    "torch.cuda.set_device(2)\n",
     "plt.style.use('dark_background')\n",
     "torch.backends.cudnn.benchmark=True"
    ]
@@ -41,13 +41,13 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "DUMMY_DATA = Path('./')\n",
-    "colorizer_file = 'colorizeV2c_gen_64'\n",
+    "#Dummy data path- shouldn't pull any images.\n",
+    "path = Path('./')\n",
     "#The higher the render_factor, the more GPU memory will be used and generally images will look better.  \n",
     "#11GB can take a factor of 42 max.  Performance generally gracefully degrades with lower factors, \n",
     "#though you may also find that certain images will actually render better at lower numbers.  \n",
     "#This tends to be the case with the oldest photos.\n",
-    "render_factor=42"
+    "render_factor=62"
    ]
   },
   {
@@ -56,9 +56,10 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "data = get_colorize_data(sz=256, bs=4, crappy_path=DUMMY_DATA, good_path=DUMMY_DATA, keep_pct=0.01)\n",
+    "data = get_colorize_data(sz=256, bs=4, crappy_path=path, good_path=path, keep_pct=0.01)\n",
     "learn = colorize_gen_learner(data=data)\n",
-    "learn.load(colorizer_file)\n",
+    "learn.path = Path('data/imagenet/ILSVRC/Data/CLS-LOC/bandw')\n",
+    "learn.load('gen-pre-a')\n",
     "filtr = MasterFilter([ColorizerFilter(learn=learn)], render_factor=render_factor)\n",
     "vis = ModelImageVisualizer(filtr, results_dir='result_images')"
    ]
@@ -78,7 +79,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "vis.plot_transformed_image(\"test_images/Chief.jpg\", render_factor=16)"
+    "vis.plot_transformed_image(\"test_images/Chief.jpg\")"
    ]
   },
   {
@@ -87,7 +88,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "vis.plot_transformed_image(\"test_images/1850SchoolForGirls.jpg\", render_factor=42)"
+    "vis.plot_transformed_image(\"test_images/1850SchoolForGirls.jpg\")"
    ]
   },
   {
@@ -96,7 +97,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "vis.plot_transformed_image(\"test_images/AtlanticCityBeach1905.jpg\", render_factor=42)"
+    "vis.plot_transformed_image(\"test_images/AtlanticCityBeach1905.jpg\")"
    ]
   },
   {
@@ -105,7 +106,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "vis.plot_transformed_image(\"test_images/CottonMillWorkers1913.jpg\", render_factor=41)"
+    "vis.plot_transformed_image(\"test_images/CottonMillWorkers1913.jpg\")"
    ]
   },
   {
@@ -123,7 +124,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "vis.plot_transformed_image(\"test_images/FinnishPeasant1867.jpg\", render_factor=20)"
+    "vis.plot_transformed_image(\"test_images/FinnishPeasant1867.jpg\")"
    ]
   },
   {
@@ -132,7 +133,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "vis.plot_transformed_image(\"test_images/AtlanticCity1905.png\", render_factor=39)"
+    "vis.plot_transformed_image(\"test_images/AtlanticCity1905.png\")"
    ]
   },
   {
@@ -159,7 +160,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "vis.plot_transformed_image(\"test_images/IronLung.png\", render_factor=39)"
+    "vis.plot_transformed_image(\"test_images/IronLung.png\")"
    ]
   },
   {
@@ -177,7 +178,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "vis.plot_transformed_image(\"test_images/DayAtSeaBelgium.jpg\", render_factor=41)"
+    "vis.plot_transformed_image(\"test_images/DayAtSeaBelgium.jpg\")"
    ]
   },
   {
@@ -186,7 +187,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "vis.plot_transformed_image(\"test_images/marilyn_woods.jpg\", render_factor=30)"
+    "vis.plot_transformed_image(\"test_images/marilyn_woods.jpg\")"
    ]
   },
   {
@@ -195,7 +196,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "vis.plot_transformed_image(\"test_images/OldWomanSweden1904.jpg\", render_factor=39)"
+    "vis.plot_transformed_image(\"test_images/OldWomanSweden1904.jpg\")"
    ]
   },
   {
@@ -240,7 +241,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "vis.plot_transformed_image(\"test_images/abe.jpg\", render_factor=25)"
+    "vis.plot_transformed_image(\"test_images/abe.jpg\")"
    ]
   },
   {
@@ -249,7 +250,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "vis.plot_transformed_image(\"test_images/RossCorbettHouseCork.jpg\", render_factor=33)"
+    "vis.plot_transformed_image(\"test_images/RossCorbettHouseCork.jpg\")"
    ]
   },
   {
@@ -267,7 +268,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "vis.plot_transformed_image(\"test_images/einstein_beach.jpg\", render_factor=42)"
+    "vis.plot_transformed_image(\"test_images/einstein_beach.jpg\")"
    ]
   },
   {
@@ -303,7 +304,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "vis.plot_transformed_image(\"test_images/Rutherford_Hayes.jpg\", render_factor=36)"
+    "vis.plot_transformed_image(\"test_images/Rutherford_Hayes.jpg\")"
    ]
   },
   {

+ 24 - 3
fasterai/critics.py

@@ -1,7 +1,28 @@
 from fastai.core import *
 from fastai.torch_core import *
 from fastai.vision import *
-from fastai.vision.gan import *
+from fastai.vision.gan import AdaptiveLoss
 
-def colorize_crit_learner(data:ImageDataBunch, loss_critic=AdaptiveLoss(nn.BCEWithLogitsLoss()), nf:int=128)->Learner:
-    return Learner(data, gan_critic(nf=nf), metrics=None, loss_func=loss_critic, wd=1e-3)
+_conv_args = dict(leaky=0.2, norm_type=NormType.Spectral)
+
+def _conv(ni:int, nf:int, ks:int=3, stride:int=1, **kwargs):
+    return conv_layer(ni, nf, ks=ks, stride=stride, **_conv_args, **kwargs)
+
+#TODO:  Merge with fastai core.  Just removed dense block.
+def gan_critic2(n_channels:int=3, nf:int=256, n_blocks:int=3, p:int=0.15):
+    "Critic to train a `GAN`."
+    layers = [
+        _conv(n_channels, nf, ks=4, stride=2),
+        nn.Dropout2d(p/2)]
+    for i in range(n_blocks):
+        layers += [
+            nn.Dropout2d(p),
+            _conv(nf, nf*2, ks=4, stride=2, self_attention=(i==0))]
+        nf *= 2
+    layers += [
+        _conv(nf, 1, ks=4, bias=False, padding=0, use_activ=False),
+        Flatten()]
+    return nn.Sequential(*layers)
+
+def colorize_crit_learner(data:ImageDataBunch, loss_critic=AdaptiveLoss(nn.BCEWithLogitsLoss()), nf:int=256)->Learner:
+    return Learner(data, gan_critic2(nf=nf), metrics=None, loss_func=loss_critic, wd=1e-3)

+ 1 - 1
fasterai/dataset.py

@@ -13,7 +13,7 @@ def get_colorize_data(sz:int, bs:int, crappy_path:Path, good_path:Path, random_s
 
     data = (src.label_from_func(lambda x: good_path/x.relative_to(crappy_path))
         #TODO:  Revisit transforms used here....
-        .transform(get_transforms(), size=sz, tfm_y=True)
+        .transform(get_transforms(max_zoom=1.2, max_lighting=0.5, max_warp=0.25), size=sz, tfm_y=True)
         .databunch(bs=bs, num_workers=num_workers)
         .normalize(imagenet_stats, do_y=True))
 

+ 21 - 2
fasterai/generators.py

@@ -1,7 +1,26 @@
 from fastai.vision import *
-from fastai.vision.models.unet import *
+from fastai.vision.learner import cnn_config
+from fasterai.unet import *
 from .loss import FeatureLoss
 
 def colorize_gen_learner(data:ImageDataBunch, gen_loss=FeatureLoss(), arch=models.resnet34):
-    return unet_learner(data, arch, wd=1e-3, blur=True, norm_type=NormType.Weight,
+    return unet_learner2(data, arch, wd=1e-3, blur=True, norm_type=NormType.Spectral,
                         self_attention=True, y_range=(-3.,3.), loss_func=gen_loss)
+
+#The code below is meant to be merged into fastaiv1 ideally
+
+def unet_learner2(data:DataBunch, arch:Callable, pretrained:bool=True, blur_final:bool=True,
+                 norm_type:Optional[NormType]=NormType, split_on:Optional[SplitFuncOrIdxList]=None, 
+                 blur:bool=False, self_attention:bool=False, y_range:Optional[Tuple[float,float]]=None, last_cross:bool=True,
+                 bottle:bool=False, **kwargs:Any)->None:
+    "Build Unet learner from `data` and `arch`."
+    meta = cnn_config(arch)
+    body = create_body(arch, pretrained)
+    model = to_device(DynamicUnet2(body, n_classes=data.c, blur=blur, blur_final=blur_final,
+          self_attention=self_attention, y_range=y_range, norm_type=norm_type, last_cross=last_cross,
+          bottle=bottle), data.device)
+    learn = Learner(data, model, **kwargs)
+    learn.split(ifnone(split_on,meta['split']))
+    if pretrained: learn.freeze()
+    apply_init(model[2], nn.init.kaiming_normal_)
+    return learn

+ 61 - 0
fasterai/layers.py

@@ -0,0 +1,61 @@
+from fastai.layers import *
+from fastai.torch_core import *
+from torch.nn.parameter import Parameter
+from torch.autograd import Variable
+
+
+#The code below is meant to be merged into fastaiv1 ideally
+
+def conv_layer2(ni:int, nf:int, ks:int=3, stride:int=1, padding:int=None, bias:bool=None, is_1d:bool=False,
+               norm_type:Optional[NormType]=NormType.Batch,  use_activ:bool=True, leaky:float=None,
+               transpose:bool=False, init:Callable=nn.init.kaiming_normal_, self_attention:bool=False,
+               extra_bn:bool=False):
+    "Create a sequence of convolutional (`ni` to `nf`), ReLU (if `use_activ`) and batchnorm (if `bn`) layers."
+    if padding is None: padding = (ks-1)//2 if not transpose else 0
+    bn = norm_type in (NormType.Batch, NormType.BatchZero) or extra_bn==True
+    if bias is None: bias = not bn
+    conv_func = nn.ConvTranspose2d if transpose else nn.Conv1d if is_1d else nn.Conv2d
+    conv = init_default(conv_func(ni, nf, kernel_size=ks, bias=bias, stride=stride, padding=padding), init)
+    if   norm_type==NormType.Weight:   conv = weight_norm(conv)
+    elif norm_type==NormType.Spectral: conv = spectral_norm(conv)
+    layers = [conv]
+    if use_activ: layers.append(relu(True, leaky=leaky))
+    if bn: layers.append((nn.BatchNorm1d if is_1d else nn.BatchNorm2d)(nf))
+    
+    #TODO:  Account for 1D
+    #if norm_type==NormType.Weight: layers.append(MeanOnlyBatchNorm(nf))
+
+    if self_attention: layers.append(SelfAttention(nf))
+    return nn.Sequential(*layers)
+
+class MeanOnlyBatchNorm(nn.Module):
+    def __init__(self, num_features, momentum=0.1):
+        super(MeanOnlyBatchNorm, self).__init__()
+        self.num_features = num_features
+        self.momentum = momentum
+        self.weight = Parameter(torch.Tensor(num_features))
+        self.bias = Parameter(torch.Tensor(num_features))
+
+        self.register_buffer('running_mean', torch.zeros(num_features))
+        self.reset_parameters()
+        
+    def reset_parameters(self):
+        self.running_mean.zero_()
+        self.weight.data.uniform_()
+        self.bias.data.zero_()
+
+    def forward(self, inp):
+        size = list(inp.size())
+        gamma = self.weight.view(1, self.num_features, 1, 1)
+        beta = self.bias.view(1, self.num_features, 1, 1)
+
+        if self.training:
+            avg = torch.mean(inp.view(size[0], self.num_features, -1), dim=2)
+            self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * torch.mean(avg.data, dim=0)
+        else:
+            avg = Variable(self.running_mean.repeat(size[0], 1), requires_grad=False)
+
+        output = inp - avg.view(size[0], size[1], 1, 1)
+        output = output*gamma + beta
+
+        return output

+ 49 - 4
fasterai/loss.py

@@ -6,8 +6,9 @@ import torchvision.models as models
 
 
 class FeatureLoss(nn.Module):
-    def __init__(self, layer_wgts=[5,15,2]):
+    def __init__(self, layer_wgts:[float]=[5.0,15.0,2.0], gram_wgt:float=5e3):
         super().__init__()
+        self.gram_wgt = gram_wgt
         self.base_loss = F.l1_loss
         self.m_feat = models.vgg16_bn(True).features.cuda().eval()
         requires_grad(self.m_feat, False)
@@ -31,11 +32,14 @@ class FeatureLoss(nn.Module):
     def forward(self, input:torch.Tensor, target:torch.Tensor):
         out_feat = self.make_features(target, clone=True)
         in_feat = self.make_features(input)
-        self.feat_losses = [self.base_loss(input,target)]
-        self.feat_losses += [self.base_loss(f_in, f_out)*w
+        self.feat_losses = [self.base_loss(f_in, f_out)*w
                              for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)]
-        self.feat_losses += [self.base_loss(self._gram_matrix(f_in), self._gram_matrix(f_out))*w**2 * 5e3
+
+        self.feat_losses += [self.base_loss(input,target)]
+
+        self.feat_losses += [self.base_loss(self._gram_matrix(f_in), self._gram_matrix(f_out))*w**2 * self.gram_wgt
                              for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)]
+                             
         self.metrics = dict(zip(self.metric_names, self.feat_losses))
         return sum(self.feat_losses)
     
@@ -43,3 +47,44 @@ class FeatureLoss(nn.Module):
         self.hooks.remove()
 
 
+
+class FeatureLoss2(nn.Module):
+    def __init__(self, layer_wgts:[float]=[20.0,70.0,10.0], gram_wgt:float=5e3):
+        super().__init__()
+        self.gram_wgt = gram_wgt
+        self.base_loss = F.l1_loss
+        self.m_feat = models.vgg16_bn(True).features.cuda().eval()
+        requires_grad(self.m_feat, False)
+        blocks = [i-1 for i,o in enumerate(children(self.m_feat)) if isinstance(o,nn.MaxPool2d)]
+        layer_ids = blocks[2:5]
+        self.loss_features = [self.m_feat[i] for i in layer_ids]
+        self.hooks = hook_outputs(self.loss_features, detach=False)
+        self.wgts = layer_wgts
+        self.metric_names = ['pixel',] + [f'feat_{i}' for i in range(len(layer_ids))
+              ] + [f'gram_{i}' for i in range(len(layer_ids))]
+
+    def _gram_matrix(self, x:torch.Tensor):
+        n,c,h,w = x.size()
+        x = x.view(n, c, -1)
+        return (x @ x.transpose(1,2))/(c*h*w)
+
+    def make_features(self, x:torch.Tensor, clone=False):
+        self.m_feat(x)
+        return [(o.clone() if clone else o) for o in self.hooks.stored]
+    
+    def forward(self, input:torch.Tensor, target:torch.Tensor):
+        out_feat = self.make_features(target, clone=True)
+        in_feat = self.make_features(input)
+        self.feat_losses = [self.base_loss(f_in, f_out)*w
+                             for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)]
+
+        self.feat_losses += [self.base_loss(input,target)*100]
+
+        self.feat_losses += [self.base_loss(self._gram_matrix(f_in), self._gram_matrix(f_out))*w**2 * self.gram_wgt
+                             for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)]
+                             
+        self.metrics = dict(zip(self.metric_names, self.feat_losses))
+        return sum(self.feat_losses)
+    
+    def __del__(self): 
+        self.hooks.remove()

+ 18 - 7
fasterai/tensorboard.py

@@ -135,7 +135,10 @@ class GANTensorboardWriter(LearnerCallback):
         self.weight_vis = ModelHistogramVisualizer()
         self.data = None
 
-    def on_batch_end(self, iteration, last_loss, **kwargs):
+    def on_batch_end(self, iteration, metrics, **kwargs):
+        if iteration==0:
+            return
+
         trainer = self.learn.gan_trainer
         generator = trainer.generator
         critic = trainer.critic
@@ -154,7 +157,7 @@ class GANTensorboardWriter(LearnerCallback):
                 self.tbwriter.add_scalar('/loss/trn_loss', trn_loss, iteration)
 
             if len(recorder.val_losses) > 0:
-                val_loss = to_np((recorder.val_losses[-1:])[0])
+                val_loss = (recorder.val_losses[-1:])[0]
                 self.tbwriter.add_scalar('/loss/val_loss', val_loss, iteration) 
 
             #TODO:  Figure out how to do metrics here and gan vs critic loss
@@ -170,6 +173,7 @@ class GANTensorboardWriter(LearnerCallback):
         if iteration % self.weight_iters == 0:
             self.weight_vis.write_tensorboard_histograms(model=generator, iter_count=iteration, tbwriter=self.tbwriter)
             self.weight_vis.write_tensorboard_histograms(model=critic, iter_count=iteration, tbwriter=self.tbwriter)
+              
 
 
 class ImageGenTensorboardWriter(LearnerCallback):
@@ -188,7 +192,10 @@ class ImageGenTensorboardWriter(LearnerCallback):
         self.img_gen_vis = ImageGenVisualizer()
         self.data = None
 
-    def on_batch_end(self, iteration, last_loss, **kwargs):
+    def on_batch_end(self, iteration, last_loss, metrics, **kwargs):
+        if iteration==0:
+            return
+
         #one_batch is extremely slow.  this is an optimization
         update_batches = self.data is not self.learn.data
 
@@ -197,10 +204,9 @@ class ImageGenTensorboardWriter(LearnerCallback):
             self.trn_batch = self.learn.data.one_batch(DatasetType.Train, detach=False, denorm=False)
             self.val_batch = self.learn.data.one_batch(DatasetType.Valid, detach=False, denorm=False)
 
-
-        if iteration % self.stats_iters == 0:
-            self.tbwriter.add_scalar('/loss/trn_loss', str(self.losses[-1:]), iteration)
-            self.tbwriter.add_scalar('/loss/val_loss', str(self.val_losses[-1:]), iteration) 
+        if iteration % self.stats_iters == 0: 
+            trn_loss = to_np(last_loss)
+            self.tbwriter.add_scalar('/loss/trn_loss', trn_loss, iteration)
 
         if iteration % self.visual_iters == 0:
             self.img_gen_vis.output_image_gen_visuals(learn=self.learn, trn_batch=self.trn_batch, val_batch=self.val_batch, 
@@ -208,3 +214,8 @@ class ImageGenTensorboardWriter(LearnerCallback):
 
         if iteration % self.weight_iters == 0:
             self.weight_vis.write_tensorboard_histograms(model=self.learn.model, iter_count=iteration, tbwriter=self.tbwriter)
+
+    def on_epoch_end(self, iteration, metrics, last_metrics, **kwargs):  
+        #TODO: Not a fan of this indexing but...what to do?
+        val_loss = last_metrics[0]
+        self.tbwriter.add_scalar('/loss/val_loss', val_loss, iteration)   

+ 100 - 0
fasterai/unet.py

@@ -0,0 +1,100 @@
+from fastai.layers import *
+from fasterai.layers import *
+from fastai.torch_core import *
+from fastai.callbacks.hooks import *
+
+#The code below is meant to be merged into fastaiv1 ideally
+
+__all__ = ['DynamicUnet2', 'UnetBlock2']
+
+def _get_sfs_idxs(sizes:Sizes) -> List[int]:
+    "Get the indexes of the layers where the size of the activation changes."
+    feature_szs = [size[-1] for size in sizes]
+    sfs_idxs = list(np.where(np.array(feature_szs[:-1]) != np.array(feature_szs[1:]))[0])
+    if feature_szs[0] != feature_szs[1]: sfs_idxs = [0] + sfs_idxs
+    return sfs_idxs
+
+class PixelShuffle_ICNR2(nn.Module):
+    "Upsample by `scale` from `ni` filters to `nf` (default `ni`), using `nn.PixelShuffle`, `icnr` init, and `weight_norm`."
+    def __init__(self, ni:int, nf:int=None, scale:int=2, blur:bool=False, leaky:float=None, **kwargs):
+        super().__init__()
+        nf = ifnone(nf, ni)
+        self.conv = conv_layer2(ni, nf*(scale**2), ks=1, use_activ=False, **kwargs)
+        icnr(self.conv[0].weight)
+        self.shuf = nn.PixelShuffle(scale)
+        # Blurring over (h*w) kernel
+        # "Super-Resolution using Convolutional Neural Networks without Any Checkerboard Artifacts"
+        # - https://arxiv.org/abs/1806.02658
+        self.pad = nn.ReplicationPad2d((1,0,1,0))
+        self.blur = nn.AvgPool2d(2, stride=1)
+        self.relu = relu(True, leaky=leaky)
+
+    def forward(self,x):
+        x = self.shuf(self.relu(self.conv(x)))
+        return self.blur(self.pad(x)) if self.blur else x
+
+class UnetBlock2(nn.Module):
+    "A quasi-UNet block, using `PixelShuffle_ICNR upsampling`."
+    def __init__(self, up_in_c:int, x_in_c:int, hook:Hook, final_div:bool=True, blur:bool=False, leaky:float=None,
+                 self_attention:bool=False,  **kwargs):
+        super().__init__()
+        self.hook = hook
+        self.shuf = PixelShuffle_ICNR2(up_in_c, up_in_c//2, blur=blur, leaky=leaky, **kwargs)
+        self.bn = batchnorm_2d(x_in_c)
+        ni = up_in_c//2 + x_in_c
+        nf = ni if final_div else ni//2
+        self.conv1 = conv_layer2(ni, nf, leaky=leaky, **kwargs)
+        self.conv2 = conv_layer2(nf, nf, leaky=leaky, self_attention=self_attention, **kwargs)
+        self.relu = relu(leaky=leaky)
+
+    def forward(self, up_in:Tensor) -> Tensor:
+        s = self.hook.stored
+        up_out = self.shuf(up_in)
+        ssh = s.shape[-2:]
+        if ssh != up_out.shape[-2:]:
+            up_out = F.interpolate(up_out, s.shape[-2:], mode='nearest')
+        cat_x = self.relu(torch.cat([up_out, self.bn(s)], dim=1))
+        return self.conv2(self.conv1(cat_x))
+
+
+class DynamicUnet2(SequentialEx):
+    "Create a U-Net from a given architecture."
+    def __init__(self, encoder:nn.Module, n_classes:int, blur:bool=False, blur_final=True, self_attention:bool=False,
+                 y_range:Optional[Tuple[float,float]]=None, last_cross:bool=True, bottle:bool=False,
+                 norm_type:Optional[NormType]=NormType.Batch, **kwargs):
+        #extra_bn =  norm_type in (NormType.Spectral, NormType.Weight)
+        extra_bn =  norm_type == NormType.Spectral
+        imsize = (256,256)
+        sfs_szs = model_sizes(encoder, size=imsize)
+        sfs_idxs = list(reversed(_get_sfs_idxs(sfs_szs)))
+        self.sfs = hook_outputs([encoder[i] for i in sfs_idxs])
+        x = dummy_eval(encoder, imsize).detach()
+
+        ni = sfs_szs[-1][1]
+        middle_conv = nn.Sequential(conv_layer2(ni, ni*2, norm_type=norm_type, extra_bn=extra_bn, **kwargs),
+                                    conv_layer2(ni*2, ni, norm_type=norm_type, extra_bn=extra_bn, **kwargs)).eval()
+        x = middle_conv(x)
+        layers = [encoder, batchnorm_2d(ni), nn.ReLU(), middle_conv]
+
+        for i,idx in enumerate(sfs_idxs):
+            not_final = i!=len(sfs_idxs)-1
+            up_in_c, x_in_c = int(x.shape[1]), int(sfs_szs[idx][1])
+            do_blur = blur and (not_final or blur_final)
+            sa = self_attention and (i==len(sfs_idxs)-3)
+            unet_block = UnetBlock2(up_in_c, x_in_c, self.sfs[i], final_div=not_final, blur=blur, self_attention=sa,
+                                   norm_type=norm_type, extra_bn=extra_bn, **kwargs).eval()
+            layers.append(unet_block)
+            x = unet_block(x)
+
+        ni = x.shape[1]
+        if imsize != sfs_szs[0][-2:]: layers.append(PixelShuffle_ICNR(ni, **kwargs))
+        if last_cross:
+            layers.append(MergeLayer(dense=True))
+            ni += in_channels(encoder)
+            layers.append(res_block(ni, bottle=bottle, **kwargs))
+        layers += [conv_layer2(ni, n_classes, ks=1, use_activ=False, norm_type=norm_type)]
+        if y_range is not None: layers.append(SigmoidRange(*y_range))
+        super().__init__(*layers)
+
+    def __del__(self):
+        if hasattr(self, "sfs"): self.sfs.remove()

برخی فایل ها در این مقایسه diff نمایش داده نمی شوند زیرا تعداد فایل ها بسیار زیاد است