from fastai.core import * from fastai.conv_learner import model_meta, cut_model from fasterai.modules import ConvBlock, UnetBlock, UpSampleBlock, SaveFeatures from abc import ABC, abstractmethod class GeneratorModule(ABC, nn.Module): def __init__(self): super().__init__() def set_trainable(self, trainable: bool): set_trainable(self, trainable) @abstractmethod def get_layer_groups(self, precompute: bool = False)->[]: pass def freeze_to(self, n): c=self.get_layer_groups() for l in c: set_trainable(l, False) for l in c[n:]: set_trainable(l, True) def get_device(self): next(self.parameters()).device class Unet34(GeneratorModule): @staticmethod def get_pretrained_resnet_base(layers_cut:int= 0): f = resnet34 cut,lr_cut = model_meta[f] cut-=layers_cut layers = cut_model(f(True), cut) return nn.Sequential(*layers), lr_cut def __init__(self, nf_factor:int=1, scale:int=1): super().__init__() assert (math.log(scale,2)).is_integer() leakyReLu=False self_attention=True bn=True sn=True self.rn, self.lr_cut = Unet34.get_pretrained_resnet_base() self.relu = nn.ReLU() self.sfs = [SaveFeatures(self.rn[i]) for i in [2,4,5,6]] self.up1 = UnetBlock(512,256,512*nf_factor, sn=sn, leakyReLu=leakyReLu, bn=bn) self.up2 = UnetBlock(512*nf_factor,128,512*nf_factor, sn=sn, leakyReLu=leakyReLu, bn=bn) self.up3 = UnetBlock(512*nf_factor,64,512*nf_factor, sn=sn, self_attention=self_attention, leakyReLu=leakyReLu, bn=bn) self.up4 = UnetBlock(512*nf_factor,64,256*nf_factor, sn=sn, leakyReLu=leakyReLu, bn=bn) self.up5 = UpSampleBlock(256*nf_factor, 32*nf_factor, 2*scale, sn=sn, leakyReLu=leakyReLu, bn=bn) self.out= nn.Sequential(ConvBlock(32*nf_factor, 3, ks=3, actn=False, bn=False, sn=sn), nn.Tanh()) #Gets around irritating inconsistent halving coming from resnet def _pad(self, x, target): h = x.shape[2] w = x.shape[3] target_h = target.shape[2]*2 target_w = target.shape[3]*2 if h[]: lgs = list(split_by_idxs(children(self.rn), [self.lr_cut])) return lgs + [children(self)[1:]] def close(self): for sf in self.sfs: sf.remove()