visualize.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. from fastai.core import *
  2. from fastai.vision import *
  3. from matplotlib.axes import Axes
  4. from matplotlib.figure import Figure
  5. from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
  6. from .filters import IFilter, MasterFilter, ColorizerFilter
  7. from .generators import colorize_gen_inference, colorize_gen_inference2
  8. from IPython.display import display
  9. from tensorboardX import SummaryWriter
  10. from scipy import misc
  11. from PIL import Image
  12. class ModelImageVisualizer():
  13. def __init__(self, filter:IFilter, results_dir:str=None):
  14. self.filter = filter
  15. self.results_dir=None if results_dir is None else Path(results_dir)
  16. def _open_pil_image(self, path:Path)->Image:
  17. return PIL.Image.open(path).convert('RGB')
  18. def plot_transformed_image(self, path:str, figsize:(int,int)=(20,20), render_factor:int=None)->Image:
  19. path = Path(path)
  20. result = self.get_transformed_image(path, render_factor)
  21. orig = self._open_pil_image(path)
  22. fig,axes = plt.subplots(1, 2, figsize=figsize)
  23. self._plot_image(orig, axes=axes[0], figsize=figsize)
  24. self._plot_image(result, axes=axes[1], figsize=figsize)
  25. if self.results_dir is not None:
  26. self._save_result_image(path, result)
  27. def _save_result_image(self, source_path:Path, image:Image):
  28. result_path = self.results_dir/source_path.name
  29. image.save(result_path)
  30. def get_transformed_image(self, path:Path, render_factor:int=None)->Image:
  31. orig_image = self._open_pil_image(path)
  32. filtered_image = self.filter.filter(orig_image, orig_image, render_factor=render_factor)
  33. return filtered_image
  34. def _plot_image(self, image:Image, axes:Axes=None, figsize=(20,20)):
  35. if axes is None:
  36. _,axes = plt.subplots(figsize=figsize)
  37. axes.imshow(np.asarray(image)/255)
  38. axes.axis('off')
  39. def _get_num_rows_columns(self, num_images:int, max_columns:int)->(int,int):
  40. columns = min(num_images, max_columns)
  41. rows = num_images//columns
  42. rows = rows if rows * columns == num_images else rows + 1
  43. return rows, columns
  44. def get_colorize_visualizer(root_folder:Path=Path('./'), weights_name:str='colorize_gen',
  45. results_dir = 'result_images', nf_factor:float=1.25, render_factor:int=21)->ModelImageVisualizer:
  46. learn = colorize_gen_inference(root_folder=root_folder, weights_name=weights_name, nf_factor=nf_factor)
  47. filtr = MasterFilter([ColorizerFilter(learn=learn)], render_factor=render_factor)
  48. vis = ModelImageVisualizer(filtr, results_dir=results_dir)
  49. return vis
  50. def get_colorize_visualizer2(root_folder:Path=Path('./'), weights_name:str='colorize_gen',
  51. results_dir = 'result_images', nf_factor:int=1, render_factor:int=21)->ModelImageVisualizer:
  52. learn = colorize_gen_inference2(root_folder=root_folder, weights_name=weights_name, nf_factor=nf_factor)
  53. filtr = MasterFilter([ColorizerFilter(learn=learn)], render_factor=render_factor)
  54. vis = ModelImageVisualizer(filtr, results_dir=results_dir)
  55. return vis