visualize.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. from fastai.core import *
  2. from fastai.vision import *
  3. from matplotlib.axes import Axes
  4. from matplotlib.figure import Figure
  5. from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
  6. from .filters import IFilter, MasterFilter, ColorizerFilter
  7. from .generators import gen_inference_deep, gen_inference_wide
  8. from IPython.display import display
  9. from tensorboardX import SummaryWriter
  10. from scipy import misc
  11. from PIL import Image
  12. import ffmpeg
  13. import youtube_dl
  14. import gc
  15. class ModelImageVisualizer():
  16. def __init__(self, filter:IFilter, results_dir:str=None):
  17. self.filter = filter
  18. self.results_dir=None if results_dir is None else Path(results_dir)
  19. def _clean_mem(self):
  20. return
  21. #torch.cuda.empty_cache()
  22. #gc.collect()
  23. def _open_pil_image(self, path:Path)->Image:
  24. return PIL.Image.open(path).convert('RGB')
  25. def plot_transformed_image(self, path:str, figsize:(int,int)=(20,20), render_factor:int=None)->Image:
  26. path = Path(path)
  27. result = self.get_transformed_image(path, render_factor)
  28. orig = self._open_pil_image(path)
  29. fig,axes = plt.subplots(1, 2, figsize=figsize)
  30. self._plot_image(orig, axes=axes[0], figsize=figsize)
  31. self._plot_image(result, axes=axes[1], figsize=figsize)
  32. if self.results_dir is not None:
  33. self._save_result_image(path, result)
  34. def _save_result_image(self, source_path:Path, image:Image):
  35. result_path = self.results_dir/source_path.name
  36. image.save(result_path)
  37. def get_transformed_image(self, path:Path, render_factor:int=None)->Image:
  38. self._clean_mem()
  39. orig_image = self._open_pil_image(path)
  40. filtered_image = self.filter.filter(orig_image, orig_image, render_factor=render_factor)
  41. return filtered_image
  42. def _plot_image(self, image:Image, axes:Axes=None, figsize=(20,20)):
  43. if axes is None:
  44. _,axes = plt.subplots(figsize=figsize)
  45. axes.imshow(np.asarray(image)/255)
  46. axes.axis('off')
  47. def _get_num_rows_columns(self, num_images:int, max_columns:int)->(int,int):
  48. columns = min(num_images, max_columns)
  49. rows = num_images//columns
  50. rows = rows if rows * columns == num_images else rows + 1
  51. return rows, columns
  52. class VideoColorizer():
  53. def __init__(self, vis:ModelImageVisualizer):
  54. self.vis=vis
  55. workfolder = Path('./video')
  56. self.source_folder = workfolder/"source"
  57. self.bwframes_root = workfolder/"bwframes"
  58. self.audio_root = workfolder/"audio"
  59. self.colorframes_root = workfolder/"colorframes"
  60. self.result_folder = workfolder/"result"
  61. def _purge_images(self, dir):
  62. for f in os.listdir(dir):
  63. if re.search('.*?\.jpg', f):
  64. os.remove(os.path.join(dir, f))
  65. def _get_fps(self, source_path: Path)->float:
  66. probe = ffmpeg.probe(str(source_path))
  67. stream_data = next((stream for stream in probe['streams'] if stream['codec_type'] == 'video'), None)
  68. avg_frame_rate = stream_data['avg_frame_rate']
  69. fps_num=avg_frame_rate.split("/")[0]
  70. fps_den = avg_frame_rate.rsplit("/")[1]
  71. return round(float(fps_num)/float(fps_den))
  72. def _download_video_from_url(self, source_url, source_path:Path):
  73. if source_path.exists(): source_path.unlink()
  74. ydl_opts = {
  75. 'format': 'bestvideo[ext=mp4]+bestaudio[ext=m4a]/mp4',
  76. 'outtmpl': str(source_path)
  77. }
  78. with youtube_dl.YoutubeDL(ydl_opts) as ydl:
  79. ydl.download([source_url])
  80. def _extract_raw_frames(self, source_path:Path):
  81. bwframes_folder = self.bwframes_root/(source_path.stem)
  82. bwframe_path_template = str(bwframes_folder/'%5d.jpg')
  83. bwframes_folder.mkdir(parents=True, exist_ok=True)
  84. self._purge_images(bwframes_folder)
  85. ffmpeg.input(str(source_path)).output(str(bwframe_path_template), format='image2', vcodec='mjpeg', qscale=0).run(capture_stdout=True)
  86. def _colorize_raw_frames(self, source_path:Path):
  87. colorframes_folder = self.colorframes_root/(source_path.stem)
  88. colorframes_folder.mkdir(parents=True, exist_ok=True)
  89. self._purge_images(colorframes_folder)
  90. bwframes_folder = self.bwframes_root/(source_path.stem)
  91. for img in progress_bar(os.listdir(str(bwframes_folder))):
  92. img_path = bwframes_folder/img
  93. if os.path.isfile(str(img_path)):
  94. color_image = self.vis.get_transformed_image(str(img_path))
  95. color_image.save(str(colorframes_folder/img))
  96. def _build_video(self, source_path:Path):
  97. result_path = self.result_folder/source_path.name
  98. colorframes_folder = self.colorframes_root/(source_path.stem)
  99. colorframes_path_template = str(colorframes_folder/'%5d.jpg')
  100. result_path.parent.mkdir(parents=True, exist_ok=True)
  101. if result_path.exists(): result_path.unlink()
  102. fps = self._get_fps(source_path)
  103. ffmpeg.input(str(colorframes_path_template), format='image2', vcodec='mjpeg', framerate=str(fps)) \
  104. .output(str(result_path), crf=17, vcodec='libx264') \
  105. .run(capture_stdout=True)
  106. print('Video created here: ' + str(result_path))
  107. def colorize_from_url(self, source_url, file_name:str):
  108. source_path = self.source_folder/file_name
  109. self._download_video_from_url(source_url, source_path)
  110. self._colorize_from_path(source_path)
  111. def colorize_from_file_name(self, file_name:str):
  112. source_path = self.source_folder/file_name
  113. self._colorize_from_path(source_path)
  114. def _colorize_from_path(self, source_path:Path):
  115. self._extract_raw_frames(source_path)
  116. self._colorize_raw_frames(source_path)
  117. self._build_video(source_path)
  118. def get_video_colorizer(render_factor:int=36)->VideoColorizer:
  119. return get_stable_video_colorizer(render_factor=render_factor)
  120. def get_stable_video_colorizer(root_folder:Path=Path('./'), weights_name:str='ColorizeImagesStable_gen',
  121. results_dir='result_images', render_factor:int=36)->VideoColorizer:
  122. learn = gen_inference_wide(root_folder=root_folder, weights_name=weights_name)
  123. filtr = MasterFilter([ColorizerFilter(learn=learn)], render_factor=render_factor)
  124. vis = ModelImageVisualizer(filtr, results_dir=results_dir)
  125. return VideoColorizer(vis)
  126. def get_artistic_video_colorizer(root_folder:Path=Path('./'), weights_name:str='ColorizeImagesArtistic_gen',
  127. results_dir='result_images', render_factor:int=36)->VideoColorizer:
  128. learn = gen_inference_deep(root_folder=root_folder, weights_name=weights_name)
  129. filtr = MasterFilter([ColorizerFilter(learn=learn)], render_factor=render_factor)
  130. vis = ModelImageVisualizer(filtr, results_dir=results_dir)
  131. return VideoColorizer(vis)
  132. def get_image_colorizer(render_factor:int=36, artistic:bool=False)->ModelImageVisualizer:
  133. if artistic:
  134. return get_artistic_image_colorizer(render_factor=render_factor)
  135. else:
  136. return get_stable_image_colorizer(render_factor=render_factor)
  137. def get_stable_image_colorizer(root_folder:Path=Path('./'), weights_name:str='ColorizeImagesStable_gen',
  138. results_dir='result_images', render_factor:int=36)->ModelImageVisualizer:
  139. learn = gen_inference_wide(root_folder=root_folder, weights_name=weights_name)
  140. filtr = MasterFilter([ColorizerFilter(learn=learn)], render_factor=render_factor)
  141. vis = ModelImageVisualizer(filtr, results_dir=results_dir)
  142. return vis
  143. def get_artistic_image_colorizer(root_folder:Path=Path('./'), weights_name:str='ColorizeImagesArtistic_gen',
  144. results_dir='result_images', render_factor:int=36)->ModelImageVisualizer:
  145. learn = gen_inference_deep(root_folder=root_folder, weights_name=weights_name)
  146. filtr = MasterFilter([ColorizerFilter(learn=learn)], render_factor=render_factor)
  147. vis = ModelImageVisualizer(filtr, results_dir=results_dir)
  148. return vis
  149. def get_artistic_image_colorizer2(root_folder:Path=Path('./'), weights_name:str='ColorizeImagesArtistic2_gen',
  150. results_dir='result_images', render_factor:int=36)->ModelImageVisualizer:
  151. learn = gen_inference_deep(root_folder=root_folder, weights_name=weights_name)
  152. filtr = MasterFilter([ColorizerFilter(learn=learn)], render_factor=render_factor)
  153. vis = ModelImageVisualizer(filtr, results_dir=results_dir)
  154. return vis