Переглянути джерело

Adding result image save and get result as pil image fucntionality for model image visualizer

Jason Antic 6 роки тому
батько
коміт
79803a9366
1 змінених файлів з 44 додано та 37 видалено
  1. 44 37
      fasterai/visualize.py

+ 44 - 37
fasterai/visualize.py

@@ -2,6 +2,8 @@ from numpy import ndarray
 from fastai.torch_imports import *
 from fastai.core import *
 from matplotlib.axes import Axes
+from matplotlib.figure import Figure
+from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
 from fastai.dataset import FilesDataset, ImageData, ModelData, open_image
 from fastai.transforms import Transform, scale_min, tfms_from_stats, inception_stats
 from fastai.transforms import CropType, NoCrop, Denormalize
@@ -9,30 +11,59 @@ from .training import GenResult, CriticResult, GANTrainer
 from .images import ModelImageSet, EasyTensorImage
 from IPython.display import display
 from tensorboardX import SummaryWriter
+from scipy import misc
 import torchvision.utils as vutils
 import statistics
+from PIL import Image 
 
 
 class ModelImageVisualizer():
-    def __init__(self, default_sz:int=500):
+    def __init__(self, default_sz:int=500, results_dir:str=None):
         self.default_sz=default_sz 
         self.denorm = Denormalize(*inception_stats) 
+        self.results_dir=None if results_dir is None else Path(results_dir)
 
-    def plot_transformed_image(self, path:Path, model:nn.Module, figsize:(int,int)=(20,20), sz:int=None, 
-            tfms:[Transform]=[], compare:bool=True):
+    def plot_transformed_image(self, path:str, model:nn.Module, figsize:(int,int)=(20,20), sz:int=None, tfms:[Transform]=[])->ndarray:
+        path = Path(path)
         result = self.get_transformed_image_ndarray(path, model, sz, tfms=tfms)
-        if compare: 
-            orig = open_image(str(path))
-            fig,axes = plt.subplots(1, 2, figsize=figsize)
-            self.plot_image_from_ndarray(orig, axes=axes[0], figsize=figsize)
-            self.plot_image_from_ndarray(result, axes=axes[1], figsize=figsize)
-        else:
-            self.plot_image_from_ndarray(result, figsize=figsize)
+        orig = open_image(str(path))
+        fig,axes = plt.subplots(1, 2, figsize=figsize)
+        self._plot_image_from_ndarray(orig, axes=axes[0], figsize=figsize)
+        self._plot_image_from_ndarray(result, axes=axes[1], figsize=figsize)
+
+        if self.results_dir is not None:
+            self._save_result_image(path, result)
+
+    def get_transformed_image_as_pil(self, path:str, model:nn.Module, sz:int=None, tfms:[Transform]=[])->Image:
+        path = Path(path)
+        array = self.get_transformed_image_ndarray(path, model, sz, tfms=tfms)
+        return misc.toimage(array)
+
+    def _save_result_image(self, source_path:Path, result:ndarray):
+        result_path = self.results_dir/source_path.name
+        misc.imsave(result_path, result)
+
+    def plot_images_from_image_sets(self, image_sets:[ModelImageSet], validation:bool, figsize:(int,int)=(20,20), 
+            max_columns:int=6, immediate_display:bool=True):
+        num_sets = len(image_sets)
+        num_images = num_sets * 2
+        rows, columns = self._get_num_rows_columns(num_images, max_columns)
+
+        fig, axes = plt.subplots(rows, columns, figsize=figsize)
+        title = 'Validation' if validation else 'Training'
+        fig.suptitle(title, fontsize=16)
+
+        for i, image_set in enumerate(image_sets):
+            self._plot_image_from_ndarray(image_set.orig.array, axes=axes.flat[i*2])
+            self._plot_image_from_ndarray(image_set.gen.array, axes=axes.flat[i*2+1])
+
+        if immediate_display:
+            display(fig)
 
     def get_transformed_image_ndarray(self, path:Path, model:nn.Module, sz:int=None, tfms:[Transform]=[]):
         training = model.training 
         model.eval()
-        orig = self.get_model_ready_image_ndarray(path, model, sz, tfms)
+        orig = self._get_model_ready_image_ndarray(path, model, sz, tfms)
         orig = VV(orig[None])
         result = model(orig).detach().cpu().numpy()
         result = self._denorm(result)
@@ -52,44 +83,20 @@ class ModelImageVisualizer():
         orig = val_tfms(orig)
         return orig
 
-    def get_model_ready_image_ndarray(self, path:Path, model:nn.Module, sz:int=None, tfms:[Transform]=[]):
+    def _get_model_ready_image_ndarray(self, path:Path, model:nn.Module, sz:int=None, tfms:[Transform]=[]):
         im = open_image(str(path))
         sz = self.default_sz if sz is None else sz
         im = scale_min(im, sz)
         im = self._transform(im, tfms, model, sz)
         return im
 
-    def plot_image_from_ndarray(self, image:ndarray, axes:Axes=None, figsize=(20,20)):
+    def _plot_image_from_ndarray(self, image:ndarray, axes:Axes=None, figsize=(20,20)):
         if axes is None: 
             _,axes = plt.subplots(figsize=figsize)
         clipped_image =np.clip(image,0,1)
         axes.imshow(clipped_image)
         axes.axis('off')
 
-
-    def plot_images_from_image_sets(self, image_sets:[ModelImageSet], validation:bool, figsize:(int,int)=(20,20), 
-            max_columns:int=6, immediate_display:bool=True):
-        num_sets = len(image_sets)
-        num_images = num_sets * 2
-        rows, columns = self._get_num_rows_columns(num_images, max_columns)
-
-        fig, axes = plt.subplots(rows, columns, figsize=figsize)
-        title = 'Validation' if validation else 'Training'
-        fig.suptitle(title, fontsize=16)
-
-        for i, image_set in enumerate(image_sets):
-            self.plot_image_from_ndarray(image_set.orig.array, axes=axes.flat[i*2])
-            self.plot_image_from_ndarray(image_set.gen.array, axes=axes.flat[i*2+1])
-
-        if immediate_display:
-            display(fig)
-
-
-    def plot_image_outputs_from_model(self, ds:FilesDataset, model:nn.Module, idxs:[int], figsize:(int,int)=(20,20), max_columns:int=6, 
-            immediate_display:bool=True):
-        image_sets = ModelImageSet.get_list_from_model(ds=ds, model=model, idxs=idxs)
-        self.plot_images_from_image_sets(image_sets=image_sets, figsize=figsize, max_columns=max_columns, immediate_display=immediate_display)
-
     def _get_num_rows_columns(self, num_images:int, max_columns:int):
         columns = min(num_images, max_columns)
         rows = num_images//columns