123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159 |
- from fastai.vision import *
- from fastai.vision.learner import cnn_config
- from .unet import DynamicUnetWide, DynamicUnetDeep
- from .loss import FeatureLoss
- from .dataset import *
- # Weights are implicitly read from ./models/ folder
- def gen_inference_wide(
- root_folder: Path, weights_name: str, nf_factor: int = 2, arch=models.resnet101
- ) -> Learner:
- data = get_dummy_databunch()
- learn = gen_learner_wide(
- data=data, gen_loss=F.l1_loss, nf_factor=nf_factor, arch=arch
- )
- learn.path = root_folder
- learn.load(weights_name)
- learn.model.eval()
- return learn
- def gen_learner_wide(
- data: ImageDataBunch,
- gen_loss=FeatureLoss(),
- arch=models.resnet101,
- nf_factor: int = 2,
- ) -> Learner:
- return unet_learner_wide(
- data,
- arch=arch,
- wd=1e-3,
- blur=True,
- norm_type=NormType.Spectral,
- self_attention=True,
- y_range=(-3.0, 3.0),
- loss_func=gen_loss,
- nf_factor=nf_factor,
- )
- # The code below is meant to be merged into fastaiv1 ideally
- def unet_learner_wide(
- data: DataBunch,
- arch: Callable,
- pretrained: bool = True,
- blur_final: bool = True,
- norm_type: Optional[NormType] = NormType,
- split_on: Optional[SplitFuncOrIdxList] = None,
- blur: bool = False,
- self_attention: bool = False,
- y_range: Optional[Tuple[float, float]] = None,
- last_cross: bool = True,
- bottle: bool = False,
- nf_factor: int = 1,
- **kwargs: Any
- ) -> Learner:
- "Build Unet learner from `data` and `arch`."
- meta = cnn_config(arch)
- body = create_body(arch, pretrained)
- model = to_device(
- DynamicUnetWide(
- body,
- n_classes=data.c,
- blur=blur,
- blur_final=blur_final,
- self_attention=self_attention,
- y_range=y_range,
- norm_type=norm_type,
- last_cross=last_cross,
- bottle=bottle,
- nf_factor=nf_factor,
- ),
- data.device,
- )
- learn = Learner(data, model, **kwargs)
- learn.split(ifnone(split_on, meta['split']))
- if pretrained:
- learn.freeze()
- apply_init(model[2], nn.init.kaiming_normal_)
- return learn
- # ----------------------------------------------------------------------
- # Weights are implicitly read from ./models/ folder
- def gen_inference_deep(
- root_folder: Path, weights_name: str, arch=models.resnet34, nf_factor: float = 1.5
- ) -> Learner:
- data = get_dummy_databunch()
- learn = gen_learner_deep(
- data=data, gen_loss=F.l1_loss, arch=arch, nf_factor=nf_factor
- )
- learn.path = root_folder
- learn.load(weights_name)
- learn.model.eval()
- return learn
- def gen_learner_deep(
- data: ImageDataBunch,
- gen_loss=FeatureLoss(),
- arch=models.resnet34,
- nf_factor: float = 1.5,
- ) -> Learner:
- return unet_learner_deep(
- data,
- arch,
- wd=1e-3,
- blur=True,
- norm_type=NormType.Spectral,
- self_attention=True,
- y_range=(-3.0, 3.0),
- loss_func=gen_loss,
- nf_factor=nf_factor,
- )
- # The code below is meant to be merged into fastaiv1 ideally
- def unet_learner_deep(
- data: DataBunch,
- arch: Callable,
- pretrained: bool = True,
- blur_final: bool = True,
- norm_type: Optional[NormType] = NormType,
- split_on: Optional[SplitFuncOrIdxList] = None,
- blur: bool = False,
- self_attention: bool = False,
- y_range: Optional[Tuple[float, float]] = None,
- last_cross: bool = True,
- bottle: bool = False,
- nf_factor: float = 1.5,
- **kwargs: Any
- ) -> Learner:
- "Build Unet learner from `data` and `arch`."
- meta = cnn_config(arch)
- body = create_body(arch, pretrained)
- model = to_device(
- DynamicUnetDeep(
- body,
- n_classes=data.c,
- blur=blur,
- blur_final=blur_final,
- self_attention=self_attention,
- y_range=y_range,
- norm_type=norm_type,
- last_cross=last_cross,
- bottle=bottle,
- nf_factor=nf_factor,
- ),
- data.device,
- )
- learn = Learner(data, model, **kwargs)
- learn.split(ifnone(split_on, meta['split']))
- if pretrained:
- learn.freeze()
- apply_init(model[2], nn.init.kaiming_normal_)
- return learn
- # -----------------------------
|