浏览代码

Fixing training loss tensorboard recording for gans

Jason Antic 6 年之前
父节点
当前提交
81c6f16267
共有 2 个文件被更改,包括 9 次插入0 次删除
  1. 1 0
      .gitignore
  2. 8 0
      fasterai/tensorboard.py

+ 1 - 0
.gitignore

@@ -493,3 +493,4 @@ fastai
 ColorizeTrainingNew2.ipynb
 ColorizeTrainingNew3.ipynb
 ColorizeTrainingNew4.ipynb
+.ipynb_checkpoints/ColorizeTraining1-checkpoint.ipynb

+ 8 - 0
fasterai/tensorboard.py

@@ -192,6 +192,14 @@ class GANTensorboardWriter(LearnerTensorboardWriter):
         self.model_vis.write_tensorboard_stats(model=generator, iteration=iteration, tbwriter=self.tbwriter, name='gen_model_stats')
         self.model_vis.write_tensorboard_stats(model=critic, iteration=iteration, tbwriter=self.tbwriter, name='crit_model_stats')
 
+    def _write_training_loss(self, iteration, last_loss):	
+        trainer = self.learn.gan_trainer	
+        recorder = trainer.recorder	
+
+        if len(recorder.losses) > 0:      	
+            trn_loss = to_np((recorder.losses[-1:])[0])	
+            self.tbwriter.add_scalar(self.metrics_root + 'train_loss', trn_loss, iteration)
+
     def _write_images(self, iteration):
         trainer = self.learn.gan_trainer
         recorder = trainer.recorder