|
@@ -25,7 +25,6 @@ class ModelImageVisualizer():
|
|
|
self.results_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
def _clean_mem(self):
|
|
|
- return
|
|
|
torch.cuda.empty_cache()
|
|
|
#gc.collect()
|
|
|
|
|
@@ -38,20 +37,32 @@ class ModelImageVisualizer():
|
|
|
return img
|
|
|
|
|
|
def plot_transformed_image_from_url(self, url:str, path:str='test_images/image.png', figsize:(int,int)=(20,20),
|
|
|
- render_factor:int=None, display_render_factor:bool=False)->Path:
|
|
|
+ render_factor:int=None, display_render_factor:bool=False, compare:bool=False)->Path:
|
|
|
img = self._get_image_from_url(url)
|
|
|
img.save(path)
|
|
|
return self.plot_transformed_image(path=path, figsize=figsize, render_factor=render_factor,
|
|
|
- display_render_factor=display_render_factor)
|
|
|
+ display_render_factor=display_render_factor, compare=compare)
|
|
|
|
|
|
- def plot_transformed_image(self, path:str, figsize:(int,int)=(20,20), render_factor:int=None, display_render_factor:bool=False)->Path:
|
|
|
+ def plot_transformed_image(self, path:str, figsize:(int,int)=(20,20), render_factor:int=None,
|
|
|
+ display_render_factor:bool=False, compare:bool=False)->Path:
|
|
|
path = Path(path)
|
|
|
result = self.get_transformed_image(path, render_factor)
|
|
|
orig = self._open_pil_image(path)
|
|
|
+ if compare:
|
|
|
+ self._plot_comparison(figsize, render_factor, display_render_factor, orig, result)
|
|
|
+ else:
|
|
|
+ self._plot_solo(figsize, render_factor, display_render_factor, result)
|
|
|
+
|
|
|
+ return self._save_result_image(path, result)
|
|
|
+
|
|
|
+ def _plot_comparison(self, figsize:(int,int), render_factor:int, display_render_factor:bool, orig:Image, result:Image):
|
|
|
fig,axes = plt.subplots(1, 2, figsize=figsize)
|
|
|
self._plot_image(orig, axes=axes[0], figsize=figsize, render_factor=render_factor, display_render_factor=False)
|
|
|
self._plot_image(result, axes=axes[1], figsize=figsize, render_factor=render_factor, display_render_factor=display_render_factor)
|
|
|
- return self._save_result_image(path, result)
|
|
|
+
|
|
|
+ def _plot_solo(self, figsize:(int,int), render_factor:int, display_render_factor:bool, result:Image):
|
|
|
+ fig,axes = plt.subplots(1, 1, figsize=figsize)
|
|
|
+ self._plot_image(result, axes=axes, figsize=figsize, render_factor=render_factor, display_render_factor=display_render_factor)
|
|
|
|
|
|
def _save_result_image(self, source_path:Path, image:Image)->Path:
|
|
|
result_path = self.results_dir/source_path.name
|