Procházet zdrojové kódy

Initial commit of new filter functionality (big!)

Includes chrominance/luminance optimzation, change to using render_factor instead of size for easier "optimal result image" search;  changing to rendering in model using stretched full squares for real world images.  Al these changes amount to huge bump in default quality, more flexibility in hardware requirements, and graceful degradation in rendering as memory is descreased.
Jason Antic před 6 roky
rodič
revize
dabb3a00ed

Rozdílová data souboru nebyla zobrazena, protože soubor je příliš velký
+ 113 - 113
ColorizeVisualization.ipynb


+ 0 - 0
fasterai/__init__.py


+ 4 - 3
fasterai/callbacks.py

@@ -4,6 +4,7 @@ from fastai.dataset import ModelData, ImageData
 from .visualize import ModelStatsVisualizer, ImageGenVisualizer, GANTrainerStatsVisualizer
 from .visualize import LearnerStatsVisualizer, ModelGraphVisualizer, ModelHistogramVisualizer
 from .training import GenResult, CriticResult, GANTrainer
+from .generators import GeneratorModule
 from tensorboardX import SummaryWriter
 
 def clear_directory(dir:Path):
@@ -66,7 +67,7 @@ class GANVisualizationHook():
         if self.trainer.iters % self.visual_iters == 0:
             model = self.trainer.netG
             self.img_gen_vis.output_image_gen_visuals(md=self.trainer.md, model=model, iter_count=self.trainer.iters, 
-                tbwriter=self.tbwriter, jupyter=self.jupyter)
+                tbwriter=self.tbwriter)
 
         if self.trainer.iters % self.weight_iters == 0:
             self.weight_vis.write_tensorboard_histograms(model=self.trainer.netG, iter_count=self.trainer.iters, tbwriter=self.tbwriter)
@@ -79,7 +80,7 @@ class GANVisualizationHook():
 
 
 class ModelVisualizationCallback(Callback):
-    def __init__(self, base_dir:Path, model:nn.Module, md:ModelData, name:str, stats_iters:int=25, 
+    def __init__(self, base_dir:Path, model:GeneratorModule, md:ModelData, name:str, stats_iters:int=25, 
             visual_iters:int=200, weight_iters:int=25, jupyter:bool=False):
         super().__init__()
         self.base_dir = base_dir
@@ -146,7 +147,7 @@ class ModelVisualizationCallback(Callback):
         self.tbwriter.close()
 
 class ImageGenVisualizationCallback(ModelVisualizationCallback):
-    def __init__(self, base_dir: Path, model: nn.Module,  md: ImageData, name: str, stats_iters: int=25, visual_iters: int=200, jupyter:bool=False):
+    def __init__(self, base_dir: Path, model: GeneratorModule,  md: ImageData, name: str, stats_iters: int=25, visual_iters: int=200, jupyter:bool=False):
         super().__init__(base_dir=base_dir, model=model,  md=md, name=name, stats_iters=stats_iters, visual_iters=visual_iters, jupyter=jupyter)
         self.img_gen_vis = ImageGenVisualizer()
 

+ 100 - 0
fasterai/filters.py

@@ -0,0 +1,100 @@
+from numpy import ndarray
+from abc import ABC, abstractmethod
+from .generators import Unet34, GeneratorModule
+from .transforms import BlackAndWhiteTransform
+from fastai.torch_imports import *
+from fastai.core import *
+from fastai.transforms import Transform, scale_min, tfms_from_stats, inception_stats
+from fastai.transforms import CropType, NoCrop, Denormalize, Scale, scale_to
+import math
+from scipy import misc
+
+class Padding():
+    def __init__(self, top:int, bottom:int, left:int, right:int):
+        self.top = top
+        self.bottom = bottom
+        self.left = left
+        self.right = right
+  
+class Filter(ABC):
+    def __init__(self, tfms:[Transform]):
+        super().__init__()
+        self.tfms=tfms
+        self.denorm = Denormalize(*inception_stats)
+    
+    @abstractmethod
+    def filter(self, orig_image:ndarray, render_factor:int)->ndarray:
+        pass
+
+    def _transform(self, orig:ndarray, sz:int):
+        for tfm in self.tfms:
+            orig,_=tfm(orig, False)
+        _,val_tfms = tfms_from_stats(inception_stats, sz, crop_type=CropType.NO, aug_tfms=[])
+        val_tfms.tfms = [tfm for tfm in val_tfms.tfms if not (isinstance(tfm, NoCrop) or isinstance(tfm, Scale))]
+        orig = val_tfms(orig)
+        return orig
+
+    def _scale_to_square(self, orig:ndarray, targ:int, interpolation=cv2.INTER_AREA):
+        r,c,*_ = orig.shape
+        ratio = targ/max(r,c)
+        #a simple stretch to fit a square really makes a big difference in rendering quality/consistency.
+        #I've tried padding to the square as well (reflect, symetric, constant, etc).  Not as good!
+        sz = (targ, targ)
+        return cv2.resize(orig, sz, interpolation=interpolation)
+
+    def _get_model_ready_image_ndarray(self, orig:ndarray, sz:int):
+        result = self._scale_to_square(orig, sz)
+        sz=result.shape[0]
+        result = self._transform(result, sz)
+        return result
+
+    def _denorm(self, image: ndarray):
+        if len(image.shape)==3: 
+            image = image[None]
+        return self.denorm(np.rollaxis(image,1,4))
+
+    def _model_process(self, model:GeneratorModule, orig:ndarray, sz:int):
+        orig = self._get_model_ready_image_ndarray(orig, sz)
+        orig = VV_(orig[None]) 
+        result = model(orig)
+        result = result.detach().cpu().numpy()
+        result = self._denorm(result)
+        return result[0]
+
+    def _convert_to_pil(self, im_array:ndarray):
+        im_array = np.clip(im_array,0,1)
+        return misc.toimage(im_array)
+
+
+class Colorizer(Filter):
+    def __init__(self, gpu:int, weights_path:Path):
+        super().__init__(tfms=[BlackAndWhiteTransform()])
+        self.model = Unet34(nf_factor=2).cuda(gpu)
+        load_model(self.model, weights_path)
+        self.model.eval()
+        torch.no_grad()
+        self.render_base = 32
+    
+    def filter(self, orig_image:ndarray, render_factor:int=14)->ndarray:
+        render_sz = render_factor * self.render_base
+        model_image = self._model_process(self.model, orig=orig_image, sz=render_sz)
+        return self._post_process(model_image, orig_image)
+
+
+    #This takes advantage of the fact that human eyes are much less sensitive to 
+    #imperfections in chrominance compared to luminance.  This means we can
+    #save a lot on memory and processing in the model, yet get a great high
+    #resolution result at the end.  This is primarily intended just for 
+    #inference
+    def _post_process(self, raw_color:ndarray, orig:ndarray):
+        for tfm in self.tfms:
+            orig,_=tfm(orig, False)
+
+        sz = (orig.shape[1], orig.shape[0])
+        raw_color = cv2.resize(raw_color, sz, interpolation=cv2.INTER_AREA)
+        color_yuv = cv2.cvtColor(raw_color, cv2.COLOR_BGR2YUV)
+        #do a black and white transform first to get better luminance values
+        orig_yuv = cv2.cvtColor(orig, cv2.COLOR_BGR2YUV)
+        hires = np.copy(orig_yuv)
+        hires[:,:,1:3] = color_yuv[:,:,1:3]
+        return cv2.cvtColor(hires, cv2.COLOR_YUV2BGR)    

+ 34 - 19
fasterai/generators.py

@@ -1,7 +1,9 @@
 from fastai.core import *
 from fastai.conv_learner import model_meta, cut_model
+from fastai.transforms import scale_min
 from .modules import ConvBlock, UnetBlock, UpSampleBlock, SaveFeatures
 from abc import ABC, abstractmethod
+from torchvision import transforms
 
 class GeneratorModule(ABC, nn.Module):
     def __init__(self):
@@ -14,6 +16,10 @@ class GeneratorModule(ABC, nn.Module):
     def get_layer_groups(self, precompute:bool=False)->[]:
         pass
 
+    @abstractmethod
+    def forward(self, x_in:torch.Tensor, max_render_sz:int=400):
+        pass
+
     def freeze_to(self, n:int):
         c=self.get_layer_groups()
         for l in c:     set_trainable(l, False)
@@ -22,10 +28,10 @@ class GeneratorModule(ABC, nn.Module):
     def get_device(self):
         return next(self.parameters()).device
 
- 
+
 class Unet34(GeneratorModule): 
     @staticmethod
-    def get_pretrained_resnet_base(layers_cut:int=0):
+    def _get_pretrained_resnet_base(layers_cut:int=0):
         f = resnet34
         cut,lr_cut = model_meta[f]
         cut-=layers_cut
@@ -39,7 +45,7 @@ class Unet34(GeneratorModule):
         self_attention=True
         bn=True
         sn=True
-        self.rn, self.lr_cut = Unet34.get_pretrained_resnet_base()
+        self.rn, self.lr_cut = Unet34._get_pretrained_resnet_base()
         self.relu = nn.ReLU()
         self.up1 = UnetBlock(512,256,512*nf_factor, sn=sn, leakyReLu=leakyReLu, bn=bn)
         self.up2 = UnetBlock(512*nf_factor,128,512*nf_factor, sn=sn, leakyReLu=leakyReLu, bn=bn)
@@ -72,9 +78,9 @@ class Unet34(GeneratorModule):
         target_h = x.shape[2]-padh
         target_w = x.shape[3]-padw
         return x[:,:,:target_h, :target_w]
-           
-    def forward(self, x_in:torch.Tensor):
-        x = self.rn[0](x_in)
+
+    def _encode(self, x:torch.Tensor):
+        x = self.rn[0](x)
         x = self.rn[1](x)
         x = self.rn[2](x)
         enc0 = x
@@ -86,24 +92,32 @@ class Unet34(GeneratorModule):
         x = self.rn[6](x)
         enc3 = x
         x = self.rn[7](x)
+        return (x, enc0, enc1, enc2, enc3)
 
-        padw = 0
+    def _decode(self, x:torch.Tensor, enc0:torch.Tensor, enc1:torch.Tensor, enc2:torch.Tensor, enc3:torch.Tensor):
         padh = 0
-
+        padw = 0
         x = self.relu(x)
-        penc3, padh, padw = self._pad(enc3, x, padh, padw)
-        x = self.up1(x, penc3)
-        penc2, padh, padw  = self._pad(enc2, x, padh, padw)
-        x = self.up2(x, penc2)
-        penc1, padh, padw  = self._pad(enc1, x, padh, padw)
-        x = self.up3(x, penc1)
-        penc0, padh, padw  = self._pad(enc0, x, padh, padw)
-        x = self.up4(x, penc0)
-
-        x = self._remove_padding(x, padh, padw)
-
+        enc3, padh, padw = self._pad(enc3, x, padh, padw)
+        x = self.up1(x, enc3)
+        enc2, padh, padw  = self._pad(enc2, x, padh, padw)
+        x = self.up2(x, enc2)
+        enc1, padh, padw  = self._pad(enc1, x, padh, padw)
+        x = self.up3(x, enc1)
+        enc0, padh, padw  = self._pad(enc0, x, padh, padw)
+        x = self.up4(x, enc0)
+        #This is a bit too much padding being removed, but I 
+        #haven't yet figured out a good way to determine what 
+        #exactly should be removed.  This is consistently more 
+        #than enough though.
         x = self.up5(x)
         x = self.out(x)
+        x = self._remove_padding(x, padh, padw)
+        return x
+
+    def forward(self, x:torch.Tensor):
+        x, enc0, enc1, enc2, enc3 = self._encode(x)
+        x = self._decode(x, enc0, enc1, enc2, enc3)
         return x
     
     def get_layer_groups(self, precompute:bool=False)->[]:
@@ -114,3 +128,4 @@ class Unet34(GeneratorModule):
         for sf in self.sfs: 
             sf.remove()
 
+ 

+ 0 - 1
fasterai/images.py

@@ -8,7 +8,6 @@ from PIL import Image
 from numpy import ndarray
 from datetime import datetime
 
-
 class EasyTensorImage():
     def __init__(self, source_tensor:torch.Tensor, ds:FilesDataset):
         self.array = self._convert_to_denormed_ndarray(source_tensor, ds=ds)   

+ 0 - 1
fasterai/loss.py

@@ -5,7 +5,6 @@ from fastai.conv_learner import children
 from .modules import SaveFeatures
 import torchvision.models as models
 
-
 class FeatureLoss(nn.Module):
     def __init__(self, block_wgts:[float]=[0.2,0.7,0.1], multiplier:float=1.0):
         super().__init__()

+ 0 - 1
fasterai/modules.py

@@ -2,7 +2,6 @@ from fastai.torch_imports import *
 from fastai.conv_learner import *
 from torch.nn.utils.spectral_norm import spectral_norm
 
-
 class ConvBlock(nn.Module):
     def __init__(self, ni:int, no:int, ks:int=3, stride:int=1, pad:int=None, actn:bool=True, 
             bn:bool=True, bias:bool=True, sn:bool=False, leakyReLu:bool=False, self_attention:bool=False,

+ 27 - 76
fasterai/visualize.py

@@ -6,26 +6,28 @@ from matplotlib.figure import Figure
 from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
 from fastai.dataset import FilesDataset, ImageData, ModelData, open_image
 from fastai.transforms import Transform, scale_min, tfms_from_stats, inception_stats
-from fastai.transforms import CropType, NoCrop, Denormalize
+from fastai.transforms import CropType, NoCrop, Denormalize, Scale
+from fasterai.transforms import BlackAndWhiteTransform
 from .training import GenResult, CriticResult, GANTrainer
 from .images import ModelImageSet, EasyTensorImage
+from .generators import GeneratorModule
+from .filters import Filter, Colorizer
 from IPython.display import display
 from tensorboardX import SummaryWriter
 from scipy import misc
 import torchvision.utils as vutils
 import statistics
-from PIL import Image 
-
+from PIL import Image
 
 class ModelImageVisualizer():
-    def __init__(self, default_sz:int=500, results_dir:str=None):
-        self.default_sz=default_sz 
-        self.denorm = Denormalize(*inception_stats) 
+    def __init__(self, filters:[Filter]=[], render_factor:int=18, results_dir:str=None):
+        self.filters = filters
+        self.render_factor=render_factor 
         self.results_dir=None if results_dir is None else Path(results_dir)
 
-    def plot_transformed_image(self, path:str, model:nn.Module, figsize:(int,int)=(20,20), sz:int=None, tfms:[Transform]=[])->ndarray:
+    def plot_transformed_image(self, path:str, figsize:(int,int)=(20,20), render_factor:int=None)->ndarray:
         path = Path(path)
-        result = self.get_transformed_image_ndarray(path, model, sz, tfms=tfms)
+        result = self._get_transformed_image_ndarray(path, render_factor)
         orig = open_image(str(path))
         fig,axes = plt.subplots(1, 2, figsize=figsize)
         self._plot_image_from_ndarray(orig, axes=axes[0], figsize=figsize)
@@ -34,63 +36,24 @@ class ModelImageVisualizer():
         if self.results_dir is not None:
             self._save_result_image(path, result)
 
-    def get_transformed_image_as_pil(self, path:str, model:nn.Module, sz:int=None, tfms:[Transform]=[])->Image:
+    def get_transformed_image_as_pil(self, path:str, render_factor:int=None)->Image:
         path = Path(path)
-        array = self.get_transformed_image_ndarray(path, model, sz, tfms=tfms)
+        array = self._get_transformed_image_ndarray(path, render_factor)
         return misc.toimage(array)
 
     def _save_result_image(self, source_path:Path, result:ndarray):
         result_path = self.results_dir/source_path.name
-        misc.imsave(result_path, result)
-
-    def plot_images_from_image_sets(self, image_sets:[ModelImageSet], validation:bool, figsize:(int,int)=(20,20), 
-            max_columns:int=6, immediate_display:bool=True):
-        num_sets = len(image_sets)
-        num_images = num_sets * 2
-        rows, columns = self._get_num_rows_columns(num_images, max_columns)
-
-        fig, axes = plt.subplots(rows, columns, figsize=figsize)
-        title = 'Validation' if validation else 'Training'
-        fig.suptitle(title, fontsize=16)
-
-        for i, image_set in enumerate(image_sets):
-            self._plot_image_from_ndarray(image_set.orig.array, axes=axes.flat[i*2])
-            self._plot_image_from_ndarray(image_set.gen.array, axes=axes.flat[i*2+1])
-
-        if immediate_display:
-            display(fig)
-
-    def get_transformed_image_ndarray(self, path:Path, model:nn.Module, sz:int=None, tfms:[Transform]=[]):
-        training = model.training 
-        model.eval()
-        with torch.no_grad():
-            orig = self._get_model_ready_image_ndarray(path, model, sz, tfms)
-            orig = VV_(orig[None])
-            result = model(orig).detach().cpu().numpy()
-            result = self._denorm(result)
-
-        if training:
-            model.train()
-        return result[0]
-
-    def _denorm(self, image: ndarray):
-        if len(image.shape)==3: arr = arr[None]
-        return self.denorm(np.rollaxis(image,1,4))
-
-    def _transform(self, orig:ndarray, tfms:[Transform], model:nn.Module, sz:int):
-        for tfm in tfms:
-            orig,_=tfm(orig, False)
-        _,val_tfms = tfms_from_stats(inception_stats, sz, crop_type=CropType.NO, aug_tfms=[])
-        val_tfms.tfms = [tfm for tfm in val_tfms.tfms if not isinstance(tfm, NoCrop)]
-        orig = val_tfms(orig)
-        return orig
-
-    def _get_model_ready_image_ndarray(self, path:Path, model:nn.Module, sz:int=None, tfms:[Transform]=[]):
-        im = open_image(str(path))
-        sz = self.default_sz if sz is None else sz
-        im = scale_min(im, sz)
-        im = self._transform(im, tfms, model, sz)
-        return im
+        misc.imsave(result_path, np.clip(result,0,1))
+
+    def _get_transformed_image_ndarray(self, path:Path, render_factor:int=None):
+        orig = open_image(str(path))
+        result = orig
+        render_factor = self.render_factor if render_factor is None else render_factor
+
+        for filt in self.filters:
+            result = filt.filter(result, render_factor=render_factor)
+
+        return result
 
     def _plot_image_from_ndarray(self, image:ndarray, axes:Axes=None, figsize=(20,20)):
         if axes is None: 
@@ -171,12 +134,11 @@ class ImageGenVisualizer():
     def __init__(self):
         self.model_vis = ModelImageVisualizer()
 
-    def output_image_gen_visuals(self, md:ImageData, model:nn.Module, iter_count:int, tbwriter:SummaryWriter, jupyter:bool=False):
-        self._output_visuals(ds=md.val_ds, model=model, iter_count=iter_count, tbwriter=tbwriter, jupyter=jupyter, validation=True)
-        self._output_visuals(ds=md.trn_ds, model=model, iter_count=iter_count, tbwriter=tbwriter, jupyter=jupyter, validation=False)
+    def output_image_gen_visuals(self, md:ImageData, model:GeneratorModule, iter_count:int, tbwriter:SummaryWriter):
+        self._output_visuals(ds=md.val_ds, model=model, iter_count=iter_count, tbwriter=tbwriter, validation=True)
+        self._output_visuals(ds=md.trn_ds, model=model, iter_count=iter_count, tbwriter=tbwriter, validation=False)
 
-    def _output_visuals(self, ds:FilesDataset, model:nn.Module, iter_count:int, tbwriter:SummaryWriter, 
-            validation:bool, jupyter:bool=False):
+    def _output_visuals(self, ds:FilesDataset, model:GeneratorModule, iter_count:int, tbwriter:SummaryWriter, validation:bool):
         #TODO:  Parameterize these
         start_idx=0
         count = 8
@@ -184,8 +146,6 @@ class ImageGenVisualizer():
         idxs = list(range(start_idx,end_index))
         image_sets = ModelImageSet.get_list_from_model(ds=ds, model=model, idxs=idxs)
         self._write_tensorboard_images(image_sets=image_sets, iter_count=iter_count, tbwriter=tbwriter, validation=validation)
-        if jupyter:
-            self._show_images_in_jupyter(image_sets, validation=validation)
     
     def _write_tensorboard_images(self, image_sets:[ModelImageSet], iter_count:int, tbwriter:SummaryWriter, validation:bool):
         orig_images = []
@@ -204,15 +164,6 @@ class ImageGenVisualizer():
         tbwriter.add_image(prefix + ' real images', vutils.make_grid(real_images, normalize=True), iter_count)
 
 
-    def _show_images_in_jupyter(self, image_sets:[ModelImageSet], validation:bool):
-        #TODO:  Parameterize these
-        figsize=(20,20)
-        max_columns=4
-        immediate_display=True
-        self.model_vis.plot_images_from_image_sets(image_sets, figsize=figsize, max_columns=max_columns, 
-            immediate_display=immediate_display, validation=validation)
-
-
 class GANTrainerStatsVisualizer():
     def __init__(self):
         return

Některé soubory nejsou zobrazeny, neboť je v těchto rozdílových datech změněno mnoho souborů