generators.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159
  1. from fastai.vision import *
  2. from fastai.vision.learner import cnn_config
  3. from .unet import DynamicUnetWide, DynamicUnetDeep
  4. from .loss import FeatureLoss
  5. from .dataset import *
  6. # Weights are implicitly read from ./models/ folder
  7. def gen_inference_wide(
  8. root_folder: Path, weights_name: str, nf_factor: int = 2, arch=models.resnet101
  9. ) -> Learner:
  10. data = get_dummy_databunch()
  11. learn = gen_learner_wide(
  12. data=data, gen_loss=F.l1_loss, nf_factor=nf_factor, arch=arch
  13. )
  14. learn.path = root_folder
  15. learn.load(weights_name)
  16. learn.model.eval()
  17. return learn
  18. def gen_learner_wide(
  19. data: ImageDataBunch,
  20. gen_loss=FeatureLoss(),
  21. arch=models.resnet101,
  22. nf_factor: int = 2,
  23. ) -> Learner:
  24. return unet_learner_wide(
  25. data,
  26. arch=arch,
  27. wd=1e-3,
  28. blur=True,
  29. norm_type=NormType.Spectral,
  30. self_attention=True,
  31. y_range=(-3.0, 3.0),
  32. loss_func=gen_loss,
  33. nf_factor=nf_factor,
  34. )
  35. # The code below is meant to be merged into fastaiv1 ideally
  36. def unet_learner_wide(
  37. data: DataBunch,
  38. arch: Callable,
  39. pretrained: bool = True,
  40. blur_final: bool = True,
  41. norm_type: Optional[NormType] = NormType,
  42. split_on: Optional[SplitFuncOrIdxList] = None,
  43. blur: bool = False,
  44. self_attention: bool = False,
  45. y_range: Optional[Tuple[float, float]] = None,
  46. last_cross: bool = True,
  47. bottle: bool = False,
  48. nf_factor: int = 1,
  49. **kwargs: Any
  50. ) -> Learner:
  51. "Build Unet learner from `data` and `arch`."
  52. meta = cnn_config(arch)
  53. body = create_body(arch, pretrained)
  54. model = to_device(
  55. DynamicUnetWide(
  56. body,
  57. n_classes=data.c,
  58. blur=blur,
  59. blur_final=blur_final,
  60. self_attention=self_attention,
  61. y_range=y_range,
  62. norm_type=norm_type,
  63. last_cross=last_cross,
  64. bottle=bottle,
  65. nf_factor=nf_factor,
  66. ),
  67. data.device,
  68. )
  69. learn = Learner(data, model, **kwargs)
  70. learn.split(ifnone(split_on, meta['split']))
  71. if pretrained:
  72. learn.freeze()
  73. apply_init(model[2], nn.init.kaiming_normal_)
  74. return learn
  75. # ----------------------------------------------------------------------
  76. # Weights are implicitly read from ./models/ folder
  77. def gen_inference_deep(
  78. root_folder: Path, weights_name: str, arch=models.resnet34, nf_factor: float = 1.5
  79. ) -> Learner:
  80. data = get_dummy_databunch()
  81. learn = gen_learner_deep(
  82. data=data, gen_loss=F.l1_loss, arch=arch, nf_factor=nf_factor
  83. )
  84. learn.path = root_folder
  85. learn.load(weights_name)
  86. learn.model.eval()
  87. return learn
  88. def gen_learner_deep(
  89. data: ImageDataBunch,
  90. gen_loss=FeatureLoss(),
  91. arch=models.resnet34,
  92. nf_factor: float = 1.5,
  93. ) -> Learner:
  94. return unet_learner_deep(
  95. data,
  96. arch,
  97. wd=1e-3,
  98. blur=True,
  99. norm_type=NormType.Spectral,
  100. self_attention=True,
  101. y_range=(-3.0, 3.0),
  102. loss_func=gen_loss,
  103. nf_factor=nf_factor,
  104. )
  105. # The code below is meant to be merged into fastaiv1 ideally
  106. def unet_learner_deep(
  107. data: DataBunch,
  108. arch: Callable,
  109. pretrained: bool = True,
  110. blur_final: bool = True,
  111. norm_type: Optional[NormType] = NormType,
  112. split_on: Optional[SplitFuncOrIdxList] = None,
  113. blur: bool = False,
  114. self_attention: bool = False,
  115. y_range: Optional[Tuple[float, float]] = None,
  116. last_cross: bool = True,
  117. bottle: bool = False,
  118. nf_factor: float = 1.5,
  119. **kwargs: Any
  120. ) -> Learner:
  121. "Build Unet learner from `data` and `arch`."
  122. meta = cnn_config(arch)
  123. body = create_body(arch, pretrained)
  124. model = to_device(
  125. DynamicUnetDeep(
  126. body,
  127. n_classes=data.c,
  128. blur=blur,
  129. blur_final=blur_final,
  130. self_attention=self_attention,
  131. y_range=y_range,
  132. norm_type=norm_type,
  133. last_cross=last_cross,
  134. bottle=bottle,
  135. nf_factor=nf_factor,
  136. ),
  137. data.device,
  138. )
  139. learn = Learner(data, model, **kwargs)
  140. learn.split(ifnone(split_on, meta['split']))
  141. if pretrained:
  142. learn.freeze()
  143. apply_init(model[2], nn.init.kaiming_normal_)
  144. return learn
  145. # -----------------------------