filters.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. from numpy import ndarray
  2. from abc import ABC, abstractmethod
  3. from .critics import colorize_crit_learner
  4. from fastai.core import *
  5. from fastai.vision import *
  6. from fastai.vision.image import *
  7. from fastai.vision.data import *
  8. from fastai import *
  9. import math
  10. from scipy import misc
  11. import cv2
  12. from PIL import Image as PilImage
  13. class IFilter(ABC):
  14. @abstractmethod
  15. def filter(
  16. self, orig_image: PilImage, filtered_image: PilImage, render_factor: int
  17. ) -> PilImage:
  18. pass
  19. class BaseFilter(IFilter):
  20. def __init__(self, learn: Learner, stats: tuple = imagenet_stats):
  21. super().__init__()
  22. self.learn = learn
  23. self.device = next(self.learn.model.parameters()).device
  24. self.norm, self.denorm = normalize_funcs(*stats)
  25. def _transform(self, image: PilImage) -> PilImage:
  26. return image
  27. def _scale_to_square(self, orig: PilImage, targ: int) -> PilImage:
  28. # a simple stretch to fit a square really makes a big difference in rendering quality/consistency.
  29. # I've tried padding to the square as well (reflect, symetric, constant, etc). Not as good!
  30. targ_sz = (targ, targ)
  31. return orig.resize(targ_sz, resample=PIL.Image.BILINEAR)
  32. def _get_model_ready_image(self, orig: PilImage, sz: int) -> PilImage:
  33. result = self._scale_to_square(orig, sz)
  34. result = self._transform(result)
  35. return result
  36. def _model_process(self, orig: PilImage, sz: int) -> PilImage:
  37. model_image = self._get_model_ready_image(orig, sz)
  38. x = pil2tensor(model_image, np.float32)
  39. x = x.to(self.device)
  40. x.div_(255)
  41. x, y = self.norm((x, x), do_x=True)
  42. result = self.learn.pred_batch(
  43. ds_type=DatasetType.Valid, batch=(x[None], y[None]), reconstruct=True
  44. )
  45. out = result[0]
  46. out = self.denorm(out.px, do_x=False)
  47. out = image2np(out * 255).astype(np.uint8)
  48. return PilImage.fromarray(out)
  49. def _unsquare(self, image: PilImage, orig: PilImage) -> PilImage:
  50. targ_sz = orig.size
  51. image = image.resize(targ_sz, resample=PIL.Image.BILINEAR)
  52. return image
  53. class ColorizerFilter(BaseFilter):
  54. def __init__(self, learn: Learner, stats: tuple = imagenet_stats):
  55. super().__init__(learn=learn, stats=stats)
  56. self.render_base = 16
  57. def filter(
  58. self, orig_image: PilImage, filtered_image: PilImage, render_factor: int, post_process: bool = True) -> PilImage:
  59. render_sz = render_factor * self.render_base
  60. model_image = self._model_process(orig=filtered_image, sz=render_sz)
  61. raw_color = self._unsquare(model_image, orig_image)
  62. if post_process:
  63. return self._post_process(raw_color, orig_image)
  64. else:
  65. return raw_color
  66. def _transform(self, image: PilImage) -> PilImage:
  67. return image.convert('LA').convert('RGB')
  68. # This takes advantage of the fact that human eyes are much less sensitive to
  69. # imperfections in chrominance compared to luminance. This means we can
  70. # save a lot on memory and processing in the model, yet get a great high
  71. # resolution result at the end. This is primarily intended just for
  72. # inference
  73. def _post_process(self, raw_color: PilImage, orig: PilImage) -> PilImage:
  74. color_np = np.asarray(raw_color)
  75. orig_np = np.asarray(orig)
  76. color_yuv = cv2.cvtColor(color_np, cv2.COLOR_BGR2YUV)
  77. # do a black and white transform first to get better luminance values
  78. orig_yuv = cv2.cvtColor(orig_np, cv2.COLOR_BGR2YUV)
  79. hires = np.copy(orig_yuv)
  80. hires[:, :, 1:3] = color_yuv[:, :, 1:3]
  81. final = cv2.cvtColor(hires, cv2.COLOR_YUV2BGR)
  82. final = PilImage.fromarray(final)
  83. return final
  84. class MasterFilter(BaseFilter):
  85. def __init__(self, filters: [IFilter], render_factor: int):
  86. self.filters = filters
  87. self.render_factor = render_factor
  88. def filter(
  89. self, orig_image: PilImage, filtered_image: PilImage, render_factor: int = None, post_process: bool = True) -> PilImage:
  90. render_factor = self.render_factor if render_factor is None else render_factor
  91. for filter in self.filters:
  92. filtered_image = filter.filter(orig_image, filtered_image, render_factor, post_process)
  93. return filtered_image