visualize.py 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. from fastai.core import *
  2. from fastai.vision import *
  3. from matplotlib.axes import Axes
  4. from matplotlib.figure import Figure
  5. from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
  6. from .filters import IFilter
  7. from IPython.display import display
  8. from tensorboardX import SummaryWriter
  9. from scipy import misc
  10. from PIL import Image
  11. class ModelImageVisualizer():
  12. def __init__(self, filter:IFilter, results_dir:str=None):
  13. self.filter = filter
  14. self.results_dir=None if results_dir is None else Path(results_dir)
  15. def _open_pil_image(self, path:Path)->Image:
  16. return PIL.Image.open(path).convert('RGB')
  17. def plot_transformed_image(self, path:str, figsize:(int,int)=(20,20), render_factor:int=None)->Image:
  18. path = Path(path)
  19. result = self.get_transformed_image(path, render_factor)
  20. orig = self._open_pil_image(path)
  21. fig,axes = plt.subplots(1, 2, figsize=figsize)
  22. self._plot_image(orig, axes=axes[0], figsize=figsize)
  23. self._plot_image(result, axes=axes[1], figsize=figsize)
  24. if self.results_dir is not None:
  25. self._save_result_image(path, result)
  26. def _save_result_image(self, source_path:Path, image:Image):
  27. result_path = self.results_dir/source_path.name
  28. image.save(result_path)
  29. def get_transformed_image(self, path:Path, render_factor:int=None)->Image:
  30. orig_image = self._open_pil_image(path)
  31. filtered_image = self.filter.filter(orig_image, orig_image, render_factor=render_factor)
  32. return filtered_image
  33. def _plot_image(self, image:Image, axes:Axes=None, figsize=(20,20)):
  34. if axes is None:
  35. _,axes = plt.subplots(figsize=figsize)
  36. axes.imshow(np.asarray(image)/255)
  37. axes.axis('off')
  38. def _get_num_rows_columns(self, num_images:int, max_columns:int)->(int,int):
  39. columns = min(num_images, max_columns)
  40. rows = num_images//columns
  41. rows = rows if rows * columns == num_images else rows + 1
  42. return rows, columns