generators.py 4.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. from fastai.vision import *
  2. from fastai.vision.learner import cnn_config
  3. from fasterai.unet import DynamicUnet2, DynamicUnet3, DynamicUnet4, DynamicUnet5
  4. from .loss import FeatureLoss
  5. def colorize_gen_learner(data:ImageDataBunch, gen_loss=FeatureLoss(), arch=models.resnet34):
  6. return unet_learner2(data, arch, wd=1e-3, blur=True, norm_type=NormType.Spectral,
  7. self_attention=True, y_range=(-3.,3.), loss_func=gen_loss)
  8. #The code below is meant to be merged into fastaiv1 ideally
  9. def unet_learner2(data:DataBunch, arch:Callable, pretrained:bool=True, blur_final:bool=True,
  10. norm_type:Optional[NormType]=NormType, split_on:Optional[SplitFuncOrIdxList]=None,
  11. blur:bool=False, self_attention:bool=False, y_range:Optional[Tuple[float,float]]=None, last_cross:bool=True,
  12. bottle:bool=False, **kwargs:Any)->None:
  13. "Build Unet learner from `data` and `arch`."
  14. meta = cnn_config(arch)
  15. body = create_body(arch, pretrained)
  16. model = to_device(DynamicUnet2(body, n_classes=data.c, blur=blur, blur_final=blur_final,
  17. self_attention=self_attention, y_range=y_range, norm_type=norm_type, last_cross=last_cross,
  18. bottle=bottle), data.device)
  19. learn = Learner(data, model, **kwargs)
  20. learn.split(ifnone(split_on,meta['split']))
  21. if pretrained: learn.freeze()
  22. apply_init(model[2], nn.init.kaiming_normal_)
  23. return learn
  24. def unet_learner3(data:DataBunch, arch:Callable, pretrained:bool=True, blur_final:bool=True,
  25. norm_type:Optional[NormType]=NormType, split_on:Optional[SplitFuncOrIdxList]=None,
  26. blur:bool=False, self_attention:bool=False, y_range:Optional[Tuple[float,float]]=None, last_cross:bool=True,
  27. bottle:bool=False, nf_factor:float=1.0, **kwargs:Any)->None:
  28. "Build Unet learner from `data` and `arch`."
  29. meta = cnn_config(arch)
  30. body = create_body(arch, pretrained)
  31. model = to_device(DynamicUnet3(body, n_classes=data.c, blur=blur, blur_final=blur_final,
  32. self_attention=self_attention, y_range=y_range, norm_type=norm_type, last_cross=last_cross,
  33. bottle=bottle, nf_factor=nf_factor), data.device)
  34. learn = Learner(data, model, **kwargs)
  35. learn.split(ifnone(split_on,meta['split']))
  36. if pretrained: learn.freeze()
  37. apply_init(model[2], nn.init.kaiming_normal_)
  38. return learn
  39. #No batch norm in ESRGAN paper
  40. def unet_learner4(data:DataBunch, arch:Callable, pretrained:bool=True, blur_final:bool=True,
  41. norm_type:Optional[NormType]=NormType, split_on:Optional[SplitFuncOrIdxList]=None,
  42. blur:bool=False, self_attention:bool=False, y_range:Optional[Tuple[float,float]]=None, last_cross:bool=True,
  43. bottle:bool=False, nf_factor:float=1.0, **kwargs:Any)->None:
  44. "Build Unet learner from `data` and `arch`."
  45. meta = cnn_config(arch)
  46. body = create_body(arch, pretrained)
  47. model = to_device(DynamicUnet4(body, n_classes=data.c, blur=blur, blur_final=blur_final,
  48. self_attention=self_attention, y_range=y_range, norm_type=norm_type, last_cross=last_cross,
  49. bottle=bottle, nf_factor=nf_factor), data.device)
  50. learn = Learner(data, model, **kwargs)
  51. learn.split(ifnone(split_on,meta['split']))
  52. if pretrained: learn.freeze()
  53. apply_init(model[2], nn.init.kaiming_normal_)
  54. return learn
  55. #No batch norm in ESRGAN paper, custom nf width
  56. def unet_learner5(data:DataBunch, arch:Callable, pretrained:bool=True, blur_final:bool=True,
  57. norm_type:Optional[NormType]=NormType, split_on:Optional[SplitFuncOrIdxList]=None,
  58. blur:bool=False, self_attention:bool=False, y_range:Optional[Tuple[float,float]]=None, last_cross:bool=True,
  59. bottle:bool=True, **kwargs:Any)->None:
  60. "Build Unet learner from `data` and `arch`."
  61. meta = cnn_config(arch)
  62. body = create_body(arch, pretrained)
  63. model = to_device(DynamicUnet5(body, n_classes=data.c, blur=blur, blur_final=blur_final,
  64. self_attention=self_attention, y_range=y_range, norm_type=norm_type, last_cross=last_cross,
  65. bottle=bottle), data.device)
  66. learn = Learner(data, model, **kwargs)
  67. learn.split(ifnone(split_on,meta['split']))
  68. if pretrained: learn.freeze()
  69. apply_init(model[2], nn.init.kaiming_normal_)
  70. return learn