filters.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. from numpy import ndarray
  2. from abc import ABC, abstractmethod
  3. from .generators import Unet34, GeneratorModule
  4. from .transforms import BlackAndWhiteTransform
  5. from fastai.torch_imports import *
  6. from fastai.core import *
  7. from fastai.transforms import Transform, scale_min, tfms_from_stats, inception_stats
  8. from fastai.transforms import CropType, NoCrop, Denormalize, Scale, scale_to
  9. import math
  10. from scipy import misc
  11. class Padding():
  12. def __init__(self, top:int, bottom:int, left:int, right:int):
  13. self.top = top
  14. self.bottom = bottom
  15. self.left = left
  16. self.right = right
  17. class Filter(ABC):
  18. def __init__(self, tfms:[Transform]):
  19. super().__init__()
  20. self.tfms=tfms
  21. self.denorm = Denormalize(*inception_stats)
  22. @abstractmethod
  23. def filter(self, orig_image:ndarray, render_factor:int)->ndarray:
  24. pass
  25. def _transform(self, orig:ndarray, sz:int):
  26. for tfm in self.tfms:
  27. orig,_=tfm(orig, False)
  28. _,val_tfms = tfms_from_stats(inception_stats, sz, crop_type=CropType.NO, aug_tfms=[])
  29. val_tfms.tfms = [tfm for tfm in val_tfms.tfms if not (isinstance(tfm, NoCrop) or isinstance(tfm, Scale))]
  30. orig = val_tfms(orig)
  31. return orig
  32. def _scale_to_square(self, orig:ndarray, targ:int, interpolation=cv2.INTER_AREA):
  33. r,c,*_ = orig.shape
  34. ratio = targ/max(r,c)
  35. #a simple stretch to fit a square really makes a big difference in rendering quality/consistency.
  36. #I've tried padding to the square as well (reflect, symetric, constant, etc). Not as good!
  37. sz = (targ, targ)
  38. return cv2.resize(orig, sz, interpolation=interpolation)
  39. def _get_model_ready_image_ndarray(self, orig:ndarray, sz:int):
  40. result = self._scale_to_square(orig, sz)
  41. sz=result.shape[0]
  42. result = self._transform(result, sz)
  43. return result
  44. def _denorm(self, image: ndarray):
  45. if len(image.shape)==3:
  46. image = image[None]
  47. return self.denorm(np.rollaxis(image,1,4))
  48. def _model_process(self, model:GeneratorModule, orig:ndarray, sz:int):
  49. orig = self._get_model_ready_image_ndarray(orig, sz)
  50. orig = VV_(orig[None])
  51. result = model(orig)
  52. result = result.detach().cpu().numpy()
  53. result = self._denorm(result)
  54. return result[0]
  55. def _convert_to_pil(self, im_array:ndarray):
  56. im_array = np.clip(im_array,0,1)
  57. return misc.toimage(im_array)
  58. class Colorizer(Filter):
  59. def __init__(self, gpu:int, weights_path:Path):
  60. super().__init__(tfms=[BlackAndWhiteTransform()])
  61. self.model = Unet34(nf_factor=2).cuda(gpu)
  62. load_model(self.model, weights_path)
  63. self.model.eval()
  64. torch.no_grad()
  65. self.render_base = 32
  66. def filter(self, orig_image:ndarray, render_factor:int=14)->ndarray:
  67. render_sz = render_factor * self.render_base
  68. model_image = self._model_process(self.model, orig=orig_image, sz=render_sz)
  69. return self._post_process(model_image, orig_image)
  70. #This takes advantage of the fact that human eyes are much less sensitive to
  71. #imperfections in chrominance compared to luminance. This means we can
  72. #save a lot on memory and processing in the model, yet get a great high
  73. #resolution result at the end. This is primarily intended just for
  74. #inference
  75. def _post_process(self, raw_color:ndarray, orig:ndarray):
  76. for tfm in self.tfms:
  77. orig,_=tfm(orig, False)
  78. sz = (orig.shape[1], orig.shape[0])
  79. raw_color = cv2.resize(raw_color, sz, interpolation=cv2.INTER_AREA)
  80. color_yuv = cv2.cvtColor(raw_color, cv2.COLOR_BGR2YUV)
  81. #do a black and white transform first to get better luminance values
  82. orig_yuv = cv2.cvtColor(orig, cv2.COLOR_BGR2YUV)
  83. hires = np.copy(orig_yuv)
  84. hires[:,:,1:3] = color_yuv[:,:,1:3]
  85. return cv2.cvtColor(hires, cv2.COLOR_YUV2BGR)