浏览代码

Getting rid of dependence on imagenet images for visualizations

They were used just to get a ds object which has the denorm function.  I just made my own denorm instead.  Git 'r done.
Jason Antic 6 年之前
父节点
当前提交
addc583174
共有 4 个文件被更改,包括 887 次插入163 次删除
  1. 758 33
      ColorizeVisualization.ipynb
  2. 111 114
      ComboVisualization.ipynb
  3. 6 9
      DeFadeVisualization.ipynb
  4. 12 7
      fasterai/visualize.py

文件差异内容过多而无法显示
+ 758 - 33
ColorizeVisualization.ipynb


文件差异内容过多而无法显示
+ 111 - 114
ComboVisualization.ipynb


+ 6 - 9
DeFadeVisualization.ipynb

@@ -47,7 +47,6 @@
    "outputs": [],
    "source": [
     "IMAGENET = Path('data/imagenet/ILSVRC/Data/CLS-LOC/train')\n",
-    "IMAGENET_SMALL = IMAGENET/'n01440764'\n",
     "gpath = IMAGENET.parent/('defade_rc_gen_256.h5')\n",
     "default_sz=400\n",
     "torch.backends.cudnn.benchmark=True"
@@ -70,9 +69,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "x_tfms = []\n",
-    "data_loader = ImageGenDataLoader(sz=256, bs=8, path=IMAGENET_SMALL, random_seed=42, keep_pct=1.0, x_tfms=x_tfms)\n",
-    "md = data_loader.get_model_data()"
+    "x_tfms = []"
    ]
   },
   {
@@ -90,7 +87,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "vis.plot_transformed_image(\"test_images/FadedOvermiller.PNG\", netG, md.val_ds, tfms=x_tfms)"
+    "vis.plot_transformed_image(\"test_images/FadedOvermiller.PNG\", netG, tfms=x_tfms)"
    ]
   },
   {
@@ -99,7 +96,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "vis.plot_transformed_image(\"test_images/FadedSphynx.PNG\", netG, md.val_ds, tfms=x_tfms, sz=500)"
+    "vis.plot_transformed_image(\"test_images/FadedSphynx.PNG\", netG, tfms=x_tfms, sz=500)"
    ]
   },
   {
@@ -108,7 +105,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "vis.plot_transformed_image(\"test_images/FadedRacket.PNG\", netG, md.val_ds, tfms=x_tfms)"
+    "vis.plot_transformed_image(\"test_images/FadedRacket.PNG\", netG, tfms=x_tfms)"
    ]
   },
   {
@@ -117,7 +114,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "vis.plot_transformed_image(\"test_images/FadedDutchBabies.PNG\", netG, md.val_ds, tfms=x_tfms, sz=500)"
+    "vis.plot_transformed_image(\"test_images/FadedDutchBabies.PNG\", netG, tfms=x_tfms, sz=500)"
    ]
   },
   {
@@ -126,7 +123,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "vis.plot_transformed_image(\"test_images/FadedDelores.PNG\", netG, md.val_ds, tfms=x_tfms, sz=500)"
+    "vis.plot_transformed_image(\"test_images/FadedDelores.PNG\", netG, tfms=x_tfms, sz=500)"
    ]
   },
   {

+ 12 - 7
fasterai/visualize.py

@@ -4,7 +4,7 @@ from fastai.core import *
 from matplotlib.axes import Axes
 from fastai.dataset import FilesDataset, ImageData, ModelData, open_image
 from fastai.transforms import Transform, scale_min, tfms_from_stats, inception_stats
-from fastai.transforms import CropType, NoCrop
+from fastai.transforms import CropType, NoCrop, Denormalize
 from fasterai.training import GenResult, CriticResult, GANTrainer
 from fasterai.images import ModelImageSet, EasyTensorImage
 from IPython.display import display
@@ -16,10 +16,11 @@ import statistics
 class ModelImageVisualizer():
     def __init__(self, default_sz:int=500):
         self.default_sz=default_sz 
+        self.denorm = Denormalize(*inception_stats) 
 
-    def plot_transformed_image(self, path:Path, model:nn.Module, ds:FilesDataset, figsize:(int,int)=(20,20), sz:int=None, 
+    def plot_transformed_image(self, path:Path, model:nn.Module, figsize:(int,int)=(20,20), sz:int=None, 
             tfms:[Transform]=[], compare:bool=True):
-        result = self.get_transformed_image_ndarray(path, model,ds, sz, tfms=tfms)
+        result = self.get_transformed_image_ndarray(path, model, sz, tfms=tfms)
         if compare: 
             orig = open_image(str(path))
             fig,axes = plt.subplots(1, 2, figsize=figsize)
@@ -28,17 +29,21 @@ class ModelImageVisualizer():
         else:
             self.plot_image_from_ndarray(result, figsize=figsize)
 
-    def get_transformed_image_ndarray(self, path:Path, model:nn.Module, ds:FilesDataset, sz:int=None, tfms:[Transform]=[]):
+    def get_transformed_image_ndarray(self, path:Path, model:nn.Module, sz:int=None, tfms:[Transform]=[]):
         training = model.training 
         model.eval()
-        orig = self.get_model_ready_image_ndarray(path, model, ds, sz, tfms)
+        orig = self.get_model_ready_image_ndarray(path, model, sz, tfms)
         orig = VV(orig[None])
         result = model(orig).detach().cpu().numpy()
-        result = ds.denorm(result)
+        result = self._denorm(result)
         if training:
             model.train()
         return result[0]
 
+    def _denorm(self, image: ndarray):
+        if len(image.shape)==3: arr = arr[None]
+        return self.denorm(np.rollaxis(image,1,4))
+
     def _transform(self, orig:ndarray, tfms:[Transform], model:nn.Module, sz:int):
         for tfm in tfms:
             orig,_=tfm(orig, False)
@@ -47,7 +52,7 @@ class ModelImageVisualizer():
         orig = val_tfms(orig)
         return orig
 
-    def get_model_ready_image_ndarray(self, path:Path, model:nn.Module, ds:FilesDataset, sz:int=None, tfms:[Transform]=[]):
+    def get_model_ready_image_ndarray(self, path:Path, model:nn.Module, sz:int=None, tfms:[Transform]=[]):
         im = open_image(str(path))
         sz = self.default_sz if sz is None else sz
         im = scale_min(im, sz)

部分文件因为文件数量过多而无法显示