|
@@ -2,6 +2,8 @@ from numpy import ndarray
|
|
|
from fastai.torch_imports import *
|
|
|
from fastai.core import *
|
|
|
from matplotlib.axes import Axes
|
|
|
+from matplotlib.figure import Figure
|
|
|
+from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
|
|
|
from fastai.dataset import FilesDataset, ImageData, ModelData, open_image
|
|
|
from fastai.transforms import Transform, scale_min, tfms_from_stats, inception_stats
|
|
|
from fastai.transforms import CropType, NoCrop, Denormalize
|
|
@@ -9,30 +11,59 @@ from .training import GenResult, CriticResult, GANTrainer
|
|
|
from .images import ModelImageSet, EasyTensorImage
|
|
|
from IPython.display import display
|
|
|
from tensorboardX import SummaryWriter
|
|
|
+from scipy import misc
|
|
|
import torchvision.utils as vutils
|
|
|
import statistics
|
|
|
+from PIL import Image
|
|
|
|
|
|
|
|
|
class ModelImageVisualizer():
|
|
|
- def __init__(self, default_sz:int=500):
|
|
|
+ def __init__(self, default_sz:int=500, results_dir:str=None):
|
|
|
self.default_sz=default_sz
|
|
|
self.denorm = Denormalize(*inception_stats)
|
|
|
+ self.results_dir=None if results_dir is None else Path(results_dir)
|
|
|
|
|
|
- def plot_transformed_image(self, path:Path, model:nn.Module, figsize:(int,int)=(20,20), sz:int=None,
|
|
|
- tfms:[Transform]=[], compare:bool=True):
|
|
|
+ def plot_transformed_image(self, path:str, model:nn.Module, figsize:(int,int)=(20,20), sz:int=None, tfms:[Transform]=[])->ndarray:
|
|
|
+ path = Path(path)
|
|
|
result = self.get_transformed_image_ndarray(path, model, sz, tfms=tfms)
|
|
|
- if compare:
|
|
|
- orig = open_image(str(path))
|
|
|
- fig,axes = plt.subplots(1, 2, figsize=figsize)
|
|
|
- self.plot_image_from_ndarray(orig, axes=axes[0], figsize=figsize)
|
|
|
- self.plot_image_from_ndarray(result, axes=axes[1], figsize=figsize)
|
|
|
- else:
|
|
|
- self.plot_image_from_ndarray(result, figsize=figsize)
|
|
|
+ orig = open_image(str(path))
|
|
|
+ fig,axes = plt.subplots(1, 2, figsize=figsize)
|
|
|
+ self._plot_image_from_ndarray(orig, axes=axes[0], figsize=figsize)
|
|
|
+ self._plot_image_from_ndarray(result, axes=axes[1], figsize=figsize)
|
|
|
+
|
|
|
+ if self.results_dir is not None:
|
|
|
+ self._save_result_image(path, result)
|
|
|
+
|
|
|
+ def get_transformed_image_as_pil(self, path:str, model:nn.Module, sz:int=None, tfms:[Transform]=[])->Image:
|
|
|
+ path = Path(path)
|
|
|
+ array = self.get_transformed_image_ndarray(path, model, sz, tfms=tfms)
|
|
|
+ return misc.toimage(array)
|
|
|
+
|
|
|
+ def _save_result_image(self, source_path:Path, result:ndarray):
|
|
|
+ result_path = self.results_dir/source_path.name
|
|
|
+ misc.imsave(result_path, result)
|
|
|
+
|
|
|
+ def plot_images_from_image_sets(self, image_sets:[ModelImageSet], validation:bool, figsize:(int,int)=(20,20),
|
|
|
+ max_columns:int=6, immediate_display:bool=True):
|
|
|
+ num_sets = len(image_sets)
|
|
|
+ num_images = num_sets * 2
|
|
|
+ rows, columns = self._get_num_rows_columns(num_images, max_columns)
|
|
|
+
|
|
|
+ fig, axes = plt.subplots(rows, columns, figsize=figsize)
|
|
|
+ title = 'Validation' if validation else 'Training'
|
|
|
+ fig.suptitle(title, fontsize=16)
|
|
|
+
|
|
|
+ for i, image_set in enumerate(image_sets):
|
|
|
+ self._plot_image_from_ndarray(image_set.orig.array, axes=axes.flat[i*2])
|
|
|
+ self._plot_image_from_ndarray(image_set.gen.array, axes=axes.flat[i*2+1])
|
|
|
+
|
|
|
+ if immediate_display:
|
|
|
+ display(fig)
|
|
|
|
|
|
def get_transformed_image_ndarray(self, path:Path, model:nn.Module, sz:int=None, tfms:[Transform]=[]):
|
|
|
training = model.training
|
|
|
model.eval()
|
|
|
- orig = self.get_model_ready_image_ndarray(path, model, sz, tfms)
|
|
|
+ orig = self._get_model_ready_image_ndarray(path, model, sz, tfms)
|
|
|
orig = VV(orig[None])
|
|
|
result = model(orig).detach().cpu().numpy()
|
|
|
result = self._denorm(result)
|
|
@@ -52,44 +83,20 @@ class ModelImageVisualizer():
|
|
|
orig = val_tfms(orig)
|
|
|
return orig
|
|
|
|
|
|
- def get_model_ready_image_ndarray(self, path:Path, model:nn.Module, sz:int=None, tfms:[Transform]=[]):
|
|
|
+ def _get_model_ready_image_ndarray(self, path:Path, model:nn.Module, sz:int=None, tfms:[Transform]=[]):
|
|
|
im = open_image(str(path))
|
|
|
sz = self.default_sz if sz is None else sz
|
|
|
im = scale_min(im, sz)
|
|
|
im = self._transform(im, tfms, model, sz)
|
|
|
return im
|
|
|
|
|
|
- def plot_image_from_ndarray(self, image:ndarray, axes:Axes=None, figsize=(20,20)):
|
|
|
+ def _plot_image_from_ndarray(self, image:ndarray, axes:Axes=None, figsize=(20,20)):
|
|
|
if axes is None:
|
|
|
_,axes = plt.subplots(figsize=figsize)
|
|
|
clipped_image =np.clip(image,0,1)
|
|
|
axes.imshow(clipped_image)
|
|
|
axes.axis('off')
|
|
|
|
|
|
-
|
|
|
- def plot_images_from_image_sets(self, image_sets:[ModelImageSet], validation:bool, figsize:(int,int)=(20,20),
|
|
|
- max_columns:int=6, immediate_display:bool=True):
|
|
|
- num_sets = len(image_sets)
|
|
|
- num_images = num_sets * 2
|
|
|
- rows, columns = self._get_num_rows_columns(num_images, max_columns)
|
|
|
-
|
|
|
- fig, axes = plt.subplots(rows, columns, figsize=figsize)
|
|
|
- title = 'Validation' if validation else 'Training'
|
|
|
- fig.suptitle(title, fontsize=16)
|
|
|
-
|
|
|
- for i, image_set in enumerate(image_sets):
|
|
|
- self.plot_image_from_ndarray(image_set.orig.array, axes=axes.flat[i*2])
|
|
|
- self.plot_image_from_ndarray(image_set.gen.array, axes=axes.flat[i*2+1])
|
|
|
-
|
|
|
- if immediate_display:
|
|
|
- display(fig)
|
|
|
-
|
|
|
-
|
|
|
- def plot_image_outputs_from_model(self, ds:FilesDataset, model:nn.Module, idxs:[int], figsize:(int,int)=(20,20), max_columns:int=6,
|
|
|
- immediate_display:bool=True):
|
|
|
- image_sets = ModelImageSet.get_list_from_model(ds=ds, model=model, idxs=idxs)
|
|
|
- self.plot_images_from_image_sets(image_sets=image_sets, figsize=figsize, max_columns=max_columns, immediate_display=immediate_display)
|
|
|
-
|
|
|
def _get_num_rows_columns(self, num_images:int, max_columns:int):
|
|
|
columns = min(num_images, max_columns)
|
|
|
rows = num_images//columns
|