123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229 |
- from fastai.core import *
- from fastai.vision import *
- from matplotlib.axes import Axes
- from matplotlib.figure import Figure
- from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
- from .filters import IFilter, MasterFilter, ColorizerFilter
- from .generators import gen_inference_deep, gen_inference_wide
- from tensorboardX import SummaryWriter
- from scipy import misc
- from PIL import Image
- import ffmpeg
- import youtube_dl
- import gc
- import requests
- from io import BytesIO
- import base64
- from IPython import display as ipythondisplay
- from IPython.display import HTML
- from IPython.display import Image as ipythonimage
- class ModelImageVisualizer():
- def __init__(self, filter:IFilter, results_dir:str=None):
- self.filter = filter
- self.results_dir=None if results_dir is None else Path(results_dir)
- self.results_dir.mkdir(parents=True, exist_ok=True)
-
- def _clean_mem(self):
- torch.cuda.empty_cache()
- #gc.collect()
- def _open_pil_image(self, path:Path)->Image:
- return PIL.Image.open(path).convert('RGB')
- def _get_image_from_url(self, url:str)->Image:
- response = requests.get(url)
- img = PIL.Image.open(BytesIO(response.content)).convert('RGB')
- return img
- def plot_transformed_image_from_url(self, url:str, path:str='test_images/image.png', figsize:(int,int)=(20,20),
- render_factor:int=None, display_render_factor:bool=False, compare:bool=False)->Path:
- img = self._get_image_from_url(url)
- img.save(path)
- return self.plot_transformed_image(path=path, figsize=figsize, render_factor=render_factor,
- display_render_factor=display_render_factor, compare=compare)
- def plot_transformed_image(self, path:str, figsize:(int,int)=(20,20), render_factor:int=None,
- display_render_factor:bool=False, compare:bool=False)->Path:
- path = Path(path)
- result = self.get_transformed_image(path, render_factor)
- orig = self._open_pil_image(path)
- if compare:
- self._plot_comparison(figsize, render_factor, display_render_factor, orig, result)
- else:
- self._plot_solo(figsize, render_factor, display_render_factor, result)
- return self._save_result_image(path, result)
- def _plot_comparison(self, figsize:(int,int), render_factor:int, display_render_factor:bool, orig:Image, result:Image):
- fig,axes = plt.subplots(1, 2, figsize=figsize)
- self._plot_image(orig, axes=axes[0], figsize=figsize, render_factor=render_factor, display_render_factor=False)
- self._plot_image(result, axes=axes[1], figsize=figsize, render_factor=render_factor, display_render_factor=display_render_factor)
-
- def _plot_solo(self, figsize:(int,int), render_factor:int, display_render_factor:bool, result:Image):
- fig,axes = plt.subplots(1, 1, figsize=figsize)
- self._plot_image(result, axes=axes, figsize=figsize, render_factor=render_factor, display_render_factor=display_render_factor)
- def _save_result_image(self, source_path:Path, image:Image)->Path:
- result_path = self.results_dir/source_path.name
- image.save(result_path)
- return result_path
- def get_transformed_image(self, path:Path, render_factor:int=None)->Image:
- self._clean_mem()
- orig_image = self._open_pil_image(path)
- filtered_image = self.filter.filter(orig_image, orig_image, render_factor=render_factor)
- return filtered_image
- def _plot_image(self, image:Image, render_factor:int, axes:Axes=None, figsize=(20,20), display_render_factor:bool=False):
- if axes is None:
- _,axes = plt.subplots(figsize=figsize)
- axes.imshow(np.asarray(image)/255)
- axes.axis('off')
- if render_factor is not None and display_render_factor:
- plt.text(10,10,'render_factor: ' + str(render_factor), color='white', backgroundcolor='black')
- def _get_num_rows_columns(self, num_images:int, max_columns:int)->(int,int):
- columns = min(num_images, max_columns)
- rows = num_images//columns
- rows = rows if rows * columns == num_images else rows + 1
- return rows, columns
- class VideoColorizer():
- def __init__(self, vis:ModelImageVisualizer):
- self.vis=vis
- workfolder = Path('./video')
- self.source_folder = workfolder/"source"
- self.bwframes_root = workfolder/"bwframes"
- self.audio_root = workfolder/"audio"
- self.colorframes_root = workfolder/"colorframes"
- self.result_folder = workfolder/"result"
- def _purge_images(self, dir):
- for f in os.listdir(dir):
- if re.search('.*?\.jpg', f):
- os.remove(os.path.join(dir, f))
- def _get_fps(self, source_path: Path)->str:
- probe = ffmpeg.probe(str(source_path))
- stream_data = next((stream for stream in probe['streams'] if stream['codec_type'] == 'video'), None)
- return stream_data['avg_frame_rate']
- def _download_video_from_url(self, source_url, source_path:Path):
- if source_path.exists(): source_path.unlink()
- ydl_opts = {
- 'format': 'bestvideo[ext=mp4]+bestaudio[ext=m4a]/mp4',
- 'outtmpl': str(source_path)
- }
- with youtube_dl.YoutubeDL(ydl_opts) as ydl:
- ydl.download([source_url])
- def _extract_raw_frames(self, source_path:Path):
- bwframes_folder = self.bwframes_root/(source_path.stem)
- bwframe_path_template = str(bwframes_folder/'%5d.jpg')
- bwframes_folder.mkdir(parents=True, exist_ok=True)
- self._purge_images(bwframes_folder)
- ffmpeg.input(str(source_path)).output(str(bwframe_path_template), format='image2', vcodec='mjpeg', qscale=0).run(capture_stdout=True)
- def _colorize_raw_frames(self, source_path:Path, render_factor:int=None):
- colorframes_folder = self.colorframes_root/(source_path.stem)
- colorframes_folder.mkdir(parents=True, exist_ok=True)
- self._purge_images(colorframes_folder)
- bwframes_folder = self.bwframes_root/(source_path.stem)
- for img in progress_bar(os.listdir(str(bwframes_folder))):
- img_path = bwframes_folder/img
- if os.path.isfile(str(img_path)):
- color_image = self.vis.get_transformed_image(str(img_path), render_factor=render_factor)
- color_image.save(str(colorframes_folder/img))
-
- def _build_video(self, source_path:Path)->Path:
- colorized_path = self.result_folder/(source_path.name.replace('.mp4', '_no_audio.mp4'))
- colorframes_folder = self.colorframes_root/(source_path.stem)
- colorframes_path_template = str(colorframes_folder/'%5d.jpg')
- colorized_path.parent.mkdir(parents=True, exist_ok=True)
- if colorized_path.exists(): colorized_path.unlink()
- fps = self._get_fps(source_path)
- ffmpeg.input(str(colorframes_path_template), format='image2', vcodec='mjpeg', framerate=fps) \
- .output(str(colorized_path), crf=17, vcodec='libx264') \
- .run(capture_stdout=True)
- result_path = self.result_folder/source_path.name
- if result_path.exists(): result_path.unlink()
- #making copy of non-audio version in case adding back audio doesn't apply or fails.
- shutil.copyfile(str(colorized_path), str(result_path))
- # adding back sound here
- audio_file = Path(str(source_path).replace('.mp4', '.aac'))
- if audio_file.exists(): audio_file.unlink()
- os.system('ffmpeg -y -i "' + str(source_path) + '" -vn -acodec copy "' + str(audio_file) + '"')
- if audio_file.exists:
- os.system('ffmpeg -y -i "' + str(colorized_path) + '" -i "' + str(audio_file)
- + '" -shortest -c:v copy -c:a aac -b:a 256k "' + str(result_path) + '"')
- print('Video created here: ' + str(result_path))
- return result_path
- def colorize_from_url(self, source_url, file_name:str, render_factor:int=None)->Path:
- source_path = self.source_folder/file_name
- self._download_video_from_url(source_url, source_path)
- return self._colorize_from_path(source_path, render_factor=render_factor)
- def colorize_from_file_name(self, file_name:str, render_factor:int=None)->Path:
- source_path = self.source_folder/file_name
- return self._colorize_from_path(source_path, render_factor=render_factor)
- def _colorize_from_path(self, source_path:Path, render_factor:int=None)->Path:
- if not source_path.exists():
- raise Exception('Video at path specfied, ' + str(source_path) + ' could not be found.')
- self._extract_raw_frames(source_path)
- self._colorize_raw_frames(source_path, render_factor=render_factor)
- return self._build_video(source_path)
- def get_video_colorizer(render_factor:int=21)->VideoColorizer:
- return get_stable_video_colorizer(render_factor=render_factor)
- def get_stable_video_colorizer(root_folder:Path=Path('./'), weights_name:str='ColorizeVideo_gen',
- results_dir='result_images', render_factor:int=21)->VideoColorizer:
- learn = gen_inference_wide(root_folder=root_folder, weights_name=weights_name)
- filtr = MasterFilter([ColorizerFilter(learn=learn)], render_factor=render_factor)
- vis = ModelImageVisualizer(filtr, results_dir=results_dir)
- return VideoColorizer(vis)
- def get_image_colorizer(render_factor:int=35, artistic:bool=True)->ModelImageVisualizer:
- if artistic:
- return get_artistic_image_colorizer(render_factor=render_factor)
- else:
- return get_stable_image_colorizer(render_factor=render_factor)
- def get_stable_image_colorizer(root_folder:Path=Path('./'), weights_name:str='ColorizeStable_gen',
- results_dir='result_images', render_factor:int=35)->ModelImageVisualizer:
- learn = gen_inference_wide(root_folder=root_folder, weights_name=weights_name)
- filtr = MasterFilter([ColorizerFilter(learn=learn)], render_factor=render_factor)
- vis = ModelImageVisualizer(filtr, results_dir=results_dir)
- return vis
- def get_artistic_image_colorizer(root_folder:Path=Path('./'), weights_name:str='ColorizeArtistic_gen',
- results_dir='result_images', render_factor:int=35)->ModelImageVisualizer:
- learn = gen_inference_deep(root_folder=root_folder, weights_name=weights_name)
- filtr = MasterFilter([ColorizerFilter(learn=learn)], render_factor=render_factor)
- vis = ModelImageVisualizer(filtr, results_dir=results_dir)
- return vis
- def show_image_in_notebook(image_path:Path):
- ipythondisplay.display(ipythonimage(str(image_path)))
- def show_video_in_notebook(video_path:Path):
- video = io.open(video_path, 'r+b').read()
- encoded = base64.b64encode(video)
- ipythondisplay.display(HTML(data='''<video alt="test" autoplay
- loop controls style="height: 400px;">
- <source src="data:video/mp4;base64,{0}" type="video/mp4" />
- </video>'''.format(encoded.decode('ascii'))))
|