Bladeren bron

More cleanup

Jason Antic 6 jaren geleden
bovenliggende
commit
90c49c994f
10 gewijzigde bestanden met toevoegingen van 233 en 461 verwijderingen
  1. 4 0
      .vscode/settings.json
  2. 1 12
      DeOldify_colab.ipynb
  3. 2 3
      fasterai/critics.py
  4. 0 1
      fasterai/dataset.py
  5. 1 1
      fasterai/filters.py
  6. 18 63
      fasterai/generators.py
  7. 2 38
      fasterai/layers.py
  8. 184 179
      fasterai/tensorboard.py
  9. 13 163
      fasterai/unet.py
  10. 8 1
      fasterai/visualize.py

+ 4 - 0
.vscode/settings.json

@@ -0,0 +1,4 @@
+{
+    "python.pythonPath": "/home/jason/anaconda3/envs/fastaiv1/bin/python",
+    "python.linting.enabled": true
+}

+ 1 - 12
DeOldify_colab.ipynb

@@ -237,17 +237,6 @@
     "!mkdir \"/content/drive/My Drive/deOldifyImages/results\""
    ]
   },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "def colorize_gen_learner_exp(data:ImageDataBunch, gen_loss=FeatureLoss(), arch=models.resnet34):\n",
-    "    return unet_learner3(data, arch, wd=1e-3, blur=True, norm_type=NormType.Spectral,\n",
-    "                        self_attention=True, y_range=(-3.,3.), loss_func=gen_loss)"
-   ]
-  },
   {
    "cell_type": "code",
    "execution_count": null,
@@ -263,7 +252,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "learn = colorize_gen_learner_exp(data=data)\n",
+    "learn = colorize_gen_learner(data=data, nf_factor=1.25)\n",
     "#switch to read models from proper place\n",
     "learn.path = Path('./')\n",
     "learn.load(weights_name)\n",

+ 2 - 3
fasterai/critics.py

@@ -8,8 +8,7 @@ _conv_args = dict(leaky=0.2, norm_type=NormType.Spectral)
 def _conv(ni:int, nf:int, ks:int=3, stride:int=1, **kwargs):
     return conv_layer(ni, nf, ks=ks, stride=stride, **_conv_args, **kwargs)
 
-#TODO:  Merge with fastai core.  Just removed dense block.
-def gan_critic2(n_channels:int=3, nf:int=256, n_blocks:int=3, p:int=0.15):
+def custom_gan_critic(n_channels:int=3, nf:int=256, n_blocks:int=3, p:int=0.15):
     "Critic to train a `GAN`."
     layers = [
         _conv(n_channels, nf, ks=4, stride=2),
@@ -27,4 +26,4 @@ def gan_critic2(n_channels:int=3, nf:int=256, n_blocks:int=3, p:int=0.15):
     return nn.Sequential(*layers)
 
 def colorize_crit_learner(data:ImageDataBunch, loss_critic=AdaptiveLoss(nn.BCEWithLogitsLoss()), nf:int=256)->Learner:
-    return Learner(data, gan_critic2(nf=nf), metrics=accuracy_thresh_expand, loss_func=loss_critic, wd=1e-3)
+    return Learner(data, custom_gan_critic(nf=nf), metrics=accuracy_thresh_expand, loss_func=loss_critic, wd=1e-3)

+ 0 - 1
fasterai/dataset.py

@@ -12,7 +12,6 @@ def get_colorize_data(sz:int, bs:int, crappy_path:Path, good_path:Path, random_s
         .random_split_by_pct(0.1, seed=random_seed))
 
     data = (src.label_from_func(lambda x: good_path/x.relative_to(crappy_path))
-        #TODO:  Revisit transforms used here....
         .transform(get_transforms(max_zoom=1.2, max_lighting=0.5, max_warp=0.25), size=sz, tfm_y=True)
         .databunch(bs=bs, num_workers=num_workers, no_check=True)
         .normalize(imagenet_stats, do_y=True))

+ 1 - 1
fasterai/filters.py

@@ -8,10 +8,10 @@ from fastai.vision.data import *
 from fastai import *
 import math
 from scipy import misc
-#from torchvision.transforms.functional import *
 import cv2
 from PIL import Image as PilImage
 
+
 class IFilter(ABC):
     @abstractmethod
     def filter(self, orig_image:PilImage, filtered_image:PilImage, render_factor:int)->PilImage:

+ 18 - 63
fasterai/generators.py

@@ -1,80 +1,35 @@
 from fastai.vision import *
 from fastai.vision.learner import cnn_config
-from fasterai.unet import DynamicUnet2, DynamicUnet3, DynamicUnet4, DynamicUnet5
+from .unet import CustomDynamicUnet
 from .loss import FeatureLoss
+from .dataset import *
 
-def colorize_gen_learner(data:ImageDataBunch, gen_loss=FeatureLoss(), arch=models.resnet34):
-    return unet_learner2(data, arch, wd=1e-3, blur=True, norm_type=NormType.Spectral,
-                        self_attention=True, y_range=(-3.,3.), loss_func=gen_loss)
+#Weights are implicitly read from ./models/ folder 
+def colorize_gen_inference(root_folder:Path, weights_name:str, nf_factor:float)->Learner:
+      data = get_dummy_databunch()
+      learn = colorize_gen_learner(data=data, gen_loss=F.l1_loss, nf_factor=nf_factor)
+      learn.path = root_folder
+      learn.load(weights_name)
+      learn.model.eval()
+      return learn
 
-#The code below is meant to be merged into fastaiv1 ideally
-
-def unet_learner2(data:DataBunch, arch:Callable, pretrained:bool=True, blur_final:bool=True,
-                 norm_type:Optional[NormType]=NormType, split_on:Optional[SplitFuncOrIdxList]=None, 
-                 blur:bool=False, self_attention:bool=False, y_range:Optional[Tuple[float,float]]=None, last_cross:bool=True,
-                 bottle:bool=False, **kwargs:Any)->None:
-    "Build Unet learner from `data` and `arch`."
-    meta = cnn_config(arch)
-    body = create_body(arch, pretrained)
-    model = to_device(DynamicUnet2(body, n_classes=data.c, blur=blur, blur_final=blur_final,
-          self_attention=self_attention, y_range=y_range, norm_type=norm_type, last_cross=last_cross,
-          bottle=bottle), data.device)
-    learn = Learner(data, model, **kwargs)
-    learn.split(ifnone(split_on,meta['split']))
-    if pretrained: learn.freeze()
-    apply_init(model[2], nn.init.kaiming_normal_)
-    return learn
-
-
-def unet_learner3(data:DataBunch, arch:Callable, pretrained:bool=True, blur_final:bool=True,
-                 norm_type:Optional[NormType]=NormType, split_on:Optional[SplitFuncOrIdxList]=None, 
-                 blur:bool=False, self_attention:bool=False, y_range:Optional[Tuple[float,float]]=None, last_cross:bool=True,
-                 bottle:bool=False, nf_factor:float=1.0, **kwargs:Any)->None:
-    "Build Unet learner from `data` and `arch`."
-    meta = cnn_config(arch)
-    body = create_body(arch, pretrained)
-    model = to_device(DynamicUnet3(body, n_classes=data.c, blur=blur, blur_final=blur_final,
-          self_attention=self_attention, y_range=y_range, norm_type=norm_type, last_cross=last_cross,
-          bottle=bottle, nf_factor=nf_factor), data.device)
-    learn = Learner(data, model, **kwargs)
-    learn.split(ifnone(split_on,meta['split']))
-    if pretrained: learn.freeze()
-    apply_init(model[2], nn.init.kaiming_normal_)
-    return learn
+def colorize_gen_learner(data:ImageDataBunch, gen_loss=FeatureLoss(), arch=models.resnet34, nf_factor:float=1.0)->Learner:
+    return custom_unet_learner(data, arch, wd=1e-3, blur=True, norm_type=NormType.Spectral,
+                        self_attention=True, y_range=(-3.,3.), loss_func=gen_loss, nf_factor=nf_factor)
 
-
-#No batch norm in ESRGAN paper
-def unet_learner4(data:DataBunch, arch:Callable, pretrained:bool=True, blur_final:bool=True,
+#The code below is meant to be merged into fastaiv1 ideally
+def custom_unet_learner(data:DataBunch, arch:Callable, pretrained:bool=True, blur_final:bool=True,
                  norm_type:Optional[NormType]=NormType, split_on:Optional[SplitFuncOrIdxList]=None, 
                  blur:bool=False, self_attention:bool=False, y_range:Optional[Tuple[float,float]]=None, last_cross:bool=True,
-                 bottle:bool=False, nf_factor:float=1.0, **kwargs:Any)->None:
+                 bottle:bool=False, nf_factor:float=1.0, **kwargs:Any)->Learner:
     "Build Unet learner from `data` and `arch`."
     meta = cnn_config(arch)
     body = create_body(arch, pretrained)
-    model = to_device(DynamicUnet4(body, n_classes=data.c, blur=blur, blur_final=blur_final,
+    model = to_device(CustomDynamicUnet(body, n_classes=data.c, blur=blur, blur_final=blur_final,
           self_attention=self_attention, y_range=y_range, norm_type=norm_type, last_cross=last_cross,
           bottle=bottle, nf_factor=nf_factor), data.device)
     learn = Learner(data, model, **kwargs)
     learn.split(ifnone(split_on,meta['split']))
     if pretrained: learn.freeze()
     apply_init(model[2], nn.init.kaiming_normal_)
-    return learn
-
-
-#No batch norm in ESRGAN paper, custom nf width
-def unet_learner5(data:DataBunch, arch:Callable, pretrained:bool=True, blur_final:bool=True,
-                 norm_type:Optional[NormType]=NormType, split_on:Optional[SplitFuncOrIdxList]=None, 
-                 blur:bool=False, self_attention:bool=False, y_range:Optional[Tuple[float,float]]=None, last_cross:bool=True,
-                 bottle:bool=True, **kwargs:Any)->None:
-    "Build Unet learner from `data` and `arch`."
-    meta = cnn_config(arch)
-    body = create_body(arch, pretrained)
-    model = to_device(DynamicUnet5(body, n_classes=data.c, blur=blur, blur_final=blur_final,
-          self_attention=self_attention, y_range=y_range, norm_type=norm_type, last_cross=last_cross,
-          bottle=bottle), data.device)
-    learn = Learner(data, model, **kwargs)
-    learn.split(ifnone(split_on,meta['split']))
-    if pretrained: learn.freeze()
-    apply_init(model[2], nn.init.kaiming_normal_)
-    return learn
-
+    return learn

+ 2 - 38
fasterai/layers.py

@@ -6,7 +6,7 @@ from torch.autograd import Variable
 
 #The code below is meant to be merged into fastaiv1 ideally
 
-def conv_layer2(ni:int, nf:int, ks:int=3, stride:int=1, padding:int=None, bias:bool=None, is_1d:bool=False,
+def custom_conv_layer(ni:int, nf:int, ks:int=3, stride:int=1, padding:int=None, bias:bool=None, is_1d:bool=False,
                norm_type:Optional[NormType]=NormType.Batch,  use_activ:bool=True, leaky:float=None,
                transpose:bool=False, init:Callable=nn.init.kaiming_normal_, self_attention:bool=False,
                extra_bn:bool=False):
@@ -21,41 +21,5 @@ def conv_layer2(ni:int, nf:int, ks:int=3, stride:int=1, padding:int=None, bias:b
     layers = [conv]
     if use_activ: layers.append(relu(True, leaky=leaky))
     if bn: layers.append((nn.BatchNorm1d if is_1d else nn.BatchNorm2d)(nf))
-    
-    #TODO:  Account for 1D
-    #if norm_type==NormType.Weight: layers.append(MeanOnlyBatchNorm(nf))
-
     if self_attention: layers.append(SelfAttention(nf))
-    return nn.Sequential(*layers)
-
-class MeanOnlyBatchNorm(nn.Module):
-    def __init__(self, num_features, momentum=0.1):
-        super(MeanOnlyBatchNorm, self).__init__()
-        self.num_features = num_features
-        self.momentum = momentum
-        self.weight = Parameter(torch.Tensor(num_features))
-        self.bias = Parameter(torch.Tensor(num_features))
-
-        self.register_buffer('running_mean', torch.zeros(num_features))
-        self.reset_parameters()
-        
-    def reset_parameters(self):
-        self.running_mean.zero_()
-        self.weight.data.uniform_()
-        self.bias.data.zero_()
-
-    def forward(self, inp):
-        size = list(inp.size())
-        gamma = self.weight.view(1, self.num_features, 1, 1)
-        beta = self.bias.view(1, self.num_features, 1, 1)
-
-        if self.training:
-            avg = torch.mean(inp.view(size[0], self.num_features, -1), dim=2)
-            self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * torch.mean(avg.data, dim=0)
-        else:
-            avg = Variable(self.running_mean.repeat(size[0], 1), requires_grad=False)
-
-        output = inp - avg.view(size[0], size[1], 1, 1)
-        output = output*gamma + beta
-
-        return output
+    return nn.Sequential(*layers)

+ 184 - 179
fasterai/tensorboard.py

@@ -1,4 +1,4 @@
-import fastai
+"Provides convenient callbacks for Learners that write model images, metrics/losses, stats and histograms to Tensorboard"
 from fastai.basic_train import Learner
 from fastai.basic_data import DatasetType, DataBunch
 from fastai.vision import Image
@@ -6,7 +6,6 @@ from fastai.callbacks import LearnerCallback
 from fastai.core import *
 from fastai.torch_core import *
 from threading import Thread, Event
-import time
 from time import sleep
 from queue import Queue
 import statistics
@@ -15,6 +14,189 @@ from abc import ABC, abstractmethod
 from tensorboardX import SummaryWriter
 
 
+__all__=['LearnerTensorboardWriter', 'GANTensorboardWriter', 'ImageGenTensorboardWriter']
+
+
+class LearnerTensorboardWriter(LearnerCallback):
+    def __init__(self, learn:Learner, base_dir:Path, name:str, loss_iters:int=25, hist_iters:int=500, stats_iters:int=100):
+        super().__init__(learn=learn)
+        self.base_dir = base_dir
+        self.name = name
+        log_dir = base_dir/name
+        self.tbwriter = SummaryWriter(log_dir=str(log_dir))
+        self.loss_iters = loss_iters
+        self.hist_iters = hist_iters
+        self.stats_iters = stats_iters
+        self.hist_writer = HistogramTBWriter()
+        self.stats_writer = ModelStatsTBWriter()
+        self.data = None
+        self.metrics_root = '/metrics/'
+        self._update_batches_if_needed()
+
+    def _update_batches_if_needed(self):
+        # one_batch function is extremely slow with large datasets.  This is an optimization.
+        # Note that also we want to always show the same batches so we can see changes 
+        # in tensorboard
+        update_batches = self.data is not self.learn.data
+
+        if update_batches:
+            self.data = self.learn.data
+            self.trn_batch = self.learn.data.one_batch(
+                ds_type=DatasetType.Train, detach=True, denorm=False, cpu=False)
+            self.val_batch = self.learn.data.one_batch(
+                ds_type=DatasetType.Valid, detach=True, denorm=False, cpu=False)
+
+    def _write_model_stats(self, iteration:int):
+        self.stats_writer.write(
+            model=self.learn.model, iteration=iteration, tbwriter=self.tbwriter)
+
+    def _write_training_loss(self, iteration:int, last_loss:Tensor):
+        scalar_value = to_np(last_loss)
+        tag = self.metrics_root + 'train_loss'
+        self.tbwriter.add_scalar(tag=tag, scalar_value=scalar_value, global_step=iteration)
+
+    def _write_weight_histograms(self, iteration:int):
+        self.hist_writer.write(
+            model=self.learn.model, iteration=iteration, tbwriter=self.tbwriter)
+
+    #TODO:  Relying on a specific hardcoded start_idx here isn't great.  Is there a better solution?
+    def _write_metrics(self, iteration:int, last_metrics:MetricsList, start_idx:int=2):
+        recorder = self.learn.recorder
+
+        for i, name in enumerate(recorder.names[start_idx:]):
+            if len(last_metrics) < i+1: return
+            scalar_value = last_metrics[i]
+            tag = self.metrics_root + name
+            self.tbwriter.add_scalar(tag=tag, scalar_value=scalar_value, global_step=iteration)
+
+    def on_batch_end(self, last_loss:Tensor, iteration:int, **kwargs):
+        if iteration == 0: return
+        self._update_batches_if_needed()
+
+        if iteration % self.loss_iters == 0:
+            self._write_training_loss(iteration=iteration, last_loss=last_loss)
+
+        if iteration % self.hist_iters == 0:
+            self._write_weight_histograms(iteration=iteration)
+
+    # Doing stuff here that requires gradient info, because they get zeroed out afterwards in training loop
+    def on_backward_end(self, iteration:int, **kwargs):
+        if iteration == 0: return
+        self._update_batches_if_needed()
+
+        if iteration % self.stats_iters == 0:
+            self._write_model_stats(iteration=iteration)
+
+    def on_epoch_end(self, last_metrics:MetricsList, iteration:int, **kwargs):
+        self._write_metrics(iteration=iteration, last_metrics=last_metrics)
+
+# TODO:  We're overriding almost everything here.  Seems like a good idea to question that ("is a" vs "has a")
+class GANTensorboardWriter(LearnerTensorboardWriter):
+    def __init__(self, learn:Learner, base_dir:Path, name:str, loss_iters:int=25, hist_iters:int=500,
+                 stats_iters:int=100, visual_iters:int=100):
+        super().__init__(learn=learn, base_dir=base_dir, name=name, loss_iters=loss_iters,
+                         hist_iters=hist_iters, stats_iters=stats_iters)
+        self.visual_iters = visual_iters
+        self.img_gen_vis = ImageTBWriter()
+        self.gen_stats_updated = True
+        self.crit_stats_updated = True
+
+    # override
+    def _write_weight_histograms(self, iteration:int):
+        trainer = self.learn.gan_trainer
+        generator = trainer.generator
+        critic = trainer.critic
+        self.hist_writer.write(
+            model=generator, iteration=iteration, tbwriter=self.tbwriter, name='generator')
+        self.hist_writer.write(
+            model=critic, iteration=iteration, tbwriter=self.tbwriter, name='critic')
+
+    # override
+    def _write_model_stats(self, iteration:int):
+        trainer = self.learn.gan_trainer
+        generator = trainer.generator
+        critic = trainer.critic
+
+        # Don't want to write stats when model is not iterated on and hence has zeroed out gradients
+        gen_mode = trainer.gen_mode
+
+        if gen_mode and not self.gen_stats_updated:
+            self.stats_writer.write(
+                model=generator, iteration=iteration, tbwriter=self.tbwriter, name='gen_model_stats')
+            self.gen_stats_updated = True
+
+        if not gen_mode and not self.crit_stats_updated:
+            self.stats_writer.write(
+                model=critic, iteration=iteration, tbwriter=self.tbwriter, name='crit_model_stats')
+            self.crit_stats_updated = True
+
+    # override
+    def _write_training_loss(self, iteration:int, last_loss:Tensor):
+        trainer = self.learn.gan_trainer
+        recorder = trainer.recorder
+
+        if len(recorder.losses) > 0:
+            scalar_value = to_np((recorder.losses[-1:])[0])
+            tag = self.metrics_root + 'train_loss'
+            self.tbwriter.add_scalar(tag=tag, scalar_value=scalar_value, global_step=iteration)
+
+    def _write(self, iteration:int):
+        trainer = self.learn.gan_trainer
+        #TODO:  Switching gen_mode temporarily seems a bit hacky here.  Certainly not a good side-effect.  Is there a better way?
+        gen_mode = trainer.gen_mode
+
+        try:
+            trainer.switch(gen_mode=True)
+            self.img_gen_vis.write(learn=self.learn, trn_batch=self.trn_batch, val_batch=self.val_batch,
+                                                    iteration=iteration, tbwriter=self.tbwriter)
+        finally:                                      
+            trainer.switch(gen_mode=gen_mode)
+
+    # override
+    def on_batch_end(self, iteration:int, **kwargs):
+        super().on_batch_end(iteration=iteration, **kwargs)
+        if iteration == 0: return
+        if iteration % self.visual_iters == 0:
+            self._write(iteration=iteration)
+
+    # override
+    def on_backward_end(self, iteration:int, **kwargs):
+        if iteration == 0: return
+        self._update_batches_if_needed()
+
+        #TODO:  This could perhaps be implemented as queues of requests instead but that seemed like overkill. 
+        # But I'm not the biggest fan of maintaining these boolean flags either... Review pls.
+        if iteration % self.stats_iters == 0:
+            self.gen_stats_updated = False
+            self.crit_stats_updated = False
+
+        if not (self.gen_stats_updated and self.crit_stats_updated):
+            self._write_model_stats(iteration=iteration)
+
+
+class ImageGenTensorboardWriter(LearnerTensorboardWriter):
+    def __init__(self, learn:Learner, base_dir:Path, name:str, loss_iters:int=25, hist_iters:int=500,
+                 stats_iters: int = 100, visual_iters: int = 100):
+        super().__init__(learn=learn, base_dir=base_dir, name=name, loss_iters=loss_iters, hist_iters=hist_iters,
+                         stats_iters=stats_iters)
+        self.visual_iters = visual_iters
+        self.img_gen_vis = ImageTBWriter()
+
+    def _write(self, iteration:int):
+        self.img_gen_vis.write(learn=self.learn, trn_batch=self.trn_batch, val_batch=self.val_batch,
+                                                  iteration=iteration, tbwriter=self.tbwriter)
+
+    # override
+    def on_batch_end(self, iteration:int, **kwargs):
+        super().on_batch_end(iteration=iteration, **kwargs)
+        if iteration == 0: return
+
+        if iteration % self.visual_iters == 0:
+            self._write(iteration=iteration)
+
+
+#------PRIVATE-----------
+
 class TBWriteRequest(ABC):
     def __init__(self, tbwriter: SummaryWriter, iteration:int):
         super().__init__()
@@ -218,180 +400,3 @@ class ImageTBWriter():
 
 
 
-#--------CALLBACKS----------------#
-class LearnerTensorboardWriter(LearnerCallback):
-    def __init__(self, learn:Learner, base_dir:Path, name:str, loss_iters:int=25, hist_iters:int=500, stats_iters:int=100):
-        super().__init__(learn=learn)
-        self.base_dir = base_dir
-        self.name = name
-        log_dir = base_dir/name
-        self.tbwriter = SummaryWriter(log_dir=str(log_dir))
-        self.loss_iters = loss_iters
-        self.hist_iters = hist_iters
-        self.stats_iters = stats_iters
-        self.hist_writer = HistogramTBWriter()
-        self.stats_writer = ModelStatsTBWriter()
-        self.data = None
-        self.metrics_root = '/metrics/'
-        self._update_batches_if_needed()
-
-    def _update_batches_if_needed(self):
-        # one_batch function is extremely slow with large datasets.  This is an optimization.
-        # Note that also we want to always show the same batches so we can see changes 
-        # in tensorboard
-        update_batches = self.data is not self.learn.data
-
-        if update_batches:
-            self.data = self.learn.data
-            self.trn_batch = self.learn.data.one_batch(
-                ds_type=DatasetType.Train, detach=True, denorm=False, cpu=False)
-            self.val_batch = self.learn.data.one_batch(
-                ds_type=DatasetType.Valid, detach=True, denorm=False, cpu=False)
-
-    def _write_model_stats(self, iteration:int):
-        self.stats_writer.write(
-            model=self.learn.model, iteration=iteration, tbwriter=self.tbwriter)
-
-    def _write_training_loss(self, iteration:int, last_loss:Tensor):
-        scalar_value = to_np(last_loss)
-        tag = self.metrics_root + 'train_loss'
-        self.tbwriter.add_scalar(tag=tag, scalar_value=scalar_value, global_step=iteration)
-
-    def _write_weight_histograms(self, iteration:int):
-        self.hist_writer.write(
-            model=self.learn.model, iteration=iteration, tbwriter=self.tbwriter)
-
-    #TODO:  Relying on a specific hardcoded start_idx here isn't great.  Is there a better solution?
-    def _write_metrics(self, iteration:int, last_metrics:MetricsList, start_idx:int=2):
-        recorder = self.learn.recorder
-
-        for i, name in enumerate(recorder.names[start_idx:]):
-            if len(last_metrics) < i+1: return
-            scalar_value = last_metrics[i]
-            tag = self.metrics_root + name
-            self.tbwriter.add_scalar(tag=tag, scalar_value=scalar_value, global_step=iteration)
-
-    def on_batch_end(self, last_loss:Tensor, iteration:int, **kwargs):
-        if iteration == 0: return
-        self._update_batches_if_needed()
-
-        if iteration % self.loss_iters == 0:
-            self._write_training_loss(iteration=iteration, last_loss=last_loss)
-
-        if iteration % self.hist_iters == 0:
-            self._write_weight_histograms(iteration=iteration)
-
-    # Doing stuff here that requires gradient info, because they get zeroed out afterwards in training loop
-    def on_backward_end(self, iteration:int, **kwargs):
-        if iteration == 0: return
-        self._update_batches_if_needed()
-
-        if iteration % self.stats_iters == 0:
-            self._write_model_stats(iteration=iteration)
-
-    def on_epoch_end(self, last_metrics:MetricsList, iteration:int, **kwargs):
-        self._write_metrics(iteration=iteration, last_metrics=last_metrics)
-
-# TODO:  We're overriding almost everything here.  Seems like a good idea to question that ("is a" vs "has a")
-class GANTensorboardWriter(LearnerTensorboardWriter):
-    def __init__(self, learn:Learner, base_dir:Path, name:str, loss_iters:int=25, hist_iters:int=500,
-                 stats_iters:int=100, visual_iters:int=100):
-        super().__init__(learn=learn, base_dir=base_dir, name=name, loss_iters=loss_iters,
-                         hist_iters=hist_iters, stats_iters=stats_iters)
-        self.visual_iters = visual_iters
-        self.img_gen_vis = ImageTBWriter()
-        self.gen_stats_updated = True
-        self.crit_stats_updated = True
-
-    # override
-    def _write_weight_histograms(self, iteration:int):
-        trainer = self.learn.gan_trainer
-        generator = trainer.generator
-        critic = trainer.critic
-        self.hist_writer.write(
-            model=generator, iteration=iteration, tbwriter=self.tbwriter, name='generator')
-        self.hist_writer.write(
-            model=critic, iteration=iteration, tbwriter=self.tbwriter, name='critic')
-
-    # override
-    def _write_model_stats(self, iteration:int):
-        trainer = self.learn.gan_trainer
-        generator = trainer.generator
-        critic = trainer.critic
-
-        # Don't want to write stats when model is not iterated on and hence has zeroed out gradients
-        gen_mode = trainer.gen_mode
-
-        if gen_mode and not self.gen_stats_updated:
-            self.stats_writer.write(
-                model=generator, iteration=iteration, tbwriter=self.tbwriter, name='gen_model_stats')
-            self.gen_stats_updated = True
-
-        if not gen_mode and not self.crit_stats_updated:
-            self.stats_writer.write(
-                model=critic, iteration=iteration, tbwriter=self.tbwriter, name='crit_model_stats')
-            self.crit_stats_updated = True
-
-    # override
-    def _write_training_loss(self, iteration:int, last_loss:Tensor):
-        trainer = self.learn.gan_trainer
-        recorder = trainer.recorder
-
-        if len(recorder.losses) > 0:
-            scalar_value = to_np((recorder.losses[-1:])[0])
-            tag = self.metrics_root + 'train_loss'
-            self.tbwriter.add_scalar(tag=tag, scalar_value=scalar_value, global_step=iteration)
-
-    def _write(self, iteration:int):
-        trainer = self.learn.gan_trainer
-        #TODO:  Switching gen_mode temporarily seems a bit hacky here.  Certainly not a good side-effect.  Is there a better way?
-        gen_mode = trainer.gen_mode
-
-        try:
-            trainer.switch(gen_mode=True)
-            self.img_gen_vis.write(learn=self.learn, trn_batch=self.trn_batch, val_batch=self.val_batch,
-                                                    iteration=iteration, tbwriter=self.tbwriter)
-        finally:                                      
-            trainer.switch(gen_mode=gen_mode)
-
-    # override
-    def on_batch_end(self, iteration:int, **kwargs):
-        super().on_batch_end(iteration=iteration, **kwargs)
-        if iteration == 0: return
-        if iteration % self.visual_iters == 0:
-            self._write(iteration=iteration)
-
-    # override
-    def on_backward_end(self, iteration:int, **kwargs):
-        if iteration == 0: return
-        self._update_batches_if_needed()
-
-        #TODO:  This could perhaps be implemented as queues of requests instead but that seemed like overkill. 
-        # But I'm not the biggest fan of maintaining these boolean flags either... Review pls.
-        if iteration % self.stats_iters == 0:
-            self.gen_stats_updated = False
-            self.crit_stats_updated = False
-
-        if not (self.gen_stats_updated and self.crit_stats_updated):
-            self._write_model_stats(iteration=iteration)
-
-
-class ImageGenTensorboardWriter(LearnerTensorboardWriter):
-    def __init__(self, learn:Learner, base_dir:Path, name:str, loss_iters:int=25, hist_iters:int=500,
-                 stats_iters: int = 100, visual_iters: int = 100):
-        super().__init__(learn=learn, base_dir=base_dir, name=name, loss_iters=loss_iters, hist_iters=hist_iters,
-                         stats_iters=stats_iters)
-        self.visual_iters = visual_iters
-        self.img_gen_vis = ImageTBWriter()
-
-    def _write(self, iteration:int):
-        self.img_gen_vis.write(learn=self.learn, trn_batch=self.trn_batch, val_batch=self.val_batch,
-                                                  iteration=iteration, tbwriter=self.tbwriter)
-
-    # override
-    def on_batch_end(self, iteration:int, **kwargs):
-        super().on_batch_end(iteration=iteration, **kwargs)
-        if iteration == 0: return
-
-        if iteration % self.visual_iters == 0:
-            self._write(iteration=iteration)

+ 13 - 163
fasterai/unet.py

@@ -1,11 +1,11 @@
 from fastai.layers import *
-from fasterai.layers import *
+from .layers import *
 from fastai.torch_core import *
 from fastai.callbacks.hooks import *
 
 #The code below is meant to be merged into fastaiv1 ideally
 
-__all__ = ['DynamicUnet2', 'UnetBlock2']
+__all__ = ['CustomDynamicUnet', 'CustomUnetBlock', 'CustomPixelShuffle_ICNR']
 
 def _get_sfs_idxs(sizes:Sizes) -> List[int]:
     "Get the indexes of the layers where the size of the activation changes."
@@ -14,12 +14,12 @@ def _get_sfs_idxs(sizes:Sizes) -> List[int]:
     if feature_szs[0] != feature_szs[1]: sfs_idxs = [0] + sfs_idxs
     return sfs_idxs
 
-class PixelShuffle_ICNR2(nn.Module):
+class CustomPixelShuffle_ICNR(nn.Module):
     "Upsample by `scale` from `ni` filters to `nf` (default `ni`), using `nn.PixelShuffle`, `icnr` init, and `weight_norm`."
     def __init__(self, ni:int, nf:int=None, scale:int=2, blur:bool=False, leaky:float=None, **kwargs):
         super().__init__()
         nf = ifnone(nf, ni)
-        self.conv = conv_layer2(ni, nf*(scale**2), ks=1, use_activ=False, **kwargs)
+        self.conv = custom_conv_layer(ni, nf*(scale**2), ks=1, use_activ=False, **kwargs)
         icnr(self.conv[0].weight)
         self.shuf = nn.PixelShuffle(scale)
         # Blurring over (h*w) kernel
@@ -33,18 +33,18 @@ class PixelShuffle_ICNR2(nn.Module):
         x = self.shuf(self.relu(self.conv(x)))
         return self.blur(self.pad(x)) if self.blur else x
 
-class UnetBlock2(nn.Module):
+class CustomUnetBlock(nn.Module):
     "A quasi-UNet block, using `PixelShuffle_ICNR upsampling`."
     def __init__(self, up_in_c:int, x_in_c:int, hook:Hook, final_div:bool=True, blur:bool=False, leaky:float=None,
                  self_attention:bool=False, nf_factor:float=1.0,  **kwargs):
         super().__init__()
         self.hook = hook
-        self.shuf = PixelShuffle_ICNR2(up_in_c, up_in_c//2, blur=blur, leaky=leaky, **kwargs)
+        self.shuf = CustomPixelShuffle_ICNR(up_in_c, up_in_c//2, blur=blur, leaky=leaky, **kwargs)
         self.bn = batchnorm_2d(x_in_c)
         ni = up_in_c//2 + x_in_c
         nf = int((ni if final_div else ni//2)*nf_factor)
-        self.conv1 = conv_layer2(ni, nf, leaky=leaky, **kwargs)
-        self.conv2 = conv_layer2(nf, nf, leaky=leaky, self_attention=self_attention, **kwargs)
+        self.conv1 = custom_conv_layer(ni, nf, leaky=leaky, **kwargs)
+        self.conv2 = custom_conv_layer(nf, nf, leaky=leaky, self_attention=self_attention, **kwargs)
         self.relu = relu(leaky=leaky)
 
     def forward(self, up_in:Tensor) -> Tensor:
@@ -57,12 +57,11 @@ class UnetBlock2(nn.Module):
         return self.conv2(self.conv1(cat_x))
 
 
-class DynamicUnet2(SequentialEx):
+class CustomDynamicUnet(SequentialEx):
     "Create a U-Net from a given architecture."
     def __init__(self, encoder:nn.Module, n_classes:int, blur:bool=False, blur_final=True, self_attention:bool=False,
                  y_range:Optional[Tuple[float,float]]=None, last_cross:bool=True, bottle:bool=False,
                  norm_type:Optional[NormType]=NormType.Batch, nf_factor:float=1.0, **kwargs):
-        #extra_bn =  norm_type in (NormType.Spectral, NormType.Weight)
         extra_bn =  norm_type == NormType.Spectral
         imsize = (256,256)
         sfs_szs = model_sizes(encoder, size=imsize)
@@ -71,8 +70,8 @@ class DynamicUnet2(SequentialEx):
         x = dummy_eval(encoder, imsize).detach()
 
         ni = sfs_szs[-1][1]
-        middle_conv = nn.Sequential(conv_layer2(ni, ni*2, norm_type=norm_type, extra_bn=extra_bn, **kwargs),
-                                    conv_layer2(ni*2, ni, norm_type=norm_type, extra_bn=extra_bn, **kwargs)).eval()
+        middle_conv = nn.Sequential(custom_conv_layer(ni, ni*2, norm_type=norm_type, extra_bn=extra_bn, **kwargs),
+                                    custom_conv_layer(ni*2, ni, norm_type=norm_type, extra_bn=extra_bn, **kwargs)).eval()
         x = middle_conv(x)
         layers = [encoder, batchnorm_2d(ni), nn.ReLU(), middle_conv]
 
@@ -81,94 +80,7 @@ class DynamicUnet2(SequentialEx):
             up_in_c, x_in_c = int(x.shape[1]), int(sfs_szs[idx][1])
             do_blur = blur and (not_final or blur_final)
             sa = self_attention and (i==len(sfs_idxs)-3)
-            unet_block = UnetBlock2(up_in_c, x_in_c, self.sfs[i], final_div=not_final, blur=blur, self_attention=sa,
-                                   norm_type=norm_type, extra_bn=extra_bn, nf_factor=nf_factor, **kwargs).eval()
-            layers.append(unet_block)
-            x = unet_block(x)
-
-        ni = x.shape[1]
-        if imsize != sfs_szs[0][-2:]: layers.append(PixelShuffle_ICNR(ni, **kwargs))
-        if last_cross:
-            layers.append(MergeLayer(dense=True))
-            ni += in_channels(encoder)
-            #TODO:  Missing norm_type argument here.  DOH!
-            layers.append(res_block(ni, bottle=bottle, **kwargs))
-        layers += [conv_layer2(ni, n_classes, ks=1, use_activ=False, norm_type=norm_type)]
-        if y_range is not None: layers.append(SigmoidRange(*y_range))
-        super().__init__(*layers)
-
-    def __del__(self):
-        if hasattr(self, "sfs"): self.sfs.remove()
-
-
-class DynamicUnet3(SequentialEx):
-    "Create a U-Net from a given architecture."
-    def __init__(self, encoder:nn.Module, n_classes:int, blur:bool=False, blur_final=True, self_attention:bool=False,
-                 y_range:Optional[Tuple[float,float]]=None, last_cross:bool=True, bottle:bool=False,
-                 norm_type:Optional[NormType]=NormType.Batch, nf_factor:float=1.0, **kwargs):
-        extra_bn =  norm_type == NormType.Spectral
-        imsize = (256,256)
-        sfs_szs = model_sizes(encoder, size=imsize)
-        sfs_idxs = list(reversed(_get_sfs_idxs(sfs_szs)))
-        self.sfs = hook_outputs([encoder[i] for i in sfs_idxs])
-        x = dummy_eval(encoder, imsize).detach()
-
-        ni = sfs_szs[-1][1]
-        middle_conv = nn.Sequential(conv_layer2(ni, ni*2, norm_type=norm_type, extra_bn=extra_bn, **kwargs),
-                                    conv_layer2(ni*2, ni, norm_type=norm_type, extra_bn=extra_bn, **kwargs)).eval()
-        x = middle_conv(x)
-        layers = [encoder, batchnorm_2d(ni), nn.ReLU(), middle_conv]
-
-        for i,idx in enumerate(sfs_idxs):
-            not_final = i!=len(sfs_idxs)-1
-            up_in_c, x_in_c = int(x.shape[1]), int(sfs_szs[idx][1])
-            do_blur = blur and (not_final or blur_final)
-            sa = self_attention and (i==len(sfs_idxs)-3)
-            unet_block = UnetBlock2(up_in_c, x_in_c, self.sfs[i], final_div=not_final, blur=blur, self_attention=sa,
-                                   norm_type=norm_type, extra_bn=extra_bn, nf_factor=nf_factor, **kwargs).eval()
-            layers.append(unet_block)
-            x = unet_block(x)
-
-        ni = x.shape[1]
-        if imsize != sfs_szs[0][-2:]: layers.append(PixelShuffle_ICNR(ni, **kwargs))
-        if last_cross:
-            layers.append(MergeLayer(dense=True))
-            ni += in_channels(encoder)
-            layers.append(res_block(ni, bottle=bottle, norm_type=norm_type, **kwargs))
-        layers += [conv_layer2(ni, n_classes, ks=1, use_activ=False, norm_type=norm_type)]
-        if y_range is not None: layers.append(SigmoidRange(*y_range))
-        super().__init__(*layers)
-
-    def __del__(self):
-        if hasattr(self, "sfs"): self.sfs.remove()
-
-#No batch norm
-class DynamicUnet4(SequentialEx):
-    "Create a U-Net from a given architecture."
-    def __init__(self, encoder:nn.Module, n_classes:int, blur:bool=False, blur_final=True, self_attention:bool=False,
-                 y_range:Optional[Tuple[float,float]]=None, last_cross:bool=True, bottle:bool=False,
-                 norm_type:Optional[NormType]=NormType.Batch, nf_factor:float=1.0, **kwargs):
-        #extra_bn =  norm_type == NormType.Spectral
-        extra_bn = False
-        imsize = (256,256)
-        sfs_szs = model_sizes(encoder, size=imsize)
-        sfs_idxs = list(reversed(_get_sfs_idxs(sfs_szs)))
-        self.sfs = hook_outputs([encoder[i] for i in sfs_idxs])
-        x = dummy_eval(encoder, imsize).detach()
-
-        ni = sfs_szs[-1][1]
-        middle_conv = nn.Sequential(conv_layer2(ni, ni*2, norm_type=norm_type, extra_bn=extra_bn, **kwargs),
-                                    conv_layer2(ni*2, ni, norm_type=norm_type, extra_bn=extra_bn, **kwargs)).eval()
-        x = middle_conv(x)
-        #layers = [encoder, batchnorm_2d(ni), nn.ReLU(), middle_conv]
-        layers = [encoder, nn.ReLU(), middle_conv]
-
-        for i,idx in enumerate(sfs_idxs):
-            not_final = i!=len(sfs_idxs)-1
-            up_in_c, x_in_c = int(x.shape[1]), int(sfs_szs[idx][1])
-            do_blur = blur and (not_final or blur_final)
-            sa = self_attention and (i==len(sfs_idxs)-3)
-            unet_block = UnetBlock2(up_in_c, x_in_c, self.sfs[i], final_div=not_final, blur=blur, self_attention=sa,
+            unet_block = CustomUnetBlock(up_in_c, x_in_c, self.sfs[i], final_div=not_final, blur=blur, self_attention=sa,
                                    norm_type=norm_type, extra_bn=extra_bn, nf_factor=nf_factor, **kwargs).eval()
             layers.append(unet_block)
             x = unet_block(x)
@@ -179,74 +91,12 @@ class DynamicUnet4(SequentialEx):
             layers.append(MergeLayer(dense=True))
             ni += in_channels(encoder)
             layers.append(res_block(ni, bottle=bottle, norm_type=norm_type, **kwargs))
-        layers += [conv_layer2(ni, n_classes, ks=1, use_activ=False, norm_type=norm_type)]
+        layers += [custom_conv_layer(ni, n_classes, ks=1, use_activ=False, norm_type=norm_type)]
         if y_range is not None: layers.append(SigmoidRange(*y_range))
         super().__init__(*layers)
 
     def __del__(self):
         if hasattr(self, "sfs"): self.sfs.remove()
 
-class UnetBlock5(nn.Module):
-    "A quasi-UNet block, using `PixelShuffle_ICNR upsampling`."
-    def __init__(self, up_in_c:int, x_in_c:int, out_c:int, hook:Hook, final_div:bool=True, blur:bool=False, leaky:float=None,
-                 self_attention:bool=False,  **kwargs):
-        super().__init__()
-        self.hook = hook
-        self.shuf = PixelShuffle_ICNR2(up_in_c, up_in_c//2, blur=blur, leaky=leaky, **kwargs)
-        self.bn = batchnorm_2d(x_in_c)
-        ni = up_in_c//2 + x_in_c
-        nf = out_c
-        self.conv = conv_layer2(ni, nf, leaky=leaky, self_attention=self_attention, **kwargs)
-        self.relu = relu(leaky=leaky)
 
-    def forward(self, up_in:Tensor) -> Tensor:
-        s = self.hook.stored
-        up_out = self.shuf(up_in)
-        ssh = s.shape[-2:]
-        if ssh != up_out.shape[-2:]:
-            up_out = F.interpolate(up_out, s.shape[-2:], mode='nearest')
-        cat_x = self.relu(torch.cat([up_out, self.bn(s)], dim=1))
-        return self.conv(cat_x)
 
-#custom filter widths
-class DynamicUnet5(SequentialEx):
-    "Create a U-Net from a given architecture."
-    def __init__(self, encoder:nn.Module, n_classes:int, blur:bool=False, blur_final=True, self_attention:bool=False,
-                 y_range:Optional[Tuple[float,float]]=None, last_cross:bool=True, bottle:bool=True,
-                 norm_type:Optional[NormType]=NormType.Batch, nf:int=256, **kwargs):
-        extra_bn =  norm_type == NormType.Spectral
-        imsize = (256,256)
-        sfs_szs = model_sizes(encoder, size=imsize)
-        sfs_idxs = list(reversed(_get_sfs_idxs(sfs_szs)))
-        self.sfs = hook_outputs([encoder[i] for i in sfs_idxs])
-        x = dummy_eval(encoder, imsize).detach()
-
-        ni = sfs_szs[-1][1]
-        middle_conv = nn.Sequential(conv_layer2(ni, ni*2, norm_type=norm_type, extra_bn=extra_bn, **kwargs),
-                                    conv_layer2(ni*2, ni, norm_type=norm_type, extra_bn=extra_bn, **kwargs)).eval()
-        x = middle_conv(x)
-        layers = [encoder, batchnorm_2d(ni), nn.ReLU(), middle_conv]
-
-        for i,idx in enumerate(sfs_idxs):
-            not_final = i!=len(sfs_idxs)-1
-            up_in_c = int(x.shape[1]) if i == 0 else nf
-            x_in_c = int(sfs_szs[idx][1])
-            do_blur = blur and (not_final or blur_final)
-            sa = self_attention and (i==len(sfs_idxs)-3)
-            unet_block = UnetBlock5(up_in_c, x_in_c, nf, self.sfs[i], final_div=not_final, blur=blur, self_attention=sa,
-                                   norm_type=norm_type, extra_bn=extra_bn, **kwargs).eval()
-            layers.append(unet_block)
-            x = unet_block(x)
-
-        ni = x.shape[1]
-        if imsize != sfs_szs[0][-2:]: layers.append(PixelShuffle_ICNR(ni, **kwargs))
-        if last_cross:
-            layers.append(MergeLayer(dense=True))
-            ni += in_channels(encoder)
-            layers.append(res_block(ni, bottle=bottle, norm_type=norm_type, **kwargs))
-        layers += [conv_layer2(ni, n_classes, ks=1, use_activ=False, norm_type=norm_type)]
-        if y_range is not None: layers.append(SigmoidRange(*y_range))
-        super().__init__(*layers)
-
-    def __del__(self):
-        if hasattr(self, "sfs"): self.sfs.remove()

+ 8 - 1
fasterai/visualize.py

@@ -3,7 +3,8 @@ from fastai.vision import *
 from matplotlib.axes import Axes
 from matplotlib.figure import Figure
 from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
-from .filters import IFilter
+from .filters import IFilter, MasterFilter, ColorizerFilter
+from .generators import colorize_gen_inference
 from IPython.display import display
 from tensorboardX import SummaryWriter
 from scipy import misc
@@ -51,6 +52,12 @@ class ModelImageVisualizer():
         return rows, columns
 
 
+def get_colorize_visualizer(root_folder:Path=Path('./'), weights_name:str='colorizer_gen', 
+        results_dir = 'result_images', nf_factor:float=1.25, render_factor:int=21)->ModelImageVisualizer:
+    learn = colorize_gen_inference(root_folder=root_folder, weights_name=weights_name, nf_factor=nf_factor)
+    filtr = MasterFilter([ColorizerFilter(learn=learn)], render_factor=render_factor)
+    vis = ModelImageVisualizer(filtr, results_dir=results_dir)
+    return vis