فهرست منبع

Getting rid of dependence on imagenet images for visualizations

They were used just to get a ds object which has the denorm function.  I just made my own denorm instead.  Git 'r done.
Jason Antic 6 سال پیش
والد
کامیت
addc583174
4فایلهای تغییر یافته به همراه887 افزوده شده و 163 حذف شده
  1. 758 33
      ColorizeVisualization.ipynb
  2. 111 114
      ComboVisualization.ipynb
  3. 6 9
      DeFadeVisualization.ipynb
  4. 12 7
      fasterai/visualize.py

تفاوت فایلی نمایش داده نمی شود زیرا این فایل بسیار بزرگ است
+ 758 - 33
ColorizeVisualization.ipynb


تفاوت فایلی نمایش داده نمی شود زیرا این فایل بسیار بزرگ است
+ 111 - 114
ComboVisualization.ipynb


+ 6 - 9
DeFadeVisualization.ipynb

@@ -47,7 +47,6 @@
    "outputs": [],
    "outputs": [],
    "source": [
    "source": [
     "IMAGENET = Path('data/imagenet/ILSVRC/Data/CLS-LOC/train')\n",
     "IMAGENET = Path('data/imagenet/ILSVRC/Data/CLS-LOC/train')\n",
-    "IMAGENET_SMALL = IMAGENET/'n01440764'\n",
     "gpath = IMAGENET.parent/('defade_rc_gen_256.h5')\n",
     "gpath = IMAGENET.parent/('defade_rc_gen_256.h5')\n",
     "default_sz=400\n",
     "default_sz=400\n",
     "torch.backends.cudnn.benchmark=True"
     "torch.backends.cudnn.benchmark=True"
@@ -70,9 +69,7 @@
    "metadata": {},
    "metadata": {},
    "outputs": [],
    "outputs": [],
    "source": [
    "source": [
-    "x_tfms = []\n",
-    "data_loader = ImageGenDataLoader(sz=256, bs=8, path=IMAGENET_SMALL, random_seed=42, keep_pct=1.0, x_tfms=x_tfms)\n",
-    "md = data_loader.get_model_data()"
+    "x_tfms = []"
    ]
    ]
   },
   },
   {
   {
@@ -90,7 +87,7 @@
    "metadata": {},
    "metadata": {},
    "outputs": [],
    "outputs": [],
    "source": [
    "source": [
-    "vis.plot_transformed_image(\"test_images/FadedOvermiller.PNG\", netG, md.val_ds, tfms=x_tfms)"
+    "vis.plot_transformed_image(\"test_images/FadedOvermiller.PNG\", netG, tfms=x_tfms)"
    ]
    ]
   },
   },
   {
   {
@@ -99,7 +96,7 @@
    "metadata": {},
    "metadata": {},
    "outputs": [],
    "outputs": [],
    "source": [
    "source": [
-    "vis.plot_transformed_image(\"test_images/FadedSphynx.PNG\", netG, md.val_ds, tfms=x_tfms, sz=500)"
+    "vis.plot_transformed_image(\"test_images/FadedSphynx.PNG\", netG, tfms=x_tfms, sz=500)"
    ]
    ]
   },
   },
   {
   {
@@ -108,7 +105,7 @@
    "metadata": {},
    "metadata": {},
    "outputs": [],
    "outputs": [],
    "source": [
    "source": [
-    "vis.plot_transformed_image(\"test_images/FadedRacket.PNG\", netG, md.val_ds, tfms=x_tfms)"
+    "vis.plot_transformed_image(\"test_images/FadedRacket.PNG\", netG, tfms=x_tfms)"
    ]
    ]
   },
   },
   {
   {
@@ -117,7 +114,7 @@
    "metadata": {},
    "metadata": {},
    "outputs": [],
    "outputs": [],
    "source": [
    "source": [
-    "vis.plot_transformed_image(\"test_images/FadedDutchBabies.PNG\", netG, md.val_ds, tfms=x_tfms, sz=500)"
+    "vis.plot_transformed_image(\"test_images/FadedDutchBabies.PNG\", netG, tfms=x_tfms, sz=500)"
    ]
    ]
   },
   },
   {
   {
@@ -126,7 +123,7 @@
    "metadata": {},
    "metadata": {},
    "outputs": [],
    "outputs": [],
    "source": [
    "source": [
-    "vis.plot_transformed_image(\"test_images/FadedDelores.PNG\", netG, md.val_ds, tfms=x_tfms, sz=500)"
+    "vis.plot_transformed_image(\"test_images/FadedDelores.PNG\", netG, tfms=x_tfms, sz=500)"
    ]
    ]
   },
   },
   {
   {

+ 12 - 7
fasterai/visualize.py

@@ -4,7 +4,7 @@ from fastai.core import *
 from matplotlib.axes import Axes
 from matplotlib.axes import Axes
 from fastai.dataset import FilesDataset, ImageData, ModelData, open_image
 from fastai.dataset import FilesDataset, ImageData, ModelData, open_image
 from fastai.transforms import Transform, scale_min, tfms_from_stats, inception_stats
 from fastai.transforms import Transform, scale_min, tfms_from_stats, inception_stats
-from fastai.transforms import CropType, NoCrop
+from fastai.transforms import CropType, NoCrop, Denormalize
 from fasterai.training import GenResult, CriticResult, GANTrainer
 from fasterai.training import GenResult, CriticResult, GANTrainer
 from fasterai.images import ModelImageSet, EasyTensorImage
 from fasterai.images import ModelImageSet, EasyTensorImage
 from IPython.display import display
 from IPython.display import display
@@ -16,10 +16,11 @@ import statistics
 class ModelImageVisualizer():
 class ModelImageVisualizer():
     def __init__(self, default_sz:int=500):
     def __init__(self, default_sz:int=500):
         self.default_sz=default_sz 
         self.default_sz=default_sz 
+        self.denorm = Denormalize(*inception_stats) 
 
 
-    def plot_transformed_image(self, path:Path, model:nn.Module, ds:FilesDataset, figsize:(int,int)=(20,20), sz:int=None, 
+    def plot_transformed_image(self, path:Path, model:nn.Module, figsize:(int,int)=(20,20), sz:int=None, 
             tfms:[Transform]=[], compare:bool=True):
             tfms:[Transform]=[], compare:bool=True):
-        result = self.get_transformed_image_ndarray(path, model,ds, sz, tfms=tfms)
+        result = self.get_transformed_image_ndarray(path, model, sz, tfms=tfms)
         if compare: 
         if compare: 
             orig = open_image(str(path))
             orig = open_image(str(path))
             fig,axes = plt.subplots(1, 2, figsize=figsize)
             fig,axes = plt.subplots(1, 2, figsize=figsize)
@@ -28,17 +29,21 @@ class ModelImageVisualizer():
         else:
         else:
             self.plot_image_from_ndarray(result, figsize=figsize)
             self.plot_image_from_ndarray(result, figsize=figsize)
 
 
-    def get_transformed_image_ndarray(self, path:Path, model:nn.Module, ds:FilesDataset, sz:int=None, tfms:[Transform]=[]):
+    def get_transformed_image_ndarray(self, path:Path, model:nn.Module, sz:int=None, tfms:[Transform]=[]):
         training = model.training 
         training = model.training 
         model.eval()
         model.eval()
-        orig = self.get_model_ready_image_ndarray(path, model, ds, sz, tfms)
+        orig = self.get_model_ready_image_ndarray(path, model, sz, tfms)
         orig = VV(orig[None])
         orig = VV(orig[None])
         result = model(orig).detach().cpu().numpy()
         result = model(orig).detach().cpu().numpy()
-        result = ds.denorm(result)
+        result = self._denorm(result)
         if training:
         if training:
             model.train()
             model.train()
         return result[0]
         return result[0]
 
 
+    def _denorm(self, image: ndarray):
+        if len(image.shape)==3: arr = arr[None]
+        return self.denorm(np.rollaxis(image,1,4))
+
     def _transform(self, orig:ndarray, tfms:[Transform], model:nn.Module, sz:int):
     def _transform(self, orig:ndarray, tfms:[Transform], model:nn.Module, sz:int):
         for tfm in tfms:
         for tfm in tfms:
             orig,_=tfm(orig, False)
             orig,_=tfm(orig, False)
@@ -47,7 +52,7 @@ class ModelImageVisualizer():
         orig = val_tfms(orig)
         orig = val_tfms(orig)
         return orig
         return orig
 
 
-    def get_model_ready_image_ndarray(self, path:Path, model:nn.Module, ds:FilesDataset, sz:int=None, tfms:[Transform]=[]):
+    def get_model_ready_image_ndarray(self, path:Path, model:nn.Module, sz:int=None, tfms:[Transform]=[]):
         im = open_image(str(path))
         im = open_image(str(path))
         sz = self.default_sz if sz is None else sz
         sz = self.default_sz if sz is None else sz
         im = scale_min(im, sz)
         im = scale_min(im, sz)

برخی فایل ها در این مقایسه diff نمایش داده نمی شوند زیرا تعداد فایل ها بسیار زیاد است