Parcourir la source

More experiments and fixes

Jason Antic il y a 6 ans
Parent
commit
f4ed956df4
3 fichiers modifiés avec 89 ajouts et 6 suppressions
  1. 19 1
      fasterai/generators.py
  2. 4 4
      fasterai/tensorboard.py
  3. 66 1
      fasterai/unet.py

+ 19 - 1
fasterai/generators.py

@@ -1,6 +1,6 @@
 from fastai.vision import *
 from fastai.vision.learner import cnn_config
-from fasterai.unet import DynamicUnet2, DynamicUnet3, DynamicUnet4
+from fasterai.unet import DynamicUnet2, DynamicUnet3, DynamicUnet4, DynamicUnet5
 from .loss import FeatureLoss
 
 def colorize_gen_learner(data:ImageDataBunch, gen_loss=FeatureLoss(), arch=models.resnet34):
@@ -60,3 +60,21 @@ def unet_learner4(data:DataBunch, arch:Callable, pretrained:bool=True, blur_fina
     apply_init(model[2], nn.init.kaiming_normal_)
     return learn
 
+
+#No batch norm in ESRGAN paper, custom nf width
+def unet_learner5(data:DataBunch, arch:Callable, pretrained:bool=True, blur_final:bool=True,
+                 norm_type:Optional[NormType]=NormType, split_on:Optional[SplitFuncOrIdxList]=None, 
+                 blur:bool=False, self_attention:bool=False, y_range:Optional[Tuple[float,float]]=None, last_cross:bool=True,
+                 bottle:bool=True, **kwargs:Any)->None:
+    "Build Unet learner from `data` and `arch`."
+    meta = cnn_config(arch)
+    body = create_body(arch, pretrained)
+    model = to_device(DynamicUnet5(body, n_classes=data.c, blur=blur, blur_final=blur_final,
+          self_attention=self_attention, y_range=y_range, norm_type=norm_type, last_cross=last_cross,
+          bottle=bottle), data.device)
+    learn = Learner(data, model, **kwargs)
+    learn.split(ifnone(split_on,meta['split']))
+    if pretrained: learn.freeze()
+    apply_init(model[2], nn.init.kaiming_normal_)
+    return learn
+

+ 4 - 4
fasterai/tensorboard.py

@@ -33,18 +33,18 @@ class TBWriteRequest(ABC):
 class AsyncTBWriter():
     def __init__(self):
         super().__init__()
-        self.stoprequest = Event()
+        self.stop_request = Event()
         self.queue = Queue()
         self.thread = Thread(target=self._queue_processor, daemon=True)
         self.thread.start()
 
     def request_write(self, request: TBWriteRequest):
-        if self.stoprequest.isSet():
+        if self.stop_request.isSet():
             raise Exception('Close was already called!  Cannot perform this operation.')
         self.queue.put(request)
 
     def _queue_processor(self):
-        while not self.stoprequest.isSet():
+        while not self.stop_request.isSet():
             while not self.queue.empty():
                 request = self.queue.get()
                 request.write()
@@ -53,7 +53,7 @@ class AsyncTBWriter():
     #Provided this to stop thread explicitly or by context management (with statement) but thread should end on its own 
     # upon program exit, due to being a daemon.  So using this is probably unecessary.
     def close(self):
-        self.stoprequest.set()
+        self.stop_request.set()
         self.thread.join()
 
     def __enter__(self):

+ 66 - 1
fasterai/unet.py

@@ -184,4 +184,69 @@ class DynamicUnet4(SequentialEx):
         super().__init__(*layers)
 
     def __del__(self):
-        if hasattr(self, "sfs"): self.sfs.remove()
+        if hasattr(self, "sfs"): self.sfs.remove()
+
+class UnetBlock5(nn.Module):
+    "A quasi-UNet block, using `PixelShuffle_ICNR upsampling`."
+    def __init__(self, up_in_c:int, x_in_c:int, out_c:int, hook:Hook, final_div:bool=True, blur:bool=False, leaky:float=None,
+                 self_attention:bool=False,  **kwargs):
+        super().__init__()
+        self.hook = hook
+        self.shuf = PixelShuffle_ICNR2(up_in_c, up_in_c//2, blur=blur, leaky=leaky, **kwargs)
+        self.bn = batchnorm_2d(x_in_c)
+        ni = up_in_c//2 + x_in_c
+        nf = out_c
+        self.conv = conv_layer2(ni, nf, leaky=leaky, self_attention=self_attention, **kwargs)
+        self.relu = relu(leaky=leaky)
+
+    def forward(self, up_in:Tensor) -> Tensor:
+        s = self.hook.stored
+        up_out = self.shuf(up_in)
+        ssh = s.shape[-2:]
+        if ssh != up_out.shape[-2:]:
+            up_out = F.interpolate(up_out, s.shape[-2:], mode='nearest')
+        cat_x = self.relu(torch.cat([up_out, self.bn(s)], dim=1))
+        return self.conv(cat_x)
+
+#custom filter widths
+class DynamicUnet5(SequentialEx):
+    "Create a U-Net from a given architecture."
+    def __init__(self, encoder:nn.Module, n_classes:int, blur:bool=False, blur_final=True, self_attention:bool=False,
+                 y_range:Optional[Tuple[float,float]]=None, last_cross:bool=True, bottle:bool=True,
+                 norm_type:Optional[NormType]=NormType.Batch, nf:int=256, **kwargs):
+        extra_bn =  norm_type == NormType.Spectral
+        imsize = (256,256)
+        sfs_szs = model_sizes(encoder, size=imsize)
+        sfs_idxs = list(reversed(_get_sfs_idxs(sfs_szs)))
+        self.sfs = hook_outputs([encoder[i] for i in sfs_idxs])
+        x = dummy_eval(encoder, imsize).detach()
+
+        ni = sfs_szs[-1][1]
+        middle_conv = nn.Sequential(conv_layer2(ni, ni*2, norm_type=norm_type, extra_bn=extra_bn, **kwargs),
+                                    conv_layer2(ni*2, ni, norm_type=norm_type, extra_bn=extra_bn, **kwargs)).eval()
+        x = middle_conv(x)
+        layers = [encoder, batchnorm_2d(ni), nn.ReLU(), middle_conv]
+
+        for i,idx in enumerate(sfs_idxs):
+            not_final = i!=len(sfs_idxs)-1
+            up_in_c = int(x.shape[1]) if i == 0 else nf
+            x_in_c = int(sfs_szs[idx][1])
+            do_blur = blur and (not_final or blur_final)
+            sa = self_attention and (i==len(sfs_idxs)-3)
+            unet_block = UnetBlock5(up_in_c, x_in_c, nf, self.sfs[i], final_div=not_final, blur=blur, self_attention=sa,
+                                   norm_type=norm_type, extra_bn=extra_bn, **kwargs).eval()
+            layers.append(unet_block)
+            x = unet_block(x)
+
+        ni = x.shape[1]
+        if imsize != sfs_szs[0][-2:]: layers.append(PixelShuffle_ICNR(ni, **kwargs))
+        if last_cross:
+            layers.append(MergeLayer(dense=True))
+            ni += in_channels(encoder)
+            layers.append(res_block(ni, bottle=bottle, norm_type=norm_type, **kwargs))
+        layers += [conv_layer2(ni, n_classes, ks=1, use_activ=False, norm_type=norm_type)]
+        if y_range is not None: layers.append(SigmoidRange(*y_range))
+        super().__init__(*layers)
+
+    def __del__(self):
+        if hasattr(self, "sfs"): self.sfs.remove()