瀏覽代碼

Adding result image save and get result as pil image fucntionality for model image visualizer

Jason Antic 6 年之前
父節點
當前提交
79803a9366
共有 1 個文件被更改,包括 44 次插入37 次删除
  1. 44 37
      fasterai/visualize.py

+ 44 - 37
fasterai/visualize.py

@@ -2,6 +2,8 @@ from numpy import ndarray
 from fastai.torch_imports import *
 from fastai.torch_imports import *
 from fastai.core import *
 from fastai.core import *
 from matplotlib.axes import Axes
 from matplotlib.axes import Axes
+from matplotlib.figure import Figure
+from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
 from fastai.dataset import FilesDataset, ImageData, ModelData, open_image
 from fastai.dataset import FilesDataset, ImageData, ModelData, open_image
 from fastai.transforms import Transform, scale_min, tfms_from_stats, inception_stats
 from fastai.transforms import Transform, scale_min, tfms_from_stats, inception_stats
 from fastai.transforms import CropType, NoCrop, Denormalize
 from fastai.transforms import CropType, NoCrop, Denormalize
@@ -9,30 +11,59 @@ from .training import GenResult, CriticResult, GANTrainer
 from .images import ModelImageSet, EasyTensorImage
 from .images import ModelImageSet, EasyTensorImage
 from IPython.display import display
 from IPython.display import display
 from tensorboardX import SummaryWriter
 from tensorboardX import SummaryWriter
+from scipy import misc
 import torchvision.utils as vutils
 import torchvision.utils as vutils
 import statistics
 import statistics
+from PIL import Image 
 
 
 
 
 class ModelImageVisualizer():
 class ModelImageVisualizer():
-    def __init__(self, default_sz:int=500):
+    def __init__(self, default_sz:int=500, results_dir:str=None):
         self.default_sz=default_sz 
         self.default_sz=default_sz 
         self.denorm = Denormalize(*inception_stats) 
         self.denorm = Denormalize(*inception_stats) 
+        self.results_dir=None if results_dir is None else Path(results_dir)
 
 
-    def plot_transformed_image(self, path:Path, model:nn.Module, figsize:(int,int)=(20,20), sz:int=None, 
-            tfms:[Transform]=[], compare:bool=True):
+    def plot_transformed_image(self, path:str, model:nn.Module, figsize:(int,int)=(20,20), sz:int=None, tfms:[Transform]=[])->ndarray:
+        path = Path(path)
         result = self.get_transformed_image_ndarray(path, model, sz, tfms=tfms)
         result = self.get_transformed_image_ndarray(path, model, sz, tfms=tfms)
-        if compare: 
-            orig = open_image(str(path))
-            fig,axes = plt.subplots(1, 2, figsize=figsize)
-            self.plot_image_from_ndarray(orig, axes=axes[0], figsize=figsize)
-            self.plot_image_from_ndarray(result, axes=axes[1], figsize=figsize)
-        else:
-            self.plot_image_from_ndarray(result, figsize=figsize)
+        orig = open_image(str(path))
+        fig,axes = plt.subplots(1, 2, figsize=figsize)
+        self._plot_image_from_ndarray(orig, axes=axes[0], figsize=figsize)
+        self._plot_image_from_ndarray(result, axes=axes[1], figsize=figsize)
+
+        if self.results_dir is not None:
+            self._save_result_image(path, result)
+
+    def get_transformed_image_as_pil(self, path:str, model:nn.Module, sz:int=None, tfms:[Transform]=[])->Image:
+        path = Path(path)
+        array = self.get_transformed_image_ndarray(path, model, sz, tfms=tfms)
+        return misc.toimage(array)
+
+    def _save_result_image(self, source_path:Path, result:ndarray):
+        result_path = self.results_dir/source_path.name
+        misc.imsave(result_path, result)
+
+    def plot_images_from_image_sets(self, image_sets:[ModelImageSet], validation:bool, figsize:(int,int)=(20,20), 
+            max_columns:int=6, immediate_display:bool=True):
+        num_sets = len(image_sets)
+        num_images = num_sets * 2
+        rows, columns = self._get_num_rows_columns(num_images, max_columns)
+
+        fig, axes = plt.subplots(rows, columns, figsize=figsize)
+        title = 'Validation' if validation else 'Training'
+        fig.suptitle(title, fontsize=16)
+
+        for i, image_set in enumerate(image_sets):
+            self._plot_image_from_ndarray(image_set.orig.array, axes=axes.flat[i*2])
+            self._plot_image_from_ndarray(image_set.gen.array, axes=axes.flat[i*2+1])
+
+        if immediate_display:
+            display(fig)
 
 
     def get_transformed_image_ndarray(self, path:Path, model:nn.Module, sz:int=None, tfms:[Transform]=[]):
     def get_transformed_image_ndarray(self, path:Path, model:nn.Module, sz:int=None, tfms:[Transform]=[]):
         training = model.training 
         training = model.training 
         model.eval()
         model.eval()
-        orig = self.get_model_ready_image_ndarray(path, model, sz, tfms)
+        orig = self._get_model_ready_image_ndarray(path, model, sz, tfms)
         orig = VV(orig[None])
         orig = VV(orig[None])
         result = model(orig).detach().cpu().numpy()
         result = model(orig).detach().cpu().numpy()
         result = self._denorm(result)
         result = self._denorm(result)
@@ -52,44 +83,20 @@ class ModelImageVisualizer():
         orig = val_tfms(orig)
         orig = val_tfms(orig)
         return orig
         return orig
 
 
-    def get_model_ready_image_ndarray(self, path:Path, model:nn.Module, sz:int=None, tfms:[Transform]=[]):
+    def _get_model_ready_image_ndarray(self, path:Path, model:nn.Module, sz:int=None, tfms:[Transform]=[]):
         im = open_image(str(path))
         im = open_image(str(path))
         sz = self.default_sz if sz is None else sz
         sz = self.default_sz if sz is None else sz
         im = scale_min(im, sz)
         im = scale_min(im, sz)
         im = self._transform(im, tfms, model, sz)
         im = self._transform(im, tfms, model, sz)
         return im
         return im
 
 
-    def plot_image_from_ndarray(self, image:ndarray, axes:Axes=None, figsize=(20,20)):
+    def _plot_image_from_ndarray(self, image:ndarray, axes:Axes=None, figsize=(20,20)):
         if axes is None: 
         if axes is None: 
             _,axes = plt.subplots(figsize=figsize)
             _,axes = plt.subplots(figsize=figsize)
         clipped_image =np.clip(image,0,1)
         clipped_image =np.clip(image,0,1)
         axes.imshow(clipped_image)
         axes.imshow(clipped_image)
         axes.axis('off')
         axes.axis('off')
 
 
-
-    def plot_images_from_image_sets(self, image_sets:[ModelImageSet], validation:bool, figsize:(int,int)=(20,20), 
-            max_columns:int=6, immediate_display:bool=True):
-        num_sets = len(image_sets)
-        num_images = num_sets * 2
-        rows, columns = self._get_num_rows_columns(num_images, max_columns)
-
-        fig, axes = plt.subplots(rows, columns, figsize=figsize)
-        title = 'Validation' if validation else 'Training'
-        fig.suptitle(title, fontsize=16)
-
-        for i, image_set in enumerate(image_sets):
-            self.plot_image_from_ndarray(image_set.orig.array, axes=axes.flat[i*2])
-            self.plot_image_from_ndarray(image_set.gen.array, axes=axes.flat[i*2+1])
-
-        if immediate_display:
-            display(fig)
-
-
-    def plot_image_outputs_from_model(self, ds:FilesDataset, model:nn.Module, idxs:[int], figsize:(int,int)=(20,20), max_columns:int=6, 
-            immediate_display:bool=True):
-        image_sets = ModelImageSet.get_list_from_model(ds=ds, model=model, idxs=idxs)
-        self.plot_images_from_image_sets(image_sets=image_sets, figsize=figsize, max_columns=max_columns, immediate_display=immediate_display)
-
     def _get_num_rows_columns(self, num_images:int, max_columns:int):
     def _get_num_rows_columns(self, num_images:int, max_columns:int):
         columns = min(num_images, max_columns)
         columns = min(num_images, max_columns)
         rows = num_images//columns
         rows = num_images//columns