Преглед на файлове

Remove default loss function from generators

If the loss function is set as a default argument it will be loaded even
if not used, which is not desired.

FeatureLoss is used only for training, but was being loaded also for
inference, using a little more memory than required due VGG16 features.

This will not cause any harm because training and inference usage always
set gen_loss argument.
Alexandre Vicenzi преди 5 години
родител
ревизия
6c6e2ee6e4
променени са 1 файла, в които са добавени 2 реда и са изтрити 8 реда
  1. 2 8
      deoldify/generators.py

+ 2 - 8
deoldify/generators.py

@@ -19,10 +19,7 @@ def gen_inference_wide(
 
 
 def gen_learner_wide(
-    data: ImageDataBunch,
-    gen_loss=FeatureLoss(),
-    arch=models.resnet101,
-    nf_factor: int = 2,
+    data: ImageDataBunch, gen_loss, arch=models.resnet101, nf_factor: int = 2
 ) -> Learner:
     return unet_learner_wide(
         data,
@@ -96,10 +93,7 @@ def gen_inference_deep(
 
 
 def gen_learner_deep(
-    data: ImageDataBunch,
-    gen_loss=FeatureLoss(),
-    arch=models.resnet34,
-    nf_factor: float = 1.5,
+    data: ImageDataBunch, gen_loss, arch=models.resnet34, nf_factor: float = 1.5
 ) -> Learner:
     return unet_learner_deep(
         data,