Переглянути джерело

Setting detach to False on hooks for Unet cross connections to allow for loss gradient to take these into account

Jason Antic 6 роки тому
батько
коміт
8a84cfb7bc
1 змінених файлів з 2 додано та 2 видалено
  1. 2 2
      fasterai/unet.py

+ 2 - 2
fasterai/unet.py

@@ -68,7 +68,7 @@ class DynamicUnetDeep(SequentialEx):
         imsize = (256,256)
         sfs_szs = model_sizes(encoder, size=imsize)
         sfs_idxs = list(reversed(_get_sfs_idxs(sfs_szs)))
-        self.sfs = hook_outputs([encoder[i] for i in sfs_idxs])
+        self.sfs = hook_outputs([encoder[i] for i in sfs_idxs], detach=False)
         x = dummy_eval(encoder, imsize).detach()
 
         ni = sfs_szs[-1][1]
@@ -138,7 +138,7 @@ class DynamicUnetWide(SequentialEx):
         imsize = (256,256)
         sfs_szs = model_sizes(encoder, size=imsize)
         sfs_idxs = list(reversed(_get_sfs_idxs(sfs_szs)))
-        self.sfs = hook_outputs([encoder[i] for i in sfs_idxs])
+        self.sfs = hook_outputs([encoder[i] for i in sfs_idxs], detach=False)
         x = dummy_eval(encoder, imsize).detach()
 
         ni = sfs_szs[-1][1]