|
@@ -6,26 +6,28 @@ from matplotlib.figure import Figure
|
|
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
|
|
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
|
|
from fastai.dataset import FilesDataset, ImageData, ModelData, open_image
|
|
from fastai.dataset import FilesDataset, ImageData, ModelData, open_image
|
|
from fastai.transforms import Transform, scale_min, tfms_from_stats, inception_stats
|
|
from fastai.transforms import Transform, scale_min, tfms_from_stats, inception_stats
|
|
-from fastai.transforms import CropType, NoCrop, Denormalize
|
|
|
|
|
|
+from fastai.transforms import CropType, NoCrop, Denormalize, Scale
|
|
|
|
+from fasterai.transforms import BlackAndWhiteTransform
|
|
from .training import GenResult, CriticResult, GANTrainer
|
|
from .training import GenResult, CriticResult, GANTrainer
|
|
from .images import ModelImageSet, EasyTensorImage
|
|
from .images import ModelImageSet, EasyTensorImage
|
|
|
|
+from .generators import GeneratorModule
|
|
|
|
+from .filters import Filter, Colorizer
|
|
from IPython.display import display
|
|
from IPython.display import display
|
|
from tensorboardX import SummaryWriter
|
|
from tensorboardX import SummaryWriter
|
|
from scipy import misc
|
|
from scipy import misc
|
|
import torchvision.utils as vutils
|
|
import torchvision.utils as vutils
|
|
import statistics
|
|
import statistics
|
|
-from PIL import Image
|
|
|
|
-
|
|
|
|
|
|
+from PIL import Image
|
|
|
|
|
|
class ModelImageVisualizer():
|
|
class ModelImageVisualizer():
|
|
- def __init__(self, default_sz:int=500, results_dir:str=None):
|
|
|
|
- self.default_sz=default_sz
|
|
|
|
- self.denorm = Denormalize(*inception_stats)
|
|
|
|
|
|
+ def __init__(self, filters:[Filter]=[], render_factor:int=18, results_dir:str=None):
|
|
|
|
+ self.filters = filters
|
|
|
|
+ self.render_factor=render_factor
|
|
self.results_dir=None if results_dir is None else Path(results_dir)
|
|
self.results_dir=None if results_dir is None else Path(results_dir)
|
|
|
|
|
|
- def plot_transformed_image(self, path:str, model:nn.Module, figsize:(int,int)=(20,20), sz:int=None, tfms:[Transform]=[])->ndarray:
|
|
|
|
|
|
+ def plot_transformed_image(self, path:str, figsize:(int,int)=(20,20), render_factor:int=None)->ndarray:
|
|
path = Path(path)
|
|
path = Path(path)
|
|
- result = self.get_transformed_image_ndarray(path, model, sz, tfms=tfms)
|
|
|
|
|
|
+ result = self._get_transformed_image_ndarray(path, render_factor)
|
|
orig = open_image(str(path))
|
|
orig = open_image(str(path))
|
|
fig,axes = plt.subplots(1, 2, figsize=figsize)
|
|
fig,axes = plt.subplots(1, 2, figsize=figsize)
|
|
self._plot_image_from_ndarray(orig, axes=axes[0], figsize=figsize)
|
|
self._plot_image_from_ndarray(orig, axes=axes[0], figsize=figsize)
|
|
@@ -34,63 +36,24 @@ class ModelImageVisualizer():
|
|
if self.results_dir is not None:
|
|
if self.results_dir is not None:
|
|
self._save_result_image(path, result)
|
|
self._save_result_image(path, result)
|
|
|
|
|
|
- def get_transformed_image_as_pil(self, path:str, model:nn.Module, sz:int=None, tfms:[Transform]=[])->Image:
|
|
|
|
|
|
+ def get_transformed_image_as_pil(self, path:str, render_factor:int=None)->Image:
|
|
path = Path(path)
|
|
path = Path(path)
|
|
- array = self.get_transformed_image_ndarray(path, model, sz, tfms=tfms)
|
|
|
|
|
|
+ array = self._get_transformed_image_ndarray(path, render_factor)
|
|
return misc.toimage(array)
|
|
return misc.toimage(array)
|
|
|
|
|
|
def _save_result_image(self, source_path:Path, result:ndarray):
|
|
def _save_result_image(self, source_path:Path, result:ndarray):
|
|
result_path = self.results_dir/source_path.name
|
|
result_path = self.results_dir/source_path.name
|
|
- misc.imsave(result_path, result)
|
|
|
|
-
|
|
|
|
- def plot_images_from_image_sets(self, image_sets:[ModelImageSet], validation:bool, figsize:(int,int)=(20,20),
|
|
|
|
- max_columns:int=6, immediate_display:bool=True):
|
|
|
|
- num_sets = len(image_sets)
|
|
|
|
- num_images = num_sets * 2
|
|
|
|
- rows, columns = self._get_num_rows_columns(num_images, max_columns)
|
|
|
|
-
|
|
|
|
- fig, axes = plt.subplots(rows, columns, figsize=figsize)
|
|
|
|
- title = 'Validation' if validation else 'Training'
|
|
|
|
- fig.suptitle(title, fontsize=16)
|
|
|
|
-
|
|
|
|
- for i, image_set in enumerate(image_sets):
|
|
|
|
- self._plot_image_from_ndarray(image_set.orig.array, axes=axes.flat[i*2])
|
|
|
|
- self._plot_image_from_ndarray(image_set.gen.array, axes=axes.flat[i*2+1])
|
|
|
|
-
|
|
|
|
- if immediate_display:
|
|
|
|
- display(fig)
|
|
|
|
-
|
|
|
|
- def get_transformed_image_ndarray(self, path:Path, model:nn.Module, sz:int=None, tfms:[Transform]=[]):
|
|
|
|
- training = model.training
|
|
|
|
- model.eval()
|
|
|
|
- with torch.no_grad():
|
|
|
|
- orig = self._get_model_ready_image_ndarray(path, model, sz, tfms)
|
|
|
|
- orig = VV_(orig[None])
|
|
|
|
- result = model(orig).detach().cpu().numpy()
|
|
|
|
- result = self._denorm(result)
|
|
|
|
-
|
|
|
|
- if training:
|
|
|
|
- model.train()
|
|
|
|
- return result[0]
|
|
|
|
-
|
|
|
|
- def _denorm(self, image: ndarray):
|
|
|
|
- if len(image.shape)==3: arr = arr[None]
|
|
|
|
- return self.denorm(np.rollaxis(image,1,4))
|
|
|
|
-
|
|
|
|
- def _transform(self, orig:ndarray, tfms:[Transform], model:nn.Module, sz:int):
|
|
|
|
- for tfm in tfms:
|
|
|
|
- orig,_=tfm(orig, False)
|
|
|
|
- _,val_tfms = tfms_from_stats(inception_stats, sz, crop_type=CropType.NO, aug_tfms=[])
|
|
|
|
- val_tfms.tfms = [tfm for tfm in val_tfms.tfms if not isinstance(tfm, NoCrop)]
|
|
|
|
- orig = val_tfms(orig)
|
|
|
|
- return orig
|
|
|
|
-
|
|
|
|
- def _get_model_ready_image_ndarray(self, path:Path, model:nn.Module, sz:int=None, tfms:[Transform]=[]):
|
|
|
|
- im = open_image(str(path))
|
|
|
|
- sz = self.default_sz if sz is None else sz
|
|
|
|
- im = scale_min(im, sz)
|
|
|
|
- im = self._transform(im, tfms, model, sz)
|
|
|
|
- return im
|
|
|
|
|
|
+ misc.imsave(result_path, np.clip(result,0,1))
|
|
|
|
+
|
|
|
|
+ def _get_transformed_image_ndarray(self, path:Path, render_factor:int=None):
|
|
|
|
+ orig = open_image(str(path))
|
|
|
|
+ result = orig
|
|
|
|
+ render_factor = self.render_factor if render_factor is None else render_factor
|
|
|
|
+
|
|
|
|
+ for filt in self.filters:
|
|
|
|
+ result = filt.filter(result, render_factor=render_factor)
|
|
|
|
+
|
|
|
|
+ return result
|
|
|
|
|
|
def _plot_image_from_ndarray(self, image:ndarray, axes:Axes=None, figsize=(20,20)):
|
|
def _plot_image_from_ndarray(self, image:ndarray, axes:Axes=None, figsize=(20,20)):
|
|
if axes is None:
|
|
if axes is None:
|
|
@@ -171,12 +134,11 @@ class ImageGenVisualizer():
|
|
def __init__(self):
|
|
def __init__(self):
|
|
self.model_vis = ModelImageVisualizer()
|
|
self.model_vis = ModelImageVisualizer()
|
|
|
|
|
|
- def output_image_gen_visuals(self, md:ImageData, model:nn.Module, iter_count:int, tbwriter:SummaryWriter, jupyter:bool=False):
|
|
|
|
- self._output_visuals(ds=md.val_ds, model=model, iter_count=iter_count, tbwriter=tbwriter, jupyter=jupyter, validation=True)
|
|
|
|
- self._output_visuals(ds=md.trn_ds, model=model, iter_count=iter_count, tbwriter=tbwriter, jupyter=jupyter, validation=False)
|
|
|
|
|
|
+ def output_image_gen_visuals(self, md:ImageData, model:GeneratorModule, iter_count:int, tbwriter:SummaryWriter):
|
|
|
|
+ self._output_visuals(ds=md.val_ds, model=model, iter_count=iter_count, tbwriter=tbwriter, validation=True)
|
|
|
|
+ self._output_visuals(ds=md.trn_ds, model=model, iter_count=iter_count, tbwriter=tbwriter, validation=False)
|
|
|
|
|
|
- def _output_visuals(self, ds:FilesDataset, model:nn.Module, iter_count:int, tbwriter:SummaryWriter,
|
|
|
|
- validation:bool, jupyter:bool=False):
|
|
|
|
|
|
+ def _output_visuals(self, ds:FilesDataset, model:GeneratorModule, iter_count:int, tbwriter:SummaryWriter, validation:bool):
|
|
#TODO: Parameterize these
|
|
#TODO: Parameterize these
|
|
start_idx=0
|
|
start_idx=0
|
|
count = 8
|
|
count = 8
|
|
@@ -184,8 +146,6 @@ class ImageGenVisualizer():
|
|
idxs = list(range(start_idx,end_index))
|
|
idxs = list(range(start_idx,end_index))
|
|
image_sets = ModelImageSet.get_list_from_model(ds=ds, model=model, idxs=idxs)
|
|
image_sets = ModelImageSet.get_list_from_model(ds=ds, model=model, idxs=idxs)
|
|
self._write_tensorboard_images(image_sets=image_sets, iter_count=iter_count, tbwriter=tbwriter, validation=validation)
|
|
self._write_tensorboard_images(image_sets=image_sets, iter_count=iter_count, tbwriter=tbwriter, validation=validation)
|
|
- if jupyter:
|
|
|
|
- self._show_images_in_jupyter(image_sets, validation=validation)
|
|
|
|
|
|
|
|
def _write_tensorboard_images(self, image_sets:[ModelImageSet], iter_count:int, tbwriter:SummaryWriter, validation:bool):
|
|
def _write_tensorboard_images(self, image_sets:[ModelImageSet], iter_count:int, tbwriter:SummaryWriter, validation:bool):
|
|
orig_images = []
|
|
orig_images = []
|
|
@@ -204,15 +164,6 @@ class ImageGenVisualizer():
|
|
tbwriter.add_image(prefix + ' real images', vutils.make_grid(real_images, normalize=True), iter_count)
|
|
tbwriter.add_image(prefix + ' real images', vutils.make_grid(real_images, normalize=True), iter_count)
|
|
|
|
|
|
|
|
|
|
- def _show_images_in_jupyter(self, image_sets:[ModelImageSet], validation:bool):
|
|
|
|
- #TODO: Parameterize these
|
|
|
|
- figsize=(20,20)
|
|
|
|
- max_columns=4
|
|
|
|
- immediate_display=True
|
|
|
|
- self.model_vis.plot_images_from_image_sets(image_sets, figsize=figsize, max_columns=max_columns,
|
|
|
|
- immediate_display=immediate_display, validation=validation)
|
|
|
|
-
|
|
|
|
-
|
|
|
|
class GANTrainerStatsVisualizer():
|
|
class GANTrainerStatsVisualizer():
|
|
def __init__(self):
|
|
def __init__(self):
|
|
return
|
|
return
|