filters.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. from numpy import ndarray
  2. from abc import ABC, abstractmethod
  3. from .generators import Unet34, Unet101, GeneratorModule
  4. from .transforms import BlackAndWhiteTransform
  5. from fastai.torch_imports import *
  6. from fastai.core import *
  7. from fastai.transforms import Transform, scale_min, tfms_from_stats, inception_stats
  8. from fastai.transforms import CropType, NoCrop, Denormalize, Scale, scale_to
  9. import math
  10. from scipy import misc
  11. class Padding():
  12. def __init__(self, top:int, bottom:int, left:int, right:int):
  13. self.top = top
  14. self.bottom = bottom
  15. self.left = left
  16. self.right = right
  17. class Filter(ABC):
  18. def __init__(self, tfms:[Transform]):
  19. super().__init__()
  20. self.tfms=tfms
  21. self.denorm = Denormalize(*inception_stats)
  22. @abstractmethod
  23. def filter(self, orig_image:ndarray, filtered_image:ndarray, render_factor:int)->ndarray:
  24. pass
  25. def _init_model(self, model:nn.Module, weights_path:Path):
  26. load_model(model, weights_path)
  27. model.eval()
  28. torch.no_grad()
  29. def _transform(self, orig:ndarray, sz:int):
  30. for tfm in self.tfms:
  31. orig,_=tfm(orig, False)
  32. _,val_tfms = tfms_from_stats(inception_stats, sz, crop_type=CropType.NO, aug_tfms=[])
  33. val_tfms.tfms = [tfm for tfm in val_tfms.tfms if not (isinstance(tfm, NoCrop) or isinstance(tfm, Scale))]
  34. orig = val_tfms(orig)
  35. return orig
  36. def _scale_to_square(self, orig:ndarray, targ:int, interpolation=cv2.INTER_AREA):
  37. r,c,*_ = orig.shape
  38. ratio = targ/max(r,c)
  39. #a simple stretch to fit a square really makes a big difference in rendering quality/consistency.
  40. #I've tried padding to the square as well (reflect, symetric, constant, etc). Not as good!
  41. sz = (targ, targ)
  42. return cv2.resize(orig, sz, interpolation=interpolation)
  43. def _get_model_ready_image_ndarray(self, orig:ndarray, sz:int):
  44. result = self._scale_to_square(orig, sz)
  45. sz=result.shape[0]
  46. result = self._transform(result, sz)
  47. return result
  48. def _denorm(self, image: ndarray):
  49. if len(image.shape)==3:
  50. image = image[None]
  51. return self.denorm(np.rollaxis(image,1,4))
  52. def _model_process(self, model:GeneratorModule, orig:ndarray, sz:int, gpu:int):
  53. orig = self._get_model_ready_image_ndarray(orig, sz)
  54. orig = VV_(orig[None])
  55. orig = orig.to(device=gpu)
  56. result = model(orig)
  57. result = result.detach().cpu().numpy()
  58. result = self._denorm(result)
  59. return result[0]
  60. def _convert_to_pil(self, im_array:ndarray):
  61. im_array = np.clip(im_array,0,1)
  62. return misc.toimage(im_array)
  63. def _unsquare(self, result:ndarray, orig:ndarray):
  64. sz = (orig.shape[1], orig.shape[0])
  65. return cv2.resize(result, sz, interpolation=cv2.INTER_AREA)
  66. class AbstractColorizer(Filter):
  67. def __init__(self, gpu:int, weights_path:Path, nf_factor:int=2, map_to_orig:bool=True):
  68. super().__init__(tfms=[BlackAndWhiteTransform()])
  69. self.model = self._get_model(nf_factor=nf_factor, gpu=gpu)
  70. self.gpu = gpu
  71. self._init_model(self.model, weights_path)
  72. self.render_base=16
  73. self.map_to_orig=map_to_orig
  74. @abstractmethod
  75. def _get_model(self, nf_factor:int, gpu:int)->GeneratorModule:
  76. pass
  77. def filter(self, orig_image:ndarray, filtered_image:ndarray, render_factor:int=36)->ndarray:
  78. render_sz = render_factor * self.render_base
  79. model_image = self._model_process(self.model, orig=filtered_image, sz=render_sz, gpu=self.gpu)
  80. if self.map_to_orig:
  81. return self._post_process(model_image, orig_image)
  82. else:
  83. return self._post_process(model_image, filtered_image)
  84. #This takes advantage of the fact that human eyes are much less sensitive to
  85. #imperfections in chrominance compared to luminance. This means we can
  86. #save a lot on memory and processing in the model, yet get a great high
  87. #resolution result at the end. This is primarily intended just for
  88. #inference
  89. def _post_process(self, raw_color:ndarray, orig:ndarray):
  90. for tfm in self.tfms:
  91. orig,_=tfm(orig, False)
  92. raw_color = self._unsquare(raw_color, orig)
  93. color_yuv = cv2.cvtColor(raw_color, cv2.COLOR_BGR2YUV)
  94. #do a black and white transform first to get better luminance values
  95. orig_yuv = cv2.cvtColor(orig, cv2.COLOR_BGR2YUV)
  96. hires = np.copy(orig_yuv)
  97. hires[:,:,1:3] = color_yuv[:,:,1:3]
  98. return cv2.cvtColor(hires, cv2.COLOR_YUV2BGR)
  99. class Colorizer34(AbstractColorizer):
  100. def __init__(self, gpu:int, weights_path:Path, nf_factor:int=2, map_to_orig:bool=True):
  101. super().__init__(gpu=gpu, weights_path=weights_path, nf_factor=nf_factor, map_to_orig=map_to_orig)
  102. def _get_model(self, nf_factor:int, gpu:int)->GeneratorModule:
  103. return Unet34(nf_factor=nf_factor).cuda(gpu)
  104. class Colorizer101(AbstractColorizer):
  105. def __init__(self, gpu:int, weights_path:Path, nf_factor:int=2, map_to_orig:bool=True):
  106. super().__init__(gpu=gpu, weights_path=weights_path, nf_factor=nf_factor, map_to_orig=map_to_orig)
  107. def _get_model(self, nf_factor:int, gpu:int)->GeneratorModule:
  108. return Unet101(nf_factor=nf_factor).cuda(gpu)
  109. #TODO: May not want to do square rendering here like in colorization- it definitely loses
  110. #fidelity visibly (but not too terribly). Will revisit.
  111. class DeFader(Filter):
  112. def __init__(self, gpu:int, weights_path:Path, nf_factor:int=2):
  113. super().__init__(tfms=[BlackAndWhiteTransform()])
  114. self.model = Unet34(nf_factor=nf_factor).cuda(gpu)
  115. self._init_model(self.model, weights_path)
  116. self.render_base=16
  117. self.gpu = gpu
  118. def filter(self, orig_image:ndarray, filtered_image:ndarray, render_factor:int=36)->ndarray:
  119. render_sz = render_factor * self.render_base
  120. model_image = self._model_process(self.model, orig=filtered_image, sz=render_sz, gpu=self.gpu)
  121. return self._post_process(model_image, filtered_image)
  122. def _post_process(self, result:ndarray, orig:ndarray):
  123. return self._unsquare(result, orig)