filters.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. from numpy import ndarray
  2. from abc import ABC, abstractmethod
  3. from .critics import colorize_crit_learner
  4. from fastai.core import *
  5. from fastai.vision import *
  6. from fastai.vision.image import *
  7. from fastai.vision.data import *
  8. from fastai import *
  9. import math
  10. from scipy import misc
  11. #from torchvision.transforms.functional import *
  12. import cv2
  13. from PIL import Image as PilImage
  14. class IFilter(ABC):
  15. @abstractmethod
  16. def filter(self, orig_image:PilImage, filtered_image:PilImage, render_factor:int)->PilImage:
  17. pass
  18. class BaseFilter(IFilter):
  19. def __init__(self, learn:Learner):
  20. super().__init__()
  21. self.learn=learn
  22. self.norm, self.denorm = normalize_funcs(*imagenet_stats)
  23. def _transform(self, image:PilImage)->PilImage:
  24. return image
  25. def _scale_to_square(self, orig:PilImage, targ:int)->PilImage:
  26. #a simple stretch to fit a square really makes a big difference in rendering quality/consistency.
  27. #I've tried padding to the square as well (reflect, symetric, constant, etc). Not as good!
  28. targ_sz = (targ, targ)
  29. return orig.resize(targ_sz, resample=PIL.Image.BILINEAR)
  30. def _get_model_ready_image(self, orig:PilImage, sz:int)->PilImage:
  31. result = self._scale_to_square(orig, sz)
  32. result = self._transform(result)
  33. return result
  34. def _model_process(self, orig:PilImage, sz:int)->PilImage:
  35. model_image = self._get_model_ready_image(orig, sz)
  36. x = pil2tensor(model_image,np.float32)
  37. x.div_(255)
  38. x,y = self.norm((x,x), do_x=True)
  39. result = self.learn.pred_batch(ds_type=DatasetType.Valid,
  40. batch=(x[None].cuda(),y[None]), reconstruct=True)
  41. out = result[0]
  42. out = self.denorm(out.px, do_x=False)
  43. out = image2np(out*255).astype(np.uint8)
  44. return PilImage.fromarray(out)
  45. def _unsquare(self, image:PilImage, orig:PilImage)->PilImage:
  46. targ_sz = orig.size
  47. image = image.resize(targ_sz, resample=PIL.Image.BILINEAR)
  48. return image
  49. class ColorizerFilter(BaseFilter):
  50. def __init__(self, learn:Learner, map_to_orig:bool=True):
  51. super().__init__(learn=learn)
  52. self.render_base=16
  53. self.map_to_orig=map_to_orig
  54. def filter(self, orig_image:PilImage, filtered_image:PilImage, render_factor:int)->PilImage:
  55. render_sz = render_factor * self.render_base
  56. model_image = self._model_process(orig=filtered_image, sz=render_sz)
  57. if self.map_to_orig:
  58. return self._post_process(model_image, orig_image)
  59. else:
  60. return self._post_process(model_image, filtered_image)
  61. def _transform(self, image:PilImage)->PilImage:
  62. return image.convert('LA').convert('RGB')
  63. #This takes advantage of the fact that human eyes are much less sensitive to
  64. #imperfections in chrominance compared to luminance. This means we can
  65. #save a lot on memory and processing in the model, yet get a great high
  66. #resolution result at the end. This is primarily intended just for
  67. #inference
  68. def _post_process(self, raw_color:PilImage, orig:PilImage)->PilImage:
  69. raw_color = self._unsquare(raw_color, orig)
  70. color_np = np.asarray(raw_color)
  71. orig_np = np.asarray(orig)
  72. color_yuv = cv2.cvtColor(color_np, cv2.COLOR_BGR2YUV)
  73. #do a black and white transform first to get better luminance values
  74. orig_yuv = cv2.cvtColor(orig_np, cv2.COLOR_BGR2YUV)
  75. hires = np.copy(orig_yuv)
  76. hires[:,:,1:3] = color_yuv[:,:,1:3]
  77. final = cv2.cvtColor(hires, cv2.COLOR_YUV2BGR)
  78. final = PilImage.fromarray(final)
  79. return final
  80. class MasterFilter(BaseFilter):
  81. def __init__(self, filters:[IFilter], render_factor:int):
  82. self.filters=filters
  83. self.render_factor=render_factor
  84. def filter(self, orig_image:PilImage, filtered_image:PilImage, render_factor:int=None)->PilImage:
  85. render_factor = self.render_factor if render_factor is None else render_factor
  86. for filter in self.filters:
  87. filtered_image=filter.filter(orig_image, filtered_image, render_factor)
  88. return filtered_image