Ver Fonte

Setting detach to False on hooks for Unet cross connections to allow for loss gradient to take these into account

Jason Antic há 6 anos atrás
pai
commit
8a84cfb7bc
1 ficheiros alterados com 2 adições e 2 exclusões
  1. 2 2
      fasterai/unet.py

+ 2 - 2
fasterai/unet.py

@@ -68,7 +68,7 @@ class DynamicUnetDeep(SequentialEx):
         imsize = (256,256)
         sfs_szs = model_sizes(encoder, size=imsize)
         sfs_idxs = list(reversed(_get_sfs_idxs(sfs_szs)))
-        self.sfs = hook_outputs([encoder[i] for i in sfs_idxs])
+        self.sfs = hook_outputs([encoder[i] for i in sfs_idxs], detach=False)
         x = dummy_eval(encoder, imsize).detach()
 
         ni = sfs_szs[-1][1]
@@ -138,7 +138,7 @@ class DynamicUnetWide(SequentialEx):
         imsize = (256,256)
         sfs_szs = model_sizes(encoder, size=imsize)
         sfs_idxs = list(reversed(_get_sfs_idxs(sfs_szs)))
-        self.sfs = hook_outputs([encoder[i] for i in sfs_idxs])
+        self.sfs = hook_outputs([encoder[i] for i in sfs_idxs], detach=False)
         x = dummy_eval(encoder, imsize).detach()
 
         ni = sfs_szs[-1][1]