|
@@ -4,7 +4,7 @@ from fastai.core import *
|
|
|
from matplotlib.axes import Axes
|
|
|
from fastai.dataset import FilesDataset, ImageData, ModelData, open_image
|
|
|
from fastai.transforms import Transform, scale_min, tfms_from_stats, inception_stats
|
|
|
-from fastai.transforms import CropType, NoCrop
|
|
|
+from fastai.transforms import CropType, NoCrop, Denormalize
|
|
|
from fasterai.training import GenResult, CriticResult, GANTrainer
|
|
|
from fasterai.images import ModelImageSet, EasyTensorImage
|
|
|
from IPython.display import display
|
|
@@ -16,10 +16,11 @@ import statistics
|
|
|
class ModelImageVisualizer():
|
|
|
def __init__(self, default_sz:int=500):
|
|
|
self.default_sz=default_sz
|
|
|
+ self.denorm = Denormalize(*inception_stats)
|
|
|
|
|
|
- def plot_transformed_image(self, path:Path, model:nn.Module, ds:FilesDataset, figsize:(int,int)=(20,20), sz:int=None,
|
|
|
+ def plot_transformed_image(self, path:Path, model:nn.Module, figsize:(int,int)=(20,20), sz:int=None,
|
|
|
tfms:[Transform]=[], compare:bool=True):
|
|
|
- result = self.get_transformed_image_ndarray(path, model,ds, sz, tfms=tfms)
|
|
|
+ result = self.get_transformed_image_ndarray(path, model, sz, tfms=tfms)
|
|
|
if compare:
|
|
|
orig = open_image(str(path))
|
|
|
fig,axes = plt.subplots(1, 2, figsize=figsize)
|
|
@@ -28,17 +29,21 @@ class ModelImageVisualizer():
|
|
|
else:
|
|
|
self.plot_image_from_ndarray(result, figsize=figsize)
|
|
|
|
|
|
- def get_transformed_image_ndarray(self, path:Path, model:nn.Module, ds:FilesDataset, sz:int=None, tfms:[Transform]=[]):
|
|
|
+ def get_transformed_image_ndarray(self, path:Path, model:nn.Module, sz:int=None, tfms:[Transform]=[]):
|
|
|
training = model.training
|
|
|
model.eval()
|
|
|
- orig = self.get_model_ready_image_ndarray(path, model, ds, sz, tfms)
|
|
|
+ orig = self.get_model_ready_image_ndarray(path, model, sz, tfms)
|
|
|
orig = VV(orig[None])
|
|
|
result = model(orig).detach().cpu().numpy()
|
|
|
- result = ds.denorm(result)
|
|
|
+ result = self._denorm(result)
|
|
|
if training:
|
|
|
model.train()
|
|
|
return result[0]
|
|
|
|
|
|
+ def _denorm(self, image: ndarray):
|
|
|
+ if len(image.shape)==3: arr = arr[None]
|
|
|
+ return self.denorm(np.rollaxis(image,1,4))
|
|
|
+
|
|
|
def _transform(self, orig:ndarray, tfms:[Transform], model:nn.Module, sz:int):
|
|
|
for tfm in tfms:
|
|
|
orig,_=tfm(orig, False)
|
|
@@ -47,7 +52,7 @@ class ModelImageVisualizer():
|
|
|
orig = val_tfms(orig)
|
|
|
return orig
|
|
|
|
|
|
- def get_model_ready_image_ndarray(self, path:Path, model:nn.Module, ds:FilesDataset, sz:int=None, tfms:[Transform]=[]):
|
|
|
+ def get_model_ready_image_ndarray(self, path:Path, model:nn.Module, sz:int=None, tfms:[Transform]=[]):
|
|
|
im = open_image(str(path))
|
|
|
sz = self.default_sz if sz is None else sz
|
|
|
im = scale_min(im, sz)
|