filters.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. from numpy import ndarray
  2. from abc import ABC, abstractmethod
  3. from .generators import Unet34, GeneratorModule
  4. from .transforms import BlackAndWhiteTransform
  5. from fastai.torch_imports import *
  6. from fastai.core import *
  7. from fastai.transforms import Transform, scale_min, tfms_from_stats, inception_stats
  8. from fastai.transforms import CropType, NoCrop, Denormalize, Scale, scale_to
  9. import math
  10. from scipy import misc
  11. class Padding():
  12. def __init__(self, top:int, bottom:int, left:int, right:int):
  13. self.top = top
  14. self.bottom = bottom
  15. self.left = left
  16. self.right = right
  17. class Filter(ABC):
  18. def __init__(self, tfms:[Transform]):
  19. super().__init__()
  20. self.tfms=tfms
  21. self.denorm = Denormalize(*inception_stats)
  22. @abstractmethod
  23. def filter(self, orig_image:ndarray, render_factor:int)->ndarray:
  24. pass
  25. def _init_model(self, model:nn.Module, weights_path:Path):
  26. load_model(model, weights_path)
  27. model.eval()
  28. torch.no_grad()
  29. def _transform(self, orig:ndarray, sz:int):
  30. for tfm in self.tfms:
  31. orig,_=tfm(orig, False)
  32. _,val_tfms = tfms_from_stats(inception_stats, sz, crop_type=CropType.NO, aug_tfms=[])
  33. val_tfms.tfms = [tfm for tfm in val_tfms.tfms if not (isinstance(tfm, NoCrop) or isinstance(tfm, Scale))]
  34. orig = val_tfms(orig)
  35. return orig
  36. def _scale_to_square(self, orig:ndarray, targ:int, interpolation=cv2.INTER_AREA):
  37. r,c,*_ = orig.shape
  38. ratio = targ/max(r,c)
  39. #a simple stretch to fit a square really makes a big difference in rendering quality/consistency.
  40. #I've tried padding to the square as well (reflect, symetric, constant, etc). Not as good!
  41. sz = (targ, targ)
  42. return cv2.resize(orig, sz, interpolation=interpolation)
  43. def _get_model_ready_image_ndarray(self, orig:ndarray, sz:int):
  44. result = self._scale_to_square(orig, sz)
  45. sz=result.shape[0]
  46. result = self._transform(result, sz)
  47. return result
  48. def _denorm(self, image: ndarray):
  49. if len(image.shape)==3:
  50. image = image[None]
  51. return self.denorm(np.rollaxis(image,1,4))
  52. def _model_process(self, model:GeneratorModule, orig:ndarray, sz:int):
  53. orig = self._get_model_ready_image_ndarray(orig, sz)
  54. orig = VV_(orig[None])
  55. result = model(orig)
  56. result = result.detach().cpu().numpy()
  57. result = self._denorm(result)
  58. return result[0]
  59. def _convert_to_pil(self, im_array:ndarray):
  60. im_array = np.clip(im_array,0,1)
  61. return misc.toimage(im_array)
  62. def _unsquare(self, result:ndarray, orig:ndarray):
  63. sz = (orig.shape[1], orig.shape[0])
  64. return cv2.resize(result, sz, interpolation=cv2.INTER_AREA)
  65. class Colorizer(Filter):
  66. def __init__(self, gpu:int, weights_path:Path):
  67. super().__init__(tfms=[BlackAndWhiteTransform()])
  68. self.model = Unet34(nf_factor=2).cuda(gpu)
  69. self._init_model(self.model, weights_path)
  70. self.render_base=16
  71. def filter(self, orig_image:ndarray, render_factor:int=36)->ndarray:
  72. render_sz = render_factor * self.render_base
  73. model_image = self._model_process(self.model, orig=orig_image, sz=render_sz)
  74. return self._post_process(model_image, orig_image)
  75. #This takes advantage of the fact that human eyes are much less sensitive to
  76. #imperfections in chrominance compared to luminance. This means we can
  77. #save a lot on memory and processing in the model, yet get a great high
  78. #resolution result at the end. This is primarily intended just for
  79. #inference
  80. def _post_process(self, raw_color:ndarray, orig:ndarray):
  81. for tfm in self.tfms:
  82. orig,_=tfm(orig, False)
  83. raw_color = self._unsquare(raw_color, orig)
  84. color_yuv = cv2.cvtColor(raw_color, cv2.COLOR_BGR2YUV)
  85. #do a black and white transform first to get better luminance values
  86. orig_yuv = cv2.cvtColor(orig, cv2.COLOR_BGR2YUV)
  87. hires = np.copy(orig_yuv)
  88. hires[:,:,1:3] = color_yuv[:,:,1:3]
  89. return cv2.cvtColor(hires, cv2.COLOR_YUV2BGR)
  90. #TODO: May not want to do square rendering here like in colorization- it definitely loses
  91. #fidelity visibly (but not too terribly). Will revisit.
  92. class DeFader(Filter):
  93. def __init__(self, gpu:int, weights_path:Path):
  94. super().__init__(tfms=[])
  95. self.model = Unet34(nf_factor=2).cuda(gpu)
  96. self._init_model(self.model, weights_path)
  97. self.render_base=16
  98. def filter(self, orig_image:ndarray, render_factor:int=36)->ndarray:
  99. render_sz = render_factor * self.render_base
  100. model_image = self._model_process(self.model, orig=orig_image, sz=render_sz)
  101. return self._unsquare(model_image, orig_image)