generators.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. from fastai.vision import *
  2. from fastai.vision.learner import cnn_config
  3. from .unet import DynamicUnetWide, DynamicUnetDeep
  4. from .loss import FeatureLoss
  5. from .dataset import *
  6. #Weights are implicitly read from ./models/ folder
  7. def gen_inference_wide(root_folder:Path, weights_name:str, nf_factor:int=2, arch=models.resnet101)->Learner:
  8. data = get_dummy_databunch()
  9. learn = gen_learner_wide(data=data, gen_loss=F.l1_loss, nf_factor=nf_factor, arch=arch)
  10. learn.path = root_folder
  11. learn.load(weights_name)
  12. learn.model.eval()
  13. return learn
  14. def gen_learner_wide(data:ImageDataBunch, gen_loss=FeatureLoss(), arch=models.resnet101, nf_factor:int=2)->Learner:
  15. return unet_learner_wide(data, arch=arch, wd=1e-3, blur=True, norm_type=NormType.Spectral,
  16. self_attention=True, y_range=(-3.,3.), loss_func=gen_loss, nf_factor=nf_factor)
  17. #The code below is meant to be merged into fastaiv1 ideally
  18. def unet_learner_wide(data:DataBunch, arch:Callable, pretrained:bool=True, blur_final:bool=True,
  19. norm_type:Optional[NormType]=NormType, split_on:Optional[SplitFuncOrIdxList]=None,
  20. blur:bool=False, self_attention:bool=False, y_range:Optional[Tuple[float,float]]=None, last_cross:bool=True,
  21. bottle:bool=False, nf_factor:int=1, **kwargs:Any)->Learner:
  22. "Build Unet learner from `data` and `arch`."
  23. meta = cnn_config(arch)
  24. body = create_body(arch, pretrained)
  25. model = to_device(DynamicUnetWide(body, n_classes=data.c, blur=blur, blur_final=blur_final,
  26. self_attention=self_attention, y_range=y_range, norm_type=norm_type, last_cross=last_cross,
  27. bottle=bottle, nf_factor=nf_factor), data.device)
  28. learn = Learner(data, model, **kwargs)
  29. learn.split(ifnone(split_on,meta['split']))
  30. if pretrained: learn.freeze()
  31. apply_init(model[2], nn.init.kaiming_normal_)
  32. return learn
  33. #----------------------------------------------------------------------
  34. #Weights are implicitly read from ./models/ folder
  35. def gen_inference_deep(root_folder:Path, weights_name:str, arch=models.resnet34, nf_factor:float=1.5)->Learner:
  36. data = get_dummy_databunch()
  37. learn = gen_learner_deep(data=data, gen_loss=F.l1_loss, arch=arch, nf_factor=nf_factor)
  38. learn.path = root_folder
  39. learn.load(weights_name)
  40. learn.model.eval()
  41. return learn
  42. def gen_learner_deep(data:ImageDataBunch, gen_loss=FeatureLoss(), arch=models.resnet34, nf_factor:float=1.5)->Learner:
  43. return unet_learner_deep(data, arch, wd=1e-3, blur=True, norm_type=NormType.Spectral,
  44. self_attention=True, y_range=(-3.,3.), loss_func=gen_loss, nf_factor=nf_factor)
  45. #The code below is meant to be merged into fastaiv1 ideally
  46. def unet_learner_deep(data:DataBunch, arch:Callable, pretrained:bool=True, blur_final:bool=True,
  47. norm_type:Optional[NormType]=NormType, split_on:Optional[SplitFuncOrIdxList]=None,
  48. blur:bool=False, self_attention:bool=False, y_range:Optional[Tuple[float,float]]=None, last_cross:bool=True,
  49. bottle:bool=False, nf_factor:float=1.5, **kwargs:Any)->Learner:
  50. "Build Unet learner from `data` and `arch`."
  51. meta = cnn_config(arch)
  52. body = create_body(arch, pretrained)
  53. model = to_device(DynamicUnetDeep(body, n_classes=data.c, blur=blur, blur_final=blur_final,
  54. self_attention=self_attention, y_range=y_range, norm_type=norm_type, last_cross=last_cross,
  55. bottle=bottle, nf_factor=nf_factor), data.device)
  56. learn = Learner(data, model, **kwargs)
  57. learn.split(ifnone(split_on,meta['split']))
  58. if pretrained: learn.freeze()
  59. apply_init(model[2], nn.init.kaiming_normal_)
  60. return learn
  61. #-----------------------------