123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120 |
- from numpy import ndarray
- from abc import ABC, abstractmethod
- from .critics import colorize_crit_learner
- from fastai.core import *
- from fastai.vision import *
- from fastai.vision.image import *
- from fastai.vision.data import *
- from fastai import *
- import math
- from scipy import misc
- import cv2
- from PIL import Image as PilImage
- class IFilter(ABC):
- @abstractmethod
- def filter(
- self, orig_image: PilImage, filtered_image: PilImage, render_factor: int
- ) -> PilImage:
- pass
- class BaseFilter(IFilter):
- def __init__(self, learn: Learner, stats: tuple = imagenet_stats):
- super().__init__()
- self.learn = learn
- self.device = next(self.learn.model.parameters()).device
- self.norm, self.denorm = normalize_funcs(*stats)
- def _transform(self, image: PilImage) -> PilImage:
- return image
- def _scale_to_square(self, orig: PilImage, targ: int) -> PilImage:
- # a simple stretch to fit a square really makes a big difference in rendering quality/consistency.
- # I've tried padding to the square as well (reflect, symetric, constant, etc). Not as good!
- targ_sz = (targ, targ)
- return orig.resize(targ_sz, resample=PIL.Image.BILINEAR)
- def _get_model_ready_image(self, orig: PilImage, sz: int) -> PilImage:
- result = self._scale_to_square(orig, sz)
- result = self._transform(result)
- return result
- def _model_process(self, orig: PilImage, sz: int) -> PilImage:
- model_image = self._get_model_ready_image(orig, sz)
- x = pil2tensor(model_image, np.float32)
- x = x.to(self.device)
- x.div_(255)
- x, y = self.norm((x, x), do_x=True)
-
- try:
- result = self.learn.pred_batch(
- ds_type=DatasetType.Valid, batch=(x[None], y[None]), reconstruct=True
- )
- except RuntimeError as rerr:
- if 'memory' not in str(rerr):
- raise rerr
- print('Warning: render_factor was set too high, and out of memory error resulted. Returning original image.')
- return model_image
-
- out = result[0]
- out = self.denorm(out.px, do_x=False)
- out = image2np(out * 255).astype(np.uint8)
- return PilImage.fromarray(out)
- def _unsquare(self, image: PilImage, orig: PilImage) -> PilImage:
- targ_sz = orig.size
- image = image.resize(targ_sz, resample=PIL.Image.BILINEAR)
- return image
- class ColorizerFilter(BaseFilter):
- def __init__(self, learn: Learner, stats: tuple = imagenet_stats):
- super().__init__(learn=learn, stats=stats)
- self.render_base = 16
- def filter(
- self, orig_image: PilImage, filtered_image: PilImage, render_factor: int, post_process: bool = True) -> PilImage:
- render_sz = render_factor * self.render_base
- model_image = self._model_process(orig=filtered_image, sz=render_sz)
- raw_color = self._unsquare(model_image, orig_image)
- if post_process:
- return self._post_process(raw_color, orig_image)
- else:
- return raw_color
- def _transform(self, image: PilImage) -> PilImage:
- return image.convert('LA').convert('RGB')
- # This takes advantage of the fact that human eyes are much less sensitive to
- # imperfections in chrominance compared to luminance. This means we can
- # save a lot on memory and processing in the model, yet get a great high
- # resolution result at the end. This is primarily intended just for
- # inference
- def _post_process(self, raw_color: PilImage, orig: PilImage) -> PilImage:
- color_np = np.asarray(raw_color)
- orig_np = np.asarray(orig)
- color_yuv = cv2.cvtColor(color_np, cv2.COLOR_BGR2YUV)
- # do a black and white transform first to get better luminance values
- orig_yuv = cv2.cvtColor(orig_np, cv2.COLOR_BGR2YUV)
- hires = np.copy(orig_yuv)
- hires[:, :, 1:3] = color_yuv[:, :, 1:3]
- final = cv2.cvtColor(hires, cv2.COLOR_YUV2BGR)
- final = PilImage.fromarray(final)
- return final
- class MasterFilter(BaseFilter):
- def __init__(self, filters: [IFilter], render_factor: int):
- self.filters = filters
- self.render_factor = render_factor
- def filter(
- self, orig_image: PilImage, filtered_image: PilImage, render_factor: int = None, post_process: bool = True) -> PilImage:
- render_factor = self.render_factor if render_factor is None else render_factor
- for filter in self.filters:
- filtered_image = filter.filter(orig_image, filtered_image, render_factor, post_process)
- return filtered_image
|