generators.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. from fastai.vision import *
  2. from fastai.vision.learner import cnn_config
  3. from .unet import DynamicUnetWide, DynamicUnetDeep
  4. from .loss import FeatureLoss
  5. from .dataset import *
  6. # Weights are implicitly read from ./models/ folder
  7. def gen_inference_wide(
  8. root_folder: Path, weights_name: str, nf_factor: int = 2, arch=models.resnet101
  9. ) -> Learner:
  10. data = get_dummy_databunch()
  11. learn = gen_learner_wide(
  12. data=data, gen_loss=F.l1_loss, nf_factor=nf_factor, arch=arch
  13. )
  14. learn.path = root_folder
  15. learn.load(weights_name)
  16. learn.model.eval()
  17. return learn
  18. def gen_learner_wide(
  19. data: ImageDataBunch, gen_loss, arch=models.resnet101, nf_factor: int = 2
  20. ) -> Learner:
  21. return unet_learner_wide(
  22. data,
  23. arch=arch,
  24. wd=1e-3,
  25. blur=True,
  26. norm_type=NormType.Spectral,
  27. self_attention=True,
  28. y_range=(-3.0, 3.0),
  29. loss_func=gen_loss,
  30. nf_factor=nf_factor,
  31. )
  32. # The code below is meant to be merged into fastaiv1 ideally
  33. def unet_learner_wide(
  34. data: DataBunch,
  35. arch: Callable,
  36. pretrained: bool = True,
  37. blur_final: bool = True,
  38. norm_type: Optional[NormType] = NormType,
  39. split_on: Optional[SplitFuncOrIdxList] = None,
  40. blur: bool = False,
  41. self_attention: bool = False,
  42. y_range: Optional[Tuple[float, float]] = None,
  43. last_cross: bool = True,
  44. bottle: bool = False,
  45. nf_factor: int = 1,
  46. **kwargs: Any
  47. ) -> Learner:
  48. "Build Unet learner from `data` and `arch`."
  49. meta = cnn_config(arch)
  50. body = create_body(arch, pretrained)
  51. model = to_device(
  52. DynamicUnetWide(
  53. body,
  54. n_classes=data.c,
  55. blur=blur,
  56. blur_final=blur_final,
  57. self_attention=self_attention,
  58. y_range=y_range,
  59. norm_type=norm_type,
  60. last_cross=last_cross,
  61. bottle=bottle,
  62. nf_factor=nf_factor,
  63. ),
  64. data.device,
  65. )
  66. learn = Learner(data, model, **kwargs)
  67. learn.split(ifnone(split_on, meta['split']))
  68. if pretrained:
  69. learn.freeze()
  70. apply_init(model[2], nn.init.kaiming_normal_)
  71. return learn
  72. # ----------------------------------------------------------------------
  73. # Weights are implicitly read from ./models/ folder
  74. def gen_inference_deep(
  75. root_folder: Path, weights_name: str, arch=models.resnet34, nf_factor: float = 1.5
  76. ) -> Learner:
  77. data = get_dummy_databunch()
  78. learn = gen_learner_deep(
  79. data=data, gen_loss=F.l1_loss, arch=arch, nf_factor=nf_factor
  80. )
  81. learn.path = root_folder
  82. learn.load(weights_name)
  83. learn.model.eval()
  84. return learn
  85. def gen_learner_deep(
  86. data: ImageDataBunch, gen_loss, arch=models.resnet34, nf_factor: float = 1.5
  87. ) -> Learner:
  88. return unet_learner_deep(
  89. data,
  90. arch,
  91. wd=1e-3,
  92. blur=True,
  93. norm_type=NormType.Spectral,
  94. self_attention=True,
  95. y_range=(-3.0, 3.0),
  96. loss_func=gen_loss,
  97. nf_factor=nf_factor,
  98. )
  99. # The code below is meant to be merged into fastaiv1 ideally
  100. def unet_learner_deep(
  101. data: DataBunch,
  102. arch: Callable,
  103. pretrained: bool = True,
  104. blur_final: bool = True,
  105. norm_type: Optional[NormType] = NormType,
  106. split_on: Optional[SplitFuncOrIdxList] = None,
  107. blur: bool = False,
  108. self_attention: bool = False,
  109. y_range: Optional[Tuple[float, float]] = None,
  110. last_cross: bool = True,
  111. bottle: bool = False,
  112. nf_factor: float = 1.5,
  113. **kwargs: Any
  114. ) -> Learner:
  115. "Build Unet learner from `data` and `arch`."
  116. meta = cnn_config(arch)
  117. body = create_body(arch, pretrained)
  118. model = to_device(
  119. DynamicUnetDeep(
  120. body,
  121. n_classes=data.c,
  122. blur=blur,
  123. blur_final=blur_final,
  124. self_attention=self_attention,
  125. y_range=y_range,
  126. norm_type=norm_type,
  127. last_cross=last_cross,
  128. bottle=bottle,
  129. nf_factor=nf_factor,
  130. ),
  131. data.device,
  132. )
  133. learn = Learner(data, model, **kwargs)
  134. learn.split(ifnone(split_on, meta['split']))
  135. if pretrained:
  136. learn.freeze()
  137. apply_init(model[2], nn.init.kaiming_normal_)
  138. return learn
  139. # -----------------------------