generators.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. from fastai.vision import *
  2. from fastai.vision.learner import cnn_config
  3. from .unet import DynamicUnetWide, DynamicUnetDeep
  4. from .loss import FeatureLoss
  5. from .dataset import *
  6. # Weights are implicitly read from ./models/ folder
  7. def gen_inference_wide(
  8. root_folder: Path, weights_name: str, nf_factor: int = 2, arch=models.resnet101) -> Learner:
  9. data = get_dummy_databunch()
  10. learn = gen_learner_wide(
  11. data=data, gen_loss=F.l1_loss, nf_factor=nf_factor, arch=arch
  12. )
  13. learn.path = root_folder
  14. learn.load(weights_name)
  15. learn.model.eval()
  16. return learn
  17. def gen_learner_wide(
  18. data: ImageDataBunch, gen_loss, arch=models.resnet101, nf_factor: int = 2
  19. ) -> Learner:
  20. return unet_learner_wide(
  21. data,
  22. arch=arch,
  23. wd=1e-3,
  24. blur=True,
  25. norm_type=NormType.Spectral,
  26. self_attention=True,
  27. y_range=(-3.0, 3.0),
  28. loss_func=gen_loss,
  29. nf_factor=nf_factor,
  30. )
  31. # The code below is meant to be merged into fastaiv1 ideally
  32. def unet_learner_wide(
  33. data: DataBunch,
  34. arch: Callable,
  35. pretrained: bool = True,
  36. blur_final: bool = True,
  37. norm_type: Optional[NormType] = NormType,
  38. split_on: Optional[SplitFuncOrIdxList] = None,
  39. blur: bool = False,
  40. self_attention: bool = False,
  41. y_range: Optional[Tuple[float, float]] = None,
  42. last_cross: bool = True,
  43. bottle: bool = False,
  44. nf_factor: int = 1,
  45. **kwargs: Any
  46. ) -> Learner:
  47. "Build Unet learner from `data` and `arch`."
  48. meta = cnn_config(arch)
  49. body = create_body(arch, pretrained)
  50. model = to_device(
  51. DynamicUnetWide(
  52. body,
  53. n_classes=data.c,
  54. blur=blur,
  55. blur_final=blur_final,
  56. self_attention=self_attention,
  57. y_range=y_range,
  58. norm_type=norm_type,
  59. last_cross=last_cross,
  60. bottle=bottle,
  61. nf_factor=nf_factor,
  62. ),
  63. data.device,
  64. )
  65. learn = Learner(data, model, **kwargs)
  66. learn.split(ifnone(split_on, meta['split']))
  67. if pretrained:
  68. learn.freeze()
  69. apply_init(model[2], nn.init.kaiming_normal_)
  70. return learn
  71. # ----------------------------------------------------------------------
  72. # Weights are implicitly read from ./models/ folder
  73. def gen_inference_deep(
  74. root_folder: Path, weights_name: str, arch=models.resnet34, nf_factor: float = 1.5) -> Learner:
  75. data = get_dummy_databunch()
  76. learn = gen_learner_deep(
  77. data=data, gen_loss=F.l1_loss, arch=arch, nf_factor=nf_factor
  78. )
  79. learn.path = root_folder
  80. learn.load(weights_name)
  81. learn.model.eval()
  82. return learn
  83. def gen_learner_deep(
  84. data: ImageDataBunch, gen_loss, arch=models.resnet34, nf_factor: float = 1.5
  85. ) -> Learner:
  86. return unet_learner_deep(
  87. data,
  88. arch,
  89. wd=1e-3,
  90. blur=True,
  91. norm_type=NormType.Spectral,
  92. self_attention=True,
  93. y_range=(-3.0, 3.0),
  94. loss_func=gen_loss,
  95. nf_factor=nf_factor,
  96. )
  97. # The code below is meant to be merged into fastaiv1 ideally
  98. def unet_learner_deep(
  99. data: DataBunch,
  100. arch: Callable,
  101. pretrained: bool = True,
  102. blur_final: bool = True,
  103. norm_type: Optional[NormType] = NormType,
  104. split_on: Optional[SplitFuncOrIdxList] = None,
  105. blur: bool = False,
  106. self_attention: bool = False,
  107. y_range: Optional[Tuple[float, float]] = None,
  108. last_cross: bool = True,
  109. bottle: bool = False,
  110. nf_factor: float = 1.5,
  111. **kwargs: Any
  112. ) -> Learner:
  113. "Build Unet learner from `data` and `arch`."
  114. meta = cnn_config(arch)
  115. body = create_body(arch, pretrained)
  116. model = to_device(
  117. DynamicUnetDeep(
  118. body,
  119. n_classes=data.c,
  120. blur=blur,
  121. blur_final=blur_final,
  122. self_attention=self_attention,
  123. y_range=y_range,
  124. norm_type=norm_type,
  125. last_cross=last_cross,
  126. bottle=bottle,
  127. nf_factor=nf_factor,
  128. ),
  129. data.device,
  130. )
  131. learn = Learner(data, model, **kwargs)
  132. learn.split(ifnone(split_on, meta['split']))
  133. if pretrained:
  134. learn.freeze()
  135. apply_init(model[2], nn.init.kaiming_normal_)
  136. return learn
  137. # -----------------------------