Kaynağa Gözat

Updating visualization code to use "dummy data"

Jason Antic 6 yıl önce
ebeveyn
işleme
5e7ad7ba45
2 değiştirilmiş dosya ile 29 ekleme ve 25 silme
  1. 22 24
      ColorizeVisualization.ipynb
  2. 7 1
      fasterai/dataset.py

+ 22 - 24
ColorizeVisualization.ipynb

@@ -10,17 +10,6 @@
     "os.environ['CUDA_VISIBLE_DEVICES']='3' "
    ]
   },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "%matplotlib inline\n",
-    "%reload_ext autoreload\n",
-    "%autoreload 2"
-   ]
-  },
   {
    "cell_type": "code",
    "execution_count": null,
@@ -50,13 +39,11 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "#Dummy data path- shouldn't pull any images.\n",
-    "path = Path('./')\n",
-    "#The higher the render_factor, the more GPU memory will be used and generally images will look better.  \n",
-    "#11GB can take a factor of 42 max.  Performance generally gracefully degrades with lower factors, \n",
-    "#though you may also find that certain images will actually render better at lower numbers.  \n",
-    "#This tends to be the case with the oldest photos.\n",
-    "render_factor=16"
+    "#Adjust this if image doesn't look quite right (max 64 on 11GB GPU).  The default here works for most photos.  \n",
+    "#It literally just is a number multiplied by 16 to get the square render resolution.  \n",
+    "#Note that this doesn't affect the resolution of the final output- the output is the same resolution as the input.\n",
+    "#Example:  render_factor=21 => color is rendered at 16x21 = 336x336 px.  \n",
+    "render_factor=21"
    ]
   },
   {
@@ -67,7 +54,17 @@
    "source": [
     "def colorize_gen_learner_exp(data:ImageDataBunch, gen_loss=FeatureLoss(), arch=models.resnet34):\n",
     "    return unet_learner3(data, arch, wd=1e-3, blur=True, norm_type=NormType.Spectral,\n",
-    "                        self_attention=True, y_range=(-3.,3.), loss_func=gen_loss, nf_factor=1.5)"
+    "                        self_attention=True, y_range=(-3.,3.), loss_func=gen_loss)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "#TODO: Replace this with loading learner via exported learner.\n",
+    "data = get_dummy_databunch()"
    ]
   },
   {
@@ -76,10 +73,11 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "data = get_colorize_data(sz=128, bs=32, crappy_path=path, good_path=path, keep_pct=0.01)\n",
     "learn = colorize_gen_learner_exp(data=data)\n",
+    "#switch to read models from proper place\n",
     "learn.path = Path('data/imagenet/ILSVRC/Data/CLS-LOC/bandw')\n",
-    "learn.load('colorize3_gen_96')\n",
+    "learn.load('ColorizeNew3_gen224')\n",
+    "#learn.load('colorize1b_gen_224')\n",
     "learn.model.eval()\n",
     "filtr = MasterFilter([ColorizerFilter(learn=learn)], render_factor=render_factor)\n",
     "vis = ModelImageVisualizer(filtr, results_dir='result_images')"
@@ -2762,7 +2760,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "vis.plot_transformed_image(\"test_images//ParisLate1800s.jpg\", render_factor=40)"
+    "vis.plot_transformed_image(\"test_images/PaddingtonStationLondon1907.jpg\", render_factor=55)"
    ]
   },
   {
@@ -2771,8 +2769,8 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "for factor in range(20,66):\n",
-    "    vis.plot_transformed_image(\"test_images/1890sMedStudents.png\", render_factor=factor)"
+    "for factor in range(10,64):\n",
+    "    vis.plot_transformed_image(\"test_images/PaddingtonStationLondon1907.jpg\", render_factor=factor)"
    ]
   },
   {

+ 7 - 1
fasterai/dataset.py

@@ -14,8 +14,14 @@ def get_colorize_data(sz:int, bs:int, crappy_path:Path, good_path:Path, random_s
     data = (src.label_from_func(lambda x: good_path/x.relative_to(crappy_path))
         #TODO:  Revisit transforms used here....
         .transform(get_transforms(max_zoom=1.2, max_lighting=0.5, max_warp=0.25), size=sz, tfm_y=True)
-        .databunch(bs=bs, num_workers=num_workers)
+        .databunch(bs=bs, num_workers=num_workers, no_check=True)
         .normalize(imagenet_stats, do_y=True))
 
     data.c = 3
     return data
+
+
+
+def get_dummy_databunch()->ImageDataBunch:
+    path = Path('./dummy/')
+    return get_colorize_data(sz=1, bs=1, crappy_path=path, good_path=path, keep_pct=0.001)