瀏覽代碼

Much needed cleanup of code

Getting rid of usused modules;  Putting in mising type hints;  Deleting usused logic.
Jason Antic 6 年之前
父節點
當前提交
2a300aa84f

+ 1 - 0
.gitignore

@@ -28,3 +28,4 @@ SymbolicLinks.sh
 .ipynb_checkpoints/README-checkpoint.md
 .ipynb_checkpoints/ComboVisualization-checkpoint.ipynb
 .ipynb_checkpoints/ColorizeTraining2-checkpoint.ipynb
+test_images/Uaqapqr.jpg

+ 7 - 7
ColorizeTraining.ipynb

@@ -2,7 +2,7 @@
  "cells": [
   {
    "cell_type": "code",
-   "execution_count": 1,
+   "execution_count": null,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -13,7 +13,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 2,
+   "execution_count": null,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -41,13 +41,13 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 3,
+   "execution_count": null,
    "metadata": {},
    "outputs": [],
    "source": [
     "IMAGENET = Path('data/imagenet/ILSVRC/Data/CLS-LOC/train')\n",
     "proj_id = 'colorize'\n",
-    "TENSORBOARD_PATH = Path('data/tensorboard/')\n",
+    "TENSORBOARD_PATH = Path('data/tensorboard/' + proj_id)\n",
     "\n",
     "gpath = IMAGENET.parent/(proj_id + '_gen_192.h5')\n",
     "dpath = IMAGENET.parent/(proj_id + '_critic_192.h5')\n",
@@ -75,7 +75,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 4,
+   "execution_count": null,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -90,7 +90,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 5,
+   "execution_count": null,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -100,7 +100,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 6,
+   "execution_count": null,
    "metadata": {},
    "outputs": [],
    "source": [

File diff suppressed because it is too large
+ 8 - 12
ColorizeVisualization.ipynb


+ 3 - 4
ComboVisualization.ipynb

@@ -38,7 +38,7 @@
     "plt.style.use('dark_background')\n",
     "torch.backends.cudnn.benchmark=True\n",
     "colorizer_device = torch.device('cuda:0')\n",
-    "defader_device = torch.device('cuda:3') \n"
+    "defader_device = torch.device('cuda:1') \n"
    ]
   },
   {
@@ -51,7 +51,7 @@
     "IMAGENET_SMALL = IMAGENET/'n01440764'\n",
     "\n",
     "colorizer_path = IMAGENET.parent/('colorize_gen_192.h5')\n",
-    "defader_path = IMAGENET.parent/('colorize_gen_192.h5')\n",
+    "defader_path = IMAGENET.parent/('defade_rc_gen_256.h5')\n",
     "\n",
     "default_sz=400\n",
     "torch.backends.cudnn.benchmark=True"
@@ -99,8 +99,7 @@
    "outputs": [],
    "source": [
     "x_tfms = [BlackAndWhiteTransform()]\n",
-    "data_loader = ImageGenDataLoader(sz=256, bs=8, path=IMAGENET_SMALL, random_seed=42, x_noise=False,\n",
-    "            keep_pct=1.0, x_tfms=x_tfms)\n",
+    "data_loader = ImageGenDataLoader(sz=256, bs=8, path=IMAGENET_SMALL, random_seed=42, keep_pct=1.0, x_tfms=x_tfms)\n",
     "md = data_loader.get_model_data()"
    ]
   },

+ 2 - 4
DeFadeVisualization.ipynb

@@ -48,7 +48,7 @@
    "source": [
     "IMAGENET = Path('data/imagenet/ILSVRC/Data/CLS-LOC/train')\n",
     "IMAGENET_SMALL = IMAGENET/'n01440764'\n",
-    "gpath = IMAGENET.parent/('defade_gen_192.h5')\n",
+    "gpath = IMAGENET.parent/('defade_rc_gen_256.h5')\n",
     "default_sz=400\n",
     "torch.backends.cudnn.benchmark=True"
    ]
@@ -71,9 +71,7 @@
    "outputs": [],
    "source": [
     "x_tfms = []\n",
-    "data_loader = ImageGenDataLoader(sz=256, bs=8, path=IMAGENET_SMALL, random_seed=42, x_noise=False,\n",
-    "            keep_pct=1.0, x_tfms=x_tfms)\n",
-    "\n",
+    "data_loader = ImageGenDataLoader(sz=256, bs=8, path=IMAGENET_SMALL, random_seed=42, keep_pct=1.0, x_tfms=x_tfms)\n",
     "md = data_loader.get_model_data()"
    ]
   },

+ 8 - 8
fasterai/callbacks.py

@@ -3,11 +3,10 @@ from fastai.sgdr import Callback
 from fastai.dataset import ModelData, ImageData
 from fasterai.visualize import ModelStatsVisualizer, ImageGenVisualizer, GANTrainerStatsVisualizer
 from fasterai.visualize import LearnerStatsVisualizer, ModelGraphVisualizer, ModelHistogramVisualizer
-from matplotlib.axes import Axes
 from fasterai.training import GenResult, CriticResult, GANTrainer
 from tensorboardX import SummaryWriter
 
-def clear_directory(dir: Path):
+def clear_directory(dir:Path):
     for f in dir.glob('*'):
         os.remove(f)
 
@@ -23,7 +22,7 @@ class ModelVisualizationHook():
         self.iter_count = 0
         self.model_vis = ModelStatsVisualizer() 
 
-    def forward_hook(self, module: nn.Module, input, output): 
+    def forward_hook(self, module:nn.Module, input, output): 
         self.iter_count += 1
         if self.iter_count % self.stats_iters == 0:
             self.model_vis.write_tensorboard_stats(module, iter_count=self.iter_count, tbwriter=self.tbwriter)  
@@ -34,8 +33,8 @@ class ModelVisualizationHook():
         self.hook.remove()
 
 class GANVisualizationHook():
-    def __init__(self, base_dir: Path, trainer: GANTrainer, name: str, stats_iters: int=10, 
-            visual_iters: int=200, weight_iters: int=1000, jupyter:bool=False):
+    def __init__(self, base_dir:Path, trainer:GANTrainer, name:str, stats_iters:int=10, 
+            visual_iters:int=200, weight_iters:int=1000, jupyter:bool=False):
         super().__init__()
         self.base_dir = base_dir
         self.name = name
@@ -59,7 +58,7 @@ class GANVisualizationHook():
         self.graph_vis.write_model_graph_to_tensorboard(ds=ds, model=self.trainer.netD, tbwriter=self.tbwriter) 
         self.graph_vis.write_model_graph_to_tensorboard(ds=ds, model=self.trainer.netG, tbwriter=self.tbwriter) 
 
-    def train_loop_hook(self, gresult: GenResult, cresult: CriticResult): 
+    def train_loop_hook(self, gresult:GenResult, cresult:CriticResult): 
         if self.trainer.iters % self.stats_iters == 0:
             self.stats_vis.print_stats_in_jupyter(gresult, cresult)
             self.stats_vis.write_tensorboard_stats(gresult, cresult, iter_count=self.trainer.iters, tbwriter=self.tbwriter) 
@@ -80,8 +79,8 @@ class GANVisualizationHook():
 
 
 class ModelVisualizationCallback(Callback):
-    def __init__(self, base_dir: Path, model: nn.Module,  md: ModelData, name: str, stats_iters: int=25, 
-        visual_iters: int=200, weight_iters: int=25, jupyter:bool=False):
+    def __init__(self, base_dir:Path, model:nn.Module, md:ModelData, name:str, stats_iters:int=25, 
+            visual_iters:int=200, weight_iters:int=25, jupyter:bool=False):
         super().__init__()
         self.base_dir = base_dir
         self.name = name
@@ -98,6 +97,7 @@ class ModelVisualizationCallback(Callback):
         self.learner_vis = LearnerStatsVisualizer()
         self.graph_vis = ModelGraphVisualizer()
         self.weight_vis = ModelHistogramVisualizer()
+        self.img_gen_vis = ImageGenVisualizer()
 
     def on_train_begin(self):
         self.output_model_graph()

+ 7 - 25
fasterai/dataset.py

@@ -4,11 +4,11 @@ from fastai.core import *
 
 
 class MatchedFilesDataset(FilesDataset):
-    def __init__(self, fnames, y, transform, path, x_tfms=[], x_noise=None):
+    def __init__(self, fnames:np.array, y:np.array, transforms:[Transform], path:Path, x_tfms:[Transform]=[]):
         self.y=y
         self.x_tfms=x_tfms
         assert(len(fnames)==len(y))
-        super().__init__(fnames, transform, path)
+        super().__init__(fnames, transforms, path)
     def get_x(self, i): 
         x = super().get_x(i)
         for tfm in self.x_tfms:
@@ -19,33 +19,15 @@ class MatchedFilesDataset(FilesDataset):
     def get_c(self): 
         return 0 
 
-class NoiseVectorToImageDataset(FilesDataset):
-    def __init__(self, fnames, y, transform, path:Path, x_tfms=[], x_noise=64):
-        self.y=y
-        assert(len(fnames)==len(y))
-        self.x_noise=x_noise
-        super().__init__(fnames, transform, path)
-    def get_y(self, i): 
-        return open_image(os.path.join(self.path, self.y[i]))
-    def get_x(self, i): 
-        return np.random.normal(loc=0.0, scale=1.0, size=(self.x_noise,1,1))
-    def get_c(self): 
-        return 0 
-    def get(self, tfm, x, y):
-        return (x,y) if tfm is None else (x, tfm(y,y)[1])
-
-
 class ImageGenDataLoader():
-    def __init__(self, sz:int, bs:int, path:Path, random_seed=None, x_noise:int=None, 
-            keep_pct:float=1.0, x_tfms:[Transform]=[], file_exts=('jpg','jpeg','png'), 
-            extra_aug_tfms:[Transform]=[], reduce_x_scale=1):
+    def __init__(self, sz:int, bs:int, path:Path, random_seed:int=None, keep_pct:float=1.0, x_tfms:[Transform]=[], 
+            file_exts=('jpg','jpeg','png'), extra_aug_tfms:[Transform]=[], reduce_x_scale:int=1):
         
         self.md = None
         self.sz = sz
         self.bs = bs 
         self.path = path
         self.x_tfms = x_tfms
-        self.x_noise = x_noise
         self.random_seed = random_seed
         self.keep_pct = keep_pct
         self.file_exts = file_exts
@@ -64,8 +46,8 @@ class ImageGenDataLoader():
         sz_x = self.sz//self.reduce_x_scale
         sz_y = self.sz
         tfms = (tfms_from_stats(inception_stats, sz=sz_x, sz_y=sz_y, tfm_y=TfmType.PIXEL, aug_tfms=aug_tfms))
-        dstype = NoiseVectorToImageDataset if self.x_noise is not None else MatchedFilesDataset
-        datasets = ImageData.get_ds(dstype, (trn_x,trn_y), (val_x,val_y), tfms, path=self.path, x_tfms=self.x_tfms, x_noise=self.x_noise)
+        dstype = MatchedFilesDataset
+        datasets = ImageData.get_ds(dstype, (trn_x,trn_y), (val_x,val_y), tfms, path=self.path, x_tfms=self.x_tfms)
         resize_path = os.path.join(self.path,resize_folder,str(resize_amt))
         self.md = self._load_model_data(resize_folder, resize_path, resize_amt, datasets, trn_x)
         return self.md
@@ -106,7 +88,7 @@ class ImageGenDataLoader():
         return self.sz
 
     
-    def _find_files_recursively(self, root_path: Path, extensions: (str)):
+    def _find_files_recursively(self, root_path:Path, extensions:(str)):
         matches = []
         for root, dirnames, filenames in os.walk(str(root_path)):
             for filename in filenames:

+ 67 - 8
fasterai/generators.py

@@ -7,14 +7,14 @@ class GeneratorModule(ABC, nn.Module):
     def __init__(self):
         super().__init__()
     
-    def set_trainable(self, trainable: bool):
+    def set_trainable(self, trainable:bool):
         set_trainable(self, trainable)
 
     @abstractmethod
-    def get_layer_groups(self, precompute: bool = False)->[]:
+    def get_layer_groups(self, precompute:bool=False)->[]:
         pass
 
-    def freeze_to(self, n):
+    def freeze_to(self, n:int):
         c=self.get_layer_groups()
         for l in c:     set_trainable(l, False)
         for l in c[n:]: set_trainable(l, True)
@@ -25,7 +25,7 @@ class GeneratorModule(ABC, nn.Module):
  
 class Unet34(GeneratorModule): 
     @staticmethod
-    def get_pretrained_resnet_base(layers_cut:int= 0):
+    def get_pretrained_resnet_base(layers_cut:int=0):
         f = resnet34
         cut,lr_cut = model_meta[f]
         cut-=layers_cut
@@ -42,7 +42,6 @@ class Unet34(GeneratorModule):
         self.rn, self.lr_cut = Unet34.get_pretrained_resnet_base()
         self.relu = nn.ReLU()
         self.sfs = [SaveFeatures(self.rn[i]) for i in [2,4,5,6]]
-
         self.up1 = UnetBlock(512,256,512*nf_factor, sn=sn, leakyReLu=leakyReLu, bn=bn)
         self.up2 = UnetBlock(512*nf_factor,128,512*nf_factor, sn=sn, leakyReLu=leakyReLu, bn=bn)
         self.up3 = UnetBlock(512*nf_factor,64,512*nf_factor, sn=sn, self_attention=self_attention, leakyReLu=leakyReLu, bn=bn)
@@ -51,7 +50,7 @@ class Unet34(GeneratorModule):
         self.out= nn.Sequential(ConvBlock(32*nf_factor, 3, ks=3, actn=False, bn=False, sn=sn), nn.Tanh())
 
     #Gets around irritating inconsistent halving coming from resnet
-    def _pad(self, x: torch.Tensor, target: torch.Tensor)-> torch.Tensor:
+    def _pad(self, x:torch.Tensor, target:torch.Tensor)-> torch.Tensor:
         h = x.shape[2] 
         w = x.shape[3]
 
@@ -65,7 +64,67 @@ class Unet34(GeneratorModule):
 
         return x
            
-    def forward(self, x_in: torch.Tensor):
+    def forward(self, x_in:torch.Tensor):
+        x = self.rn(x_in)
+        x = self.relu(x)
+        x = self.up1(x, self._pad(self.sfs[3].features, x))
+        x = self.up2(x, self._pad(self.sfs[2].features, x))
+        x = self.up3(x, self._pad(self.sfs[1].features, x))
+        x = self.up4(x, self._pad(self.sfs[0].features, x))
+        x = self.up5(x)
+        x = self.out(x)
+        return x
+    
+    def get_layer_groups(self, precompute:bool=False)->[]:
+        lgs = list(split_by_idxs(children(self.rn), [self.lr_cut]))
+        return lgs + [children(self)[1:]]
+    
+    def close(self):
+        for sf in self.sfs: 
+            sf.remove()
+
+
+class Unet34_V2(GeneratorModule): 
+    @staticmethod
+    def get_pretrained_resnet_base(layers_cut:int=0):
+        f = resnet34
+        cut,lr_cut = model_meta[f]
+        cut-=layers_cut
+        layers = cut_model(f(True), cut)
+        return nn.Sequential(*layers), lr_cut
+
+    def __init__(self, nf_factor:int=1, scale:int=1):
+        super().__init__()
+        assert (math.log(scale,2)).is_integer()
+        leakyReLu=False
+        self_attention=True
+        bn=True
+        sn=True
+        self.rn, self.lr_cut = Unet34.get_pretrained_resnet_base()
+        self.relu = nn.ReLU()
+        self.sfs = [SaveFeatures(self.rn[i]) for i in [2,4,5,6]]
+        self.up1 = UnetBlock(512,256,512*nf_factor, sn=sn, leakyReLu=leakyReLu, bn=bn)
+        self.up2 = UnetBlock(512*nf_factor,128,512*nf_factor, sn=sn, leakyReLu=leakyReLu, bn=bn)
+        self.up3 = UnetBlock(512*nf_factor,64,256*nf_factor, sn=sn, self_attention=self_attention, leakyReLu=leakyReLu, bn=bn)
+        self.up4 = UnetBlock(256*nf_factor,64,128*nf_factor, sn=sn, leakyReLu=leakyReLu, bn=bn)
+        self.up5 = UpSampleBlock(128*nf_factor, 32*nf_factor, 2*scale, sn=sn, leakyReLu=leakyReLu, bn=bn) 
+        self.out= nn.Sequential(ConvBlock(32*nf_factor, 3, ks=3, actn=False, bn=False, sn=sn), nn.Tanh())
+
+    #Gets around irritating inconsistent halving coming from resnet
+    def _pad(self, x:torch.Tensor, target:torch.Tensor)-> torch.Tensor:
+        h = x.shape[2] 
+        w = x.shape[3]
+        target_h = target.shape[2]*2
+        target_w = target.shape[3]*2
+
+        if h<target_h or w<target_w:
+            padh = target_h-h if target_h > h else 0
+            padw = target_w-w if target_w > w else 0
+            return F.pad(x, (0,padw,0,padh), "constant",0)
+
+        return x
+           
+    def forward(self, x_in:torch.Tensor):
         x = self.rn(x_in)
         x = self.relu(x)
         x = self.up1(x, self._pad(self.sfs[3].features, x))
@@ -76,7 +135,7 @@ class Unet34(GeneratorModule):
         x = self.out(x)
         return x
     
-    def get_layer_groups(self, precompute: bool = False)->[]:
+    def get_layer_groups(self, precompute:bool=False)->[]:
         lgs = list(split_by_idxs(children(self.rn), [self.lr_cut]))
         return lgs + [children(self)[1:]]
     

+ 5 - 23
fasterai/images.py

@@ -10,11 +10,11 @@ from datetime import datetime
 
 
 class EasyTensorImage():
-    def __init__(self, source_tensor: torch.Tensor, ds:FilesDataset):
+    def __init__(self, source_tensor:torch.Tensor, ds:FilesDataset):
         self.array = self._convert_to_denormed_ndarray(source_tensor, ds=ds)   
         self.tensor = self._convert_to_denormed_tensor(self.array)
     
-    def _convert_to_denormed_ndarray(self, raw_tensor: torch.Tensor, ds:FilesDataset):
+    def _convert_to_denormed_ndarray(self, raw_tensor:torch.Tensor, ds:FilesDataset):
         raw_array = raw_tensor.clone().data.cpu().numpy()
         if raw_array.shape[1] != 3:
             array = np.zeros((3, 1, 1))
@@ -27,46 +27,28 @@ class EasyTensorImage():
 
 class ModelImageSet():
     @staticmethod
-    def get_list_from_model(ds: FilesDataset, model: nn.Module, idxs:[int]):
+    def get_list_from_model(ds:FilesDataset, model:nn.Module, idxs:[int]):
         image_sets = []
-        rand = ModelImageSet._is_random_vector(ds[0][0])
         training = model.training
         model.eval()
         
         for idx in idxs:
             x,y=ds[idx]
-
-            if rand: 
-                #Making fixed noise, for consistent output
-                np.random.seed(idx)
-                orig_tensor = VV(np.random.normal(loc=0.0, scale=1.0, size=(1, x.shape[0],1,1)))
-            else:
-                orig_tensor = VV(x[None]) 
-
+            orig_tensor = VV(x[None]) 
             real_tensor = V(y[None])
             gen_tensor = model(orig_tensor)
-
             gen_easy = EasyTensorImage(gen_tensor, ds)
             orig_easy = EasyTensorImage(orig_tensor, ds)
             real_easy = EasyTensorImage(real_tensor, ds)
-
             image_set = ModelImageSet(orig_easy,real_easy,gen_easy)
             image_sets.append(image_set)
-        
-        #reseting noise back to random random
-        if rand:
-            np.random.seed()
 
         if training:
             model.train()
 
         return image_sets  
 
-    @staticmethod
-    def _is_random_vector(x):
-        return x.shape[0] != 3
-
-    def __init__(self, orig: EasyTensorImage, real: EasyTensorImage, gen: EasyTensorImage):
+    def __init__(self, orig:EasyTensorImage, real:EasyTensorImage, gen:EasyTensorImage):
         self.orig=orig
         self.real=real
         self.gen=gen

+ 5 - 7
fasterai/loss.py

@@ -7,12 +7,10 @@ import torchvision.models as models
 
 
 class FeatureLoss(nn.Module):
-    def __init__(self, block_wgts: [float] = [0.2,0.7,0.1], multiplier:float=1.0):
+    def __init__(self, block_wgts:[float]=[0.2,0.7,0.1], multiplier:float=1.0):
         super().__init__()
-        m_vgg = vgg16(True)
-        
-        blocks = [i-1 for i,o in enumerate(children(m_vgg))
-              if isinstance(o,nn.MaxPool2d)]
+        m_vgg = vgg16(True)  
+        blocks = [i-1 for i,o in enumerate(children(m_vgg)) if isinstance(o,nn.MaxPool2d)]
         blocks, [m_vgg[i] for i in blocks]
         layer_ids = blocks[:3]
         
@@ -24,7 +22,7 @@ class FeatureLoss(nn.Module):
         self.sfs = [SaveFeatures(m_vgg[i]) for i in layer_ids]
         self.multiplier = multiplier
 
-    def forward(self, input, target, sum_layers=True):
+    def forward(self, input, target, sum_layers:bool=True):
         self.m(VV(target.data))
         res = [F.l1_loss(input,target)/100]
         targ_feat = [V(o.features.data.clone()) for o in self.sfs]
@@ -34,7 +32,7 @@ class FeatureLoss(nn.Module):
         if sum_layers: res = sum(res)
         return res*self.multiplier
     
-    def _flatten(self, x): 
+    def _flatten(self, x:torch.Tensor): 
         return x.view(x.size(0), -1)
     
     def close(self):

+ 10 - 125
fasterai/modules.py

@@ -26,60 +26,16 @@ class ConvBlock(nn.Module):
     def forward(self, x):
         return self.seq(x)
 
-class MeanPoolConv(nn.Module):
-    def __init__(self, ni, no, sn:bool=False, leakyReLu:bool=False):
-        super(MeanPoolConv, self).__init__()
-        self.conv = ConvBlock(ni, no, pad=0, ks=1, bn=False, sn=sn, leakyReLu=leakyReLu)
-
-    def forward(self, input):
-        output = input
-        output = (output[:,:,::2,::2] + output[:,:,1::2,::2] + output[:,:,::2,1::2] + output[:,:,1::2,1::2]) / 4
-        output = self.conv(output)
-        return output
-
-class ConvPoolMean(nn.Module):
-    def __init__(self, ni, no, ks:int=3, sn:bool=False, leakyReLu:bool=False):
-        super(ConvPoolMean, self).__init__()
-        self.conv = ConvBlock(ni, no, ks=ks, bn=False, sn=sn, leakyReLu=leakyReLu)
-
-    def forward(self, input):
-        output = input
-        output = self.conv(output)
-        output = (output[:,:,::2,::2] + output[:,:,1::2,::2] + output[:,:,::2,1::2] + output[:,:,1::2,1::2]) / 4
-        return output
-
-class DeconvBlock(nn.Module):
-    def __init__(self, ni:int, no:int, ks:int, stride:int, pad:int, bn:bool=True, 
-            sn:bool=False, leakyReLu:bool=False, self_attention:bool=False, inplace_relu:bool=True):
-        super().__init__()
-
-        layers=[]
-
-        conv = nn.ConvTranspose2d(ni, no, ks, stride, padding=pad, bias=False)
-        if sn: 
-            conv = spectral_norm(conv)
-        layers.append(conv)
-        if bn:
-            layers.append(nn.BatchNorm2d(no))
-        if leakyReLu:
-            layers.append(nn.LeakyReLU(0.2, inplace=inplace_relu))
-        else:
-            layers.append(nn.ReLU(inplace=inplace_relu))
-        if self_attention:
-            layers.append(SelfAttention(no, 1))
-        self.out=nn.Sequential(*layers)
-        
-    def forward(self, x):
-        return self.out(x)
 
 class UpSampleBlock(nn.Module):
     @staticmethod
-    def _conv(ni: int, nf: int, ks: int=3, bn=True, sn=False, leakyReLu:bool=False):
+    def _conv(ni:int, nf:int, ks:int=3, bn:bool=True, sn:bool=False, leakyReLu:bool=False):
         layers = [ConvBlock(ni, nf, ks=ks, sn=sn, bn=bn, actn=False, leakyReLu=leakyReLu)]
         return nn.Sequential(*layers)
 
     @staticmethod
-    def _icnr(x:torch.Tensor, scale:int =2, init=nn.init.kaiming_normal_):
+    def _icnr(x:torch.Tensor, scale:int=2):
+        init=nn.init.kaiming_normal_
         new_shape = [int(x.shape[0] / (scale ** 2))] + list(x.shape[1:])
         subkernel = torch.zeros(new_shape)
         subkernel = init(subkernel)
@@ -95,7 +51,6 @@ class UpSampleBlock(nn.Module):
     def __init__(self, ni:int, nf:int, scale:int=2, ks:int=3, bn:bool=True, sn:bool=False, leakyReLu:bool=False):
         super().__init__()
         layers = []
-
         assert (math.log(scale,2)).is_integer()
 
         for i in range(int(math.log(scale,2))):
@@ -118,70 +73,9 @@ class UpSampleBlock(nn.Module):
         return self.sequence(x)
 
 
-class ResSequential(nn.Module):
-    def __init__(self, layers:[], res_scale:float=1.0):
-        super().__init__()
-        self.res_scale = res_scale
-        self.m = nn.Sequential(*layers)
-
-    def forward(self, x): 
-        return x + self.m(x) * self.res_scale
-
-class ResBlock(nn.Module):
-    def __init__(self, nf:int, ks:int=3, res_scale:float=1.0, dropout:float=0.5, bn:bool=True, sn:bool=False, leakyReLu:bool=False):
-        super().__init__()
-        layers = []
-        nf_bottleneck = nf//4
-        self.res_scale = res_scale
-        layers.append(ConvBlock(nf, nf_bottleneck, ks=ks, bn=bn, sn=sn, leakyReLu=leakyReLu))
-        layers.append(nn.Dropout2d(dropout))
-        layers.append(ConvBlock(nf_bottleneck, nf, ks=ks, actn=False, bn=False, sn=sn))
-        self.mid = nn.Sequential(*layers)
-        self.relu = nn.LeakyReLU(0.2) if leakyReLu else nn.ReLU()
-    
-    def forward(self, x):
-        x = self.mid(x)*self.res_scale+x
-        x = self.relu(x)
-        return x
-
-
-class DownSampleResBlock(nn.Module):
-    def __init__(self, ni:int, nf:int, res_scale:float=1.0, dropout:float=0.5, bn:bool=True, sn:bool=False, leakyReLu:bool=False,
-            inplace_relu:bool=True):
-        super().__init__()
-        self.res_scale = res_scale
-        layers = []
-        layers.append(ConvBlock(ni, nf, ks=4, stride=2, bn=bn, sn=sn, leakyReLu=leakyReLu))
-        layers.append(nn.Dropout2d(dropout))
-
-        self.mid = nn.Sequential(*layers)
-        self.mid_shortcut = MeanPoolConv(ni, nf, sn=sn, leakyReLu=leakyReLu)
-        self.relu = nn.LeakyReLU(0.2, inplace=inplace_relu) if leakyReLu else nn.ReLU(inplace=inplace_relu)
-    
-    def forward(self, x):
-        x = self.mid(x)*self.res_scale + self.mid_shortcut(x)*self.res_scale
-        x = self.relu(x)
-        return x
-
-class FilterScalingBlock(nn.Module):
-    def __init__(self, ni:int, nf:int, ks:int=3, res_scale:float=1.0, dropout:float=0.5, bn:bool=True, sn:bool=False, leakyReLu:bool=False,
-            inplace_relu:bool=True):
-        super().__init__()
-        self.res_scale = res_scale
-        layers = []
-        layers.append(ConvBlock(ni, nf, ks=1, bn=bn, sn=sn, leakyReLu=leakyReLu))
-        layers.append(nn.Dropout2d(dropout))
-        self.mid = nn.Sequential(*layers)
-        self.relu = nn.LeakyReLU(0.2, inplace=inplace_relu) if leakyReLu else nn.ReLU(inplace=inplace_relu)
-    
-    def forward(self, x):
-        x = self.mid(x)*self.res_scale
-        x = self.relu(x)
-        return x 
-
 class UnetBlock(nn.Module):
     def __init__(self, up_in:int , x_in:int , n_out:int, bn:bool=True, sn:bool=False, leakyReLu:bool=False, 
-        self_attention:bool=False, inplace_relu:bool=True):
+            self_attention:bool=False, inplace_relu:bool=True):
         super().__init__()
         up_out = x_out = n_out//2
         self.x_conv  = ConvBlock(x_in,  x_out,  ks=1, bn=False, actn=False, sn=sn, inplace_relu=inplace_relu)
@@ -203,14 +97,6 @@ class UnetBlock(nn.Module):
         return self.out(x)
         return out
 
-def get_pretrained_resnet_base(layers_cut:int= 0):
-    f = resnet34
-    cut,lr_cut = model_meta[f]
-    cut-=layers_cut
-    layers = cut_model(f(True), cut)
-    return nn.Sequential(*layers), lr_cut
-
-
 class SaveFeatures():
     features=None
     def __init__(self, m:nn.Module): 
@@ -221,21 +107,21 @@ class SaveFeatures():
         self.hook.remove()
 
 class SelfAttention(nn.Module):
-    def __init__(self, in_channel, gain=1):
+    def __init__(self, in_channel:int, gain:int=1):
         super().__init__()
-        self.query = self.spectral_init(nn.Conv1d(in_channel, in_channel // 8, 1),gain=gain)
-        self.key = self.spectral_init(nn.Conv1d(in_channel, in_channel // 8, 1),gain=gain)
-        self.value = self.spectral_init(nn.Conv1d(in_channel, in_channel, 1), gain=gain)
+        self.query = self._spectral_init(nn.Conv1d(in_channel, in_channel // 8, 1),gain=gain)
+        self.key = self._spectral_init(nn.Conv1d(in_channel, in_channel // 8, 1),gain=gain)
+        self.value = self._spectral_init(nn.Conv1d(in_channel, in_channel, 1), gain=gain)
         self.gamma = nn.Parameter(torch.tensor(0.0))
 
-    def spectral_init(self, module, gain=1):
+    def _spectral_init(self, module:nn.Module, gain:int=1):
         nn.init.kaiming_uniform_(module.weight, gain)
         if module.bias is not None:
             module.bias.data.zero_()
 
         return spectral_norm(module)
 
-    def forward(self, input):
+    def forward(self, input:torch.Tensor):
         shape = input.shape
         flatten = input.view(shape[0], shape[1], -1)
         query = self.query(flatten).permute(0, 2, 1)
@@ -246,5 +132,4 @@ class SelfAttention(nn.Module):
         attn = torch.bmm(value, attn)
         attn = attn.view(*shape)
         out = self.gamma * attn + input
-
         return out

+ 18 - 19
fasterai/training.py

@@ -31,7 +31,6 @@ class CriticModule(ABC, nn.Module):
         return next(self.parameters()).device
 
 class DCCritic(CriticModule):
-
     def _generate_reduce_layers(self, nf:int):
         layers=[]
         layers.append(nn.Dropout2d(0.5))
@@ -74,13 +73,13 @@ class DCCritic(CriticModule):
 
 
 class GenResult():
-    def __init__(self, gcost: np.array, iters: int, gaddlloss: np.array):
+    def __init__(self, gcost:np.array, iters:int, gaddlloss:np.array):
         self.gcost=gcost
         self.iters=iters
         self.gaddlloss=gaddlloss
 
 class CriticResult():
-    def __init__(self, hingeloss: np.array, dreal: np.array, dfake: np.array, dcost: np.array):
+    def __init__(self, hingeloss:np.array, dreal:np.array, dfake:np.array, dcost:np.array):
         self.hingeloss=hingeloss
         self.dreal=dreal
         self.dfake=dfake
@@ -91,8 +90,8 @@ class GANTrainSchedule():
     @staticmethod
     def generate_schedules(szs:[int], bss:[int], path:Path, keep_pcts:[float], save_base_name:str, 
         c_lrs:[float], g_lrs:[float], gen_freeze_tos:[int], lrs_unfreeze_factor:float=0.1, 
-        x_noise:int=None, random_seed=None, x_tfms:[Transform]=[], extra_aug_tfms:[Transform]=[],
-        reduce_x_scale=1):
+        random_seed:int=None, x_tfms:[Transform]=[], extra_aug_tfms:[Transform]=[],
+        reduce_x_scale:int=1):
 
         scheds = []
 
@@ -107,7 +106,7 @@ class GANTrainSchedule():
             gen_save_path = path.parent/(save_base_name + '_gen_' + str(sz) + '.h5')
             sched = GANTrainSchedule(sz=sz, bs=bs, path=path, critic_lrs=critic_lrs, gen_lrs=gen_lrs,
                 critic_save_path=critic_save_path, gen_save_path=gen_save_path, random_seed=random_seed,
-                x_noise=x_noise, keep_pct=keep_pct, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms,  
+                keep_pct=keep_pct, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms,  
                 reduce_x_scale=reduce_x_scale, gen_freeze_to=gen_freeze_to)
             scheds.append(sched)
         
@@ -115,12 +114,12 @@ class GANTrainSchedule():
 
 
     def __init__(self, sz:int, bs:int, path:Path, critic_lrs:[float], gen_lrs:[float],
-            critic_save_path: Path, gen_save_path: Path, random_seed=None, x_noise:int=None, 
-            keep_pct:float=1.0, num_epochs=1, x_tfms:[Transform]=[], extra_aug_tfms:[Transform]=[], 
-            reduce_x_scale=1, gen_freeze_to=0):
+            critic_save_path:Path, gen_save_path:Path, random_seed:int=None, 
+            keep_pct:float=1.0, num_epochs:int=1, x_tfms:[Transform]=[], extra_aug_tfms:[Transform]=[], 
+            reduce_x_scale:int=1, gen_freeze_to:int=0):
         self.md = None
 
-        self.data_loader = ImageGenDataLoader(sz=sz, bs=bs, path=path, random_seed=random_seed, x_noise=x_noise,
+        self.data_loader = ImageGenDataLoader(sz=sz, bs=bs, path=path, random_seed=random_seed,
             keep_pct=keep_pct, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, reduce_x_scale=reduce_x_scale)
         self.sz = sz
         self.bs = bs
@@ -184,7 +183,7 @@ class GANTrainer():
     def _get_inner_module(self, model:nn.Module):
         return model.module if isinstance(model, nn.DataParallel) else model
 
-    def _generate_clr_sched(self, model:nn.Module, use_clr_beta: (int), lrs: [float], cycle_len: int):
+    def _generate_clr_sched(self, model:nn.Module, use_clr_beta:(int), lrs:[float], cycle_len:int):
         wds = 1e-7
         opt_fn = partial(optim.Adam, betas=(0.0,0.9))
         layer_opt = LayerOptimizer(opt_fn, self._get_inner_module(model).get_layer_groups(), lrs, wds)
@@ -223,7 +222,7 @@ class GANTrainer():
                     "train begin hooks should never return any values, but '{}'"
                     "didn't return None".format(hook))
 
-    def _call_train_loop_hooks(self, gresult: GenResult, cresult: CriticResult):
+    def _call_train_loop_hooks(self, gresult:GenResult, cresult:CriticResult):
         for hook in self._train_loop_hooks.values():
             hook_result = hook(gresult, cresult)
             if hook_result is not None:
@@ -231,7 +230,7 @@ class GANTrainer():
                     "train loop hooks should never return any values, but '{}'"
                     "didn't return None".format(hook))
 
-    def _get_next_training_images(self, data_iter: Iterable)->(torch.Tensor,torch.Tensor):
+    def _get_next_training_images(self, data_iter:Iterable)->(torch.Tensor,torch.Tensor):
         x, y = next(data_iter, (None, None))
 
         if x is None or y is None:
@@ -242,7 +241,7 @@ class GANTrainer():
         return (orig_image, real_image)
 
 
-    def _train_critic(self, data_iter: Iterable, pbar: tqdm)->CriticResult:
+    def _train_critic(self, data_iter:Iterable, pbar:tqdm)->CriticResult:
         self._get_inner_module(self.netD).set_trainable(True)
         self._get_inner_module(self.netG).set_trainable(False)
         orig_image, real_image = self._get_next_training_images(data_iter)
@@ -253,7 +252,7 @@ class GANTrainer():
         pbar.update()
         return cresult
 
-    def _train_critic_once(self, orig_image: torch.Tensor, real_image: torch.Tensor)->CriticResult:                     
+    def _train_critic_once(self, orig_image:torch.Tensor, real_image:torch.Tensor)->CriticResult:                     
         fake_image = self.netG(orig_image)
         dfake_raw,_ = self.netD(fake_image)
         dfake = torch.nn.ReLU()(1.0+dfake_raw).mean()
@@ -267,7 +266,7 @@ class GANTrainer():
         self.gen_sched.on_batch_end(to_np(-dfake))
         return CriticResult(to_np(hingeloss), to_np(dreal), to_np(dfake), to_np(hingeloss))
 
-    def _train_generator(self, data_iter: Iterable, pbar: tqdm, cresult: CriticResult)->GenResult:
+    def _train_generator(self, data_iter:Iterable, pbar:tqdm, cresult:CriticResult)->GenResult:
         orig_image, real_image = self._get_next_training_images(data_iter)   
         if orig_image is None:
             return None
@@ -276,7 +275,7 @@ class GANTrainer():
         pbar.update() 
         return gresult
 
-    def _train_generator_once(self, orig_image: torch.Tensor, real_image: torch.Tensor, 
+    def _train_generator_once(self, orig_image:torch.Tensor, real_image:torch.Tensor, 
             cresult: CriticResult)->GenResult:
         self._get_inner_module(self.netD).set_trainable(False)
         self._get_inner_module(self.netG).set_trainable(True)
@@ -297,12 +296,12 @@ class GANTrainer():
             save_model(self.netD, self.dpath)
             save_model(self.netG, self.gpath)
 
-    def _get_dscore(self, new_image: torch.Tensor):
+    def _get_dscore(self, new_image:torch.Tensor):
         scores, _ = self.netD(new_image)
         return scores.mean()
     
 
-    def _calc_addl_gen_loss(self, real_data: torch.Tensor, fake_data: torch.Tensor)->torch.Tensor:
+    def _calc_addl_gen_loss(self, real_data:torch.Tensor, fake_data:torch.Tensor)->torch.Tensor:
         total_loss = V(0.0)
         for loss_fn in self.genloss_fns:
             loss = loss_fn(fake_data, real_data)

+ 18 - 16
fasterai/visualize.py

@@ -17,7 +17,7 @@ class ModelImageVisualizer():
     def __init__(self, default_sz:int=500):
         self.default_sz=default_sz 
 
-    def plot_transformed_image(self, path: Path, model: nn.Module, ds:FilesDataset, figsize=(20,20), sz:int=None, 
+    def plot_transformed_image(self, path:Path, model:nn.Module, ds:FilesDataset, figsize:(int,int)=(20,20), sz:int=None, 
             tfms:[Transform]=[], compare:bool=True):
         result = self.get_transformed_image_ndarray(path, model,ds, sz, tfms=tfms)
         if compare: 
@@ -28,7 +28,7 @@ class ModelImageVisualizer():
         else:
             self.plot_image_from_ndarray(result, figsize=figsize)
 
-    def get_transformed_image_ndarray(self, path: Path, model: nn.Module, ds:FilesDataset, sz:int=None, tfms:[Transform]=[]):
+    def get_transformed_image_ndarray(self, path:Path, model:nn.Module, ds:FilesDataset, sz:int=None, tfms:[Transform]=[]):
         training = model.training 
         model.eval()
         orig = self.get_model_ready_image_ndarray(path, model, ds, sz, tfms)
@@ -39,7 +39,7 @@ class ModelImageVisualizer():
             model.train()
         return result[0]
 
-    def _transform(self, orig, tfms:[Transform], model: nn.Module, sz:int):
+    def _transform(self, orig:ndarray, tfms:[Transform], model:nn.Module, sz:int):
         for tfm in tfms:
             orig,_=tfm(orig, False)
         _,val_tfms = tfms_from_stats(inception_stats, sz, crop_type=CropType.NO, aug_tfms=[])
@@ -47,14 +47,14 @@ class ModelImageVisualizer():
         orig = val_tfms(orig)
         return orig
 
-    def get_model_ready_image_ndarray(self, path: Path, model: nn.Module, ds:FilesDataset, sz:int=None, tfms:[Transform]=[]):
+    def get_model_ready_image_ndarray(self, path:Path, model:nn.Module, ds:FilesDataset, sz:int=None, tfms:[Transform]=[]):
         im = open_image(str(path))
         sz = self.default_sz if sz is None else sz
         im = scale_min(im, sz)
         im = self._transform(im, tfms, model, sz)
         return im
 
-    def plot_image_from_ndarray(self, image: ndarray, axes:Axes=None, figsize=(20,20)):
+    def plot_image_from_ndarray(self, image:ndarray, axes:Axes=None, figsize=(20,20)):
         if axes is None: 
             _,axes = plt.subplots(figsize=figsize)
         clipped_image =np.clip(image,0,1)
@@ -62,7 +62,8 @@ class ModelImageVisualizer():
         axes.axis('off')
 
 
-    def plot_images_from_image_sets(self, image_sets: [ModelImageSet], validation:bool, figsize=(20,20), max_columns=6, immediate_display=True):
+    def plot_images_from_image_sets(self, image_sets:[ModelImageSet], validation:bool, figsize:(int,int)=(20,20), 
+            max_columns:int=6, immediate_display:bool=True):
         num_sets = len(image_sets)
         num_images = num_sets * 2
         rows, columns = self._get_num_rows_columns(num_images, max_columns)
@@ -79,11 +80,12 @@ class ModelImageVisualizer():
             display(fig)
 
 
-    def plot_image_outputs_from_model(self, ds: FilesDataset, model: nn.Module, idxs: [int], figsize=(20,20), max_columns=6, immediate_display=True):
+    def plot_image_outputs_from_model(self, ds:FilesDataset, model:nn.Module, idxs:[int], figsize:(int,int)=(20,20), max_columns:int=6, 
+            immediate_display:bool=True):
         image_sets = ModelImageSet.get_list_from_model(ds=ds, model=model, idxs=idxs)
         self.plot_images_from_image_sets(image_sets=image_sets, figsize=figsize, max_columns=max_columns, immediate_display=immediate_display)
 
-    def _get_num_rows_columns(self, num_images: int, max_columns: int):
+    def _get_num_rows_columns(self, num_images:int, max_columns:int):
         columns = min(num_images, max_columns)
         rows = num_images//columns
         rows = rows if rows * columns == num_images else rows + 1
@@ -94,7 +96,7 @@ class ModelGraphVisualizer():
     def __init__(self):
         return 
      
-    def write_model_graph_to_tensorboard(self, ds: FilesDataset, model: nn.Module, tbwriter: SummaryWriter):
+    def write_model_graph_to_tensorboard(self, ds:FilesDataset, model:nn.Module, tbwriter:SummaryWriter):
         try:
             x,_=ds[0]
             tbwriter.add_graph(model, V(x[None]))
@@ -106,7 +108,7 @@ class ModelHistogramVisualizer():
     def __init__(self):
         return 
 
-    def write_tensorboard_histograms(self, model: nn.Module, iter_count:int, tbwriter: SummaryWriter):
+    def write_tensorboard_histograms(self, model:nn.Module, iter_count:int, tbwriter:SummaryWriter):
         for name, param in model.named_parameters():
             tbwriter.add_histogram('/weights/' + name, param, iter_count)
     
@@ -116,7 +118,7 @@ class ModelStatsVisualizer():
     def __init__(self):
         return 
 
-    def write_tensorboard_stats(self, model: nn.Module, iter_count:int, tbwriter: SummaryWriter):
+    def write_tensorboard_stats(self, model:nn.Module, iter_count:int, tbwriter:SummaryWriter):
         gradients = [x.grad  for x in model.parameters() if x.grad is not None]
         gradient_nps = [to_np(x.data) for x in gradients]
  
@@ -155,11 +157,11 @@ class ImageGenVisualizer():
     def __init__(self):
         self.model_vis = ModelImageVisualizer()
 
-    def output_image_gen_visuals(self, md: ImageData, model: nn.Module, iter_count:int, tbwriter: SummaryWriter, jupyter:bool=False):
+    def output_image_gen_visuals(self, md:ImageData, model:nn.Module, iter_count:int, tbwriter:SummaryWriter, jupyter:bool=False):
         self._output_visuals(ds=md.val_ds, model=model, iter_count=iter_count, tbwriter=tbwriter, jupyter=jupyter, validation=True)
         self._output_visuals(ds=md.trn_ds, model=model, iter_count=iter_count, tbwriter=tbwriter, jupyter=jupyter, validation=False)
 
-    def _output_visuals(self, ds: FilesDataset, model: nn.Module, iter_count:int, tbwriter: SummaryWriter, 
+    def _output_visuals(self, ds:FilesDataset, model:nn.Module, iter_count:int, tbwriter:SummaryWriter, 
             validation:bool, jupyter:bool=False):
         #TODO:  Parameterize these
         start_idx=0
@@ -171,7 +173,7 @@ class ImageGenVisualizer():
         if jupyter:
             self._show_images_in_jupyter(image_sets, validation=validation)
     
-    def _write_tensorboard_images(self, image_sets:[ModelImageSet], iter_count:int, tbwriter: SummaryWriter, validation:bool):
+    def _write_tensorboard_images(self, image_sets:[ModelImageSet], iter_count:int, tbwriter:SummaryWriter, validation:bool):
         orig_images = []
         gen_images = []
         real_images = []
@@ -201,7 +203,7 @@ class GANTrainerStatsVisualizer():
     def __init__(self):
         return
 
-    def write_tensorboard_stats(self, gresult: GenResult, cresult: CriticResult, iter_count:int, tbwriter: SummaryWriter):
+    def write_tensorboard_stats(self, gresult:GenResult, cresult:CriticResult, iter_count:int, tbwriter:SummaryWriter):
         tbwriter.add_scalar('/loss/hingeloss', cresult.hingeloss, iter_count)
         tbwriter.add_scalar('/loss/dfake', cresult.dfake, iter_count)
         tbwriter.add_scalar('/loss/dreal', cresult.dreal, iter_count)
@@ -209,7 +211,7 @@ class GANTrainerStatsVisualizer():
         tbwriter.add_scalar('/loss/gcount', gresult.iters, iter_count)
         tbwriter.add_scalar('/loss/gaddlloss', gresult.gaddlloss, iter_count)
 
-    def print_stats_in_jupyter(self, gresult: GenResult, cresult: CriticResult):
+    def print_stats_in_jupyter(self, gresult:GenResult, cresult:CriticResult):
         print(f'\nHingeLoss {cresult.hingeloss}; RScore {cresult.dreal}; FScore {cresult.dfake}; GAddlLoss {gresult.gaddlloss}; ' + 
                 f'Iters: {gresult.iters}; GCost: {gresult.gcost};')
 

Some files were not shown because too many files changed in this diff