|
@@ -1,4 +1,4 @@
|
|
|
-import fastai
|
|
|
+"Provides convenient callbacks for Learners that write model images, metrics/losses, stats and histograms to Tensorboard"
|
|
|
from fastai.basic_train import Learner
|
|
|
from fastai.basic_data import DatasetType, DataBunch
|
|
|
from fastai.vision import Image
|
|
@@ -6,7 +6,6 @@ from fastai.callbacks import LearnerCallback
|
|
|
from fastai.core import *
|
|
|
from fastai.torch_core import *
|
|
|
from threading import Thread, Event
|
|
|
-import time
|
|
|
from time import sleep
|
|
|
from queue import Queue
|
|
|
import statistics
|
|
@@ -15,6 +14,189 @@ from abc import ABC, abstractmethod
|
|
|
from tensorboardX import SummaryWriter
|
|
|
|
|
|
|
|
|
+__all__=['LearnerTensorboardWriter', 'GANTensorboardWriter', 'ImageGenTensorboardWriter']
|
|
|
+
|
|
|
+
|
|
|
+class LearnerTensorboardWriter(LearnerCallback):
|
|
|
+ def __init__(self, learn:Learner, base_dir:Path, name:str, loss_iters:int=25, hist_iters:int=500, stats_iters:int=100):
|
|
|
+ super().__init__(learn=learn)
|
|
|
+ self.base_dir = base_dir
|
|
|
+ self.name = name
|
|
|
+ log_dir = base_dir/name
|
|
|
+ self.tbwriter = SummaryWriter(log_dir=str(log_dir))
|
|
|
+ self.loss_iters = loss_iters
|
|
|
+ self.hist_iters = hist_iters
|
|
|
+ self.stats_iters = stats_iters
|
|
|
+ self.hist_writer = HistogramTBWriter()
|
|
|
+ self.stats_writer = ModelStatsTBWriter()
|
|
|
+ self.data = None
|
|
|
+ self.metrics_root = '/metrics/'
|
|
|
+ self._update_batches_if_needed()
|
|
|
+
|
|
|
+ def _update_batches_if_needed(self):
|
|
|
+ # one_batch function is extremely slow with large datasets. This is an optimization.
|
|
|
+ # Note that also we want to always show the same batches so we can see changes
|
|
|
+ # in tensorboard
|
|
|
+ update_batches = self.data is not self.learn.data
|
|
|
+
|
|
|
+ if update_batches:
|
|
|
+ self.data = self.learn.data
|
|
|
+ self.trn_batch = self.learn.data.one_batch(
|
|
|
+ ds_type=DatasetType.Train, detach=True, denorm=False, cpu=False)
|
|
|
+ self.val_batch = self.learn.data.one_batch(
|
|
|
+ ds_type=DatasetType.Valid, detach=True, denorm=False, cpu=False)
|
|
|
+
|
|
|
+ def _write_model_stats(self, iteration:int):
|
|
|
+ self.stats_writer.write(
|
|
|
+ model=self.learn.model, iteration=iteration, tbwriter=self.tbwriter)
|
|
|
+
|
|
|
+ def _write_training_loss(self, iteration:int, last_loss:Tensor):
|
|
|
+ scalar_value = to_np(last_loss)
|
|
|
+ tag = self.metrics_root + 'train_loss'
|
|
|
+ self.tbwriter.add_scalar(tag=tag, scalar_value=scalar_value, global_step=iteration)
|
|
|
+
|
|
|
+ def _write_weight_histograms(self, iteration:int):
|
|
|
+ self.hist_writer.write(
|
|
|
+ model=self.learn.model, iteration=iteration, tbwriter=self.tbwriter)
|
|
|
+
|
|
|
+ #TODO: Relying on a specific hardcoded start_idx here isn't great. Is there a better solution?
|
|
|
+ def _write_metrics(self, iteration:int, last_metrics:MetricsList, start_idx:int=2):
|
|
|
+ recorder = self.learn.recorder
|
|
|
+
|
|
|
+ for i, name in enumerate(recorder.names[start_idx:]):
|
|
|
+ if len(last_metrics) < i+1: return
|
|
|
+ scalar_value = last_metrics[i]
|
|
|
+ tag = self.metrics_root + name
|
|
|
+ self.tbwriter.add_scalar(tag=tag, scalar_value=scalar_value, global_step=iteration)
|
|
|
+
|
|
|
+ def on_batch_end(self, last_loss:Tensor, iteration:int, **kwargs):
|
|
|
+ if iteration == 0: return
|
|
|
+ self._update_batches_if_needed()
|
|
|
+
|
|
|
+ if iteration % self.loss_iters == 0:
|
|
|
+ self._write_training_loss(iteration=iteration, last_loss=last_loss)
|
|
|
+
|
|
|
+ if iteration % self.hist_iters == 0:
|
|
|
+ self._write_weight_histograms(iteration=iteration)
|
|
|
+
|
|
|
+ # Doing stuff here that requires gradient info, because they get zeroed out afterwards in training loop
|
|
|
+ def on_backward_end(self, iteration:int, **kwargs):
|
|
|
+ if iteration == 0: return
|
|
|
+ self._update_batches_if_needed()
|
|
|
+
|
|
|
+ if iteration % self.stats_iters == 0:
|
|
|
+ self._write_model_stats(iteration=iteration)
|
|
|
+
|
|
|
+ def on_epoch_end(self, last_metrics:MetricsList, iteration:int, **kwargs):
|
|
|
+ self._write_metrics(iteration=iteration, last_metrics=last_metrics)
|
|
|
+
|
|
|
+# TODO: We're overriding almost everything here. Seems like a good idea to question that ("is a" vs "has a")
|
|
|
+class GANTensorboardWriter(LearnerTensorboardWriter):
|
|
|
+ def __init__(self, learn:Learner, base_dir:Path, name:str, loss_iters:int=25, hist_iters:int=500,
|
|
|
+ stats_iters:int=100, visual_iters:int=100):
|
|
|
+ super().__init__(learn=learn, base_dir=base_dir, name=name, loss_iters=loss_iters,
|
|
|
+ hist_iters=hist_iters, stats_iters=stats_iters)
|
|
|
+ self.visual_iters = visual_iters
|
|
|
+ self.img_gen_vis = ImageTBWriter()
|
|
|
+ self.gen_stats_updated = True
|
|
|
+ self.crit_stats_updated = True
|
|
|
+
|
|
|
+ # override
|
|
|
+ def _write_weight_histograms(self, iteration:int):
|
|
|
+ trainer = self.learn.gan_trainer
|
|
|
+ generator = trainer.generator
|
|
|
+ critic = trainer.critic
|
|
|
+ self.hist_writer.write(
|
|
|
+ model=generator, iteration=iteration, tbwriter=self.tbwriter, name='generator')
|
|
|
+ self.hist_writer.write(
|
|
|
+ model=critic, iteration=iteration, tbwriter=self.tbwriter, name='critic')
|
|
|
+
|
|
|
+ # override
|
|
|
+ def _write_model_stats(self, iteration:int):
|
|
|
+ trainer = self.learn.gan_trainer
|
|
|
+ generator = trainer.generator
|
|
|
+ critic = trainer.critic
|
|
|
+
|
|
|
+ # Don't want to write stats when model is not iterated on and hence has zeroed out gradients
|
|
|
+ gen_mode = trainer.gen_mode
|
|
|
+
|
|
|
+ if gen_mode and not self.gen_stats_updated:
|
|
|
+ self.stats_writer.write(
|
|
|
+ model=generator, iteration=iteration, tbwriter=self.tbwriter, name='gen_model_stats')
|
|
|
+ self.gen_stats_updated = True
|
|
|
+
|
|
|
+ if not gen_mode and not self.crit_stats_updated:
|
|
|
+ self.stats_writer.write(
|
|
|
+ model=critic, iteration=iteration, tbwriter=self.tbwriter, name='crit_model_stats')
|
|
|
+ self.crit_stats_updated = True
|
|
|
+
|
|
|
+ # override
|
|
|
+ def _write_training_loss(self, iteration:int, last_loss:Tensor):
|
|
|
+ trainer = self.learn.gan_trainer
|
|
|
+ recorder = trainer.recorder
|
|
|
+
|
|
|
+ if len(recorder.losses) > 0:
|
|
|
+ scalar_value = to_np((recorder.losses[-1:])[0])
|
|
|
+ tag = self.metrics_root + 'train_loss'
|
|
|
+ self.tbwriter.add_scalar(tag=tag, scalar_value=scalar_value, global_step=iteration)
|
|
|
+
|
|
|
+ def _write(self, iteration:int):
|
|
|
+ trainer = self.learn.gan_trainer
|
|
|
+ #TODO: Switching gen_mode temporarily seems a bit hacky here. Certainly not a good side-effect. Is there a better way?
|
|
|
+ gen_mode = trainer.gen_mode
|
|
|
+
|
|
|
+ try:
|
|
|
+ trainer.switch(gen_mode=True)
|
|
|
+ self.img_gen_vis.write(learn=self.learn, trn_batch=self.trn_batch, val_batch=self.val_batch,
|
|
|
+ iteration=iteration, tbwriter=self.tbwriter)
|
|
|
+ finally:
|
|
|
+ trainer.switch(gen_mode=gen_mode)
|
|
|
+
|
|
|
+ # override
|
|
|
+ def on_batch_end(self, iteration:int, **kwargs):
|
|
|
+ super().on_batch_end(iteration=iteration, **kwargs)
|
|
|
+ if iteration == 0: return
|
|
|
+ if iteration % self.visual_iters == 0:
|
|
|
+ self._write(iteration=iteration)
|
|
|
+
|
|
|
+ # override
|
|
|
+ def on_backward_end(self, iteration:int, **kwargs):
|
|
|
+ if iteration == 0: return
|
|
|
+ self._update_batches_if_needed()
|
|
|
+
|
|
|
+ #TODO: This could perhaps be implemented as queues of requests instead but that seemed like overkill.
|
|
|
+ # But I'm not the biggest fan of maintaining these boolean flags either... Review pls.
|
|
|
+ if iteration % self.stats_iters == 0:
|
|
|
+ self.gen_stats_updated = False
|
|
|
+ self.crit_stats_updated = False
|
|
|
+
|
|
|
+ if not (self.gen_stats_updated and self.crit_stats_updated):
|
|
|
+ self._write_model_stats(iteration=iteration)
|
|
|
+
|
|
|
+
|
|
|
+class ImageGenTensorboardWriter(LearnerTensorboardWriter):
|
|
|
+ def __init__(self, learn:Learner, base_dir:Path, name:str, loss_iters:int=25, hist_iters:int=500,
|
|
|
+ stats_iters: int = 100, visual_iters: int = 100):
|
|
|
+ super().__init__(learn=learn, base_dir=base_dir, name=name, loss_iters=loss_iters, hist_iters=hist_iters,
|
|
|
+ stats_iters=stats_iters)
|
|
|
+ self.visual_iters = visual_iters
|
|
|
+ self.img_gen_vis = ImageTBWriter()
|
|
|
+
|
|
|
+ def _write(self, iteration:int):
|
|
|
+ self.img_gen_vis.write(learn=self.learn, trn_batch=self.trn_batch, val_batch=self.val_batch,
|
|
|
+ iteration=iteration, tbwriter=self.tbwriter)
|
|
|
+
|
|
|
+ # override
|
|
|
+ def on_batch_end(self, iteration:int, **kwargs):
|
|
|
+ super().on_batch_end(iteration=iteration, **kwargs)
|
|
|
+ if iteration == 0: return
|
|
|
+
|
|
|
+ if iteration % self.visual_iters == 0:
|
|
|
+ self._write(iteration=iteration)
|
|
|
+
|
|
|
+
|
|
|
+#------PRIVATE-----------
|
|
|
+
|
|
|
class TBWriteRequest(ABC):
|
|
|
def __init__(self, tbwriter: SummaryWriter, iteration:int):
|
|
|
super().__init__()
|
|
@@ -218,180 +400,3 @@ class ImageTBWriter():
|
|
|
|
|
|
|
|
|
|
|
|
-#--------CALLBACKS----------------#
|
|
|
-class LearnerTensorboardWriter(LearnerCallback):
|
|
|
- def __init__(self, learn:Learner, base_dir:Path, name:str, loss_iters:int=25, hist_iters:int=500, stats_iters:int=100):
|
|
|
- super().__init__(learn=learn)
|
|
|
- self.base_dir = base_dir
|
|
|
- self.name = name
|
|
|
- log_dir = base_dir/name
|
|
|
- self.tbwriter = SummaryWriter(log_dir=str(log_dir))
|
|
|
- self.loss_iters = loss_iters
|
|
|
- self.hist_iters = hist_iters
|
|
|
- self.stats_iters = stats_iters
|
|
|
- self.hist_writer = HistogramTBWriter()
|
|
|
- self.stats_writer = ModelStatsTBWriter()
|
|
|
- self.data = None
|
|
|
- self.metrics_root = '/metrics/'
|
|
|
- self._update_batches_if_needed()
|
|
|
-
|
|
|
- def _update_batches_if_needed(self):
|
|
|
- # one_batch function is extremely slow with large datasets. This is an optimization.
|
|
|
- # Note that also we want to always show the same batches so we can see changes
|
|
|
- # in tensorboard
|
|
|
- update_batches = self.data is not self.learn.data
|
|
|
-
|
|
|
- if update_batches:
|
|
|
- self.data = self.learn.data
|
|
|
- self.trn_batch = self.learn.data.one_batch(
|
|
|
- ds_type=DatasetType.Train, detach=True, denorm=False, cpu=False)
|
|
|
- self.val_batch = self.learn.data.one_batch(
|
|
|
- ds_type=DatasetType.Valid, detach=True, denorm=False, cpu=False)
|
|
|
-
|
|
|
- def _write_model_stats(self, iteration:int):
|
|
|
- self.stats_writer.write(
|
|
|
- model=self.learn.model, iteration=iteration, tbwriter=self.tbwriter)
|
|
|
-
|
|
|
- def _write_training_loss(self, iteration:int, last_loss:Tensor):
|
|
|
- scalar_value = to_np(last_loss)
|
|
|
- tag = self.metrics_root + 'train_loss'
|
|
|
- self.tbwriter.add_scalar(tag=tag, scalar_value=scalar_value, global_step=iteration)
|
|
|
-
|
|
|
- def _write_weight_histograms(self, iteration:int):
|
|
|
- self.hist_writer.write(
|
|
|
- model=self.learn.model, iteration=iteration, tbwriter=self.tbwriter)
|
|
|
-
|
|
|
- #TODO: Relying on a specific hardcoded start_idx here isn't great. Is there a better solution?
|
|
|
- def _write_metrics(self, iteration:int, last_metrics:MetricsList, start_idx:int=2):
|
|
|
- recorder = self.learn.recorder
|
|
|
-
|
|
|
- for i, name in enumerate(recorder.names[start_idx:]):
|
|
|
- if len(last_metrics) < i+1: return
|
|
|
- scalar_value = last_metrics[i]
|
|
|
- tag = self.metrics_root + name
|
|
|
- self.tbwriter.add_scalar(tag=tag, scalar_value=scalar_value, global_step=iteration)
|
|
|
-
|
|
|
- def on_batch_end(self, last_loss:Tensor, iteration:int, **kwargs):
|
|
|
- if iteration == 0: return
|
|
|
- self._update_batches_if_needed()
|
|
|
-
|
|
|
- if iteration % self.loss_iters == 0:
|
|
|
- self._write_training_loss(iteration=iteration, last_loss=last_loss)
|
|
|
-
|
|
|
- if iteration % self.hist_iters == 0:
|
|
|
- self._write_weight_histograms(iteration=iteration)
|
|
|
-
|
|
|
- # Doing stuff here that requires gradient info, because they get zeroed out afterwards in training loop
|
|
|
- def on_backward_end(self, iteration:int, **kwargs):
|
|
|
- if iteration == 0: return
|
|
|
- self._update_batches_if_needed()
|
|
|
-
|
|
|
- if iteration % self.stats_iters == 0:
|
|
|
- self._write_model_stats(iteration=iteration)
|
|
|
-
|
|
|
- def on_epoch_end(self, last_metrics:MetricsList, iteration:int, **kwargs):
|
|
|
- self._write_metrics(iteration=iteration, last_metrics=last_metrics)
|
|
|
-
|
|
|
-# TODO: We're overriding almost everything here. Seems like a good idea to question that ("is a" vs "has a")
|
|
|
-class GANTensorboardWriter(LearnerTensorboardWriter):
|
|
|
- def __init__(self, learn:Learner, base_dir:Path, name:str, loss_iters:int=25, hist_iters:int=500,
|
|
|
- stats_iters:int=100, visual_iters:int=100):
|
|
|
- super().__init__(learn=learn, base_dir=base_dir, name=name, loss_iters=loss_iters,
|
|
|
- hist_iters=hist_iters, stats_iters=stats_iters)
|
|
|
- self.visual_iters = visual_iters
|
|
|
- self.img_gen_vis = ImageTBWriter()
|
|
|
- self.gen_stats_updated = True
|
|
|
- self.crit_stats_updated = True
|
|
|
-
|
|
|
- # override
|
|
|
- def _write_weight_histograms(self, iteration:int):
|
|
|
- trainer = self.learn.gan_trainer
|
|
|
- generator = trainer.generator
|
|
|
- critic = trainer.critic
|
|
|
- self.hist_writer.write(
|
|
|
- model=generator, iteration=iteration, tbwriter=self.tbwriter, name='generator')
|
|
|
- self.hist_writer.write(
|
|
|
- model=critic, iteration=iteration, tbwriter=self.tbwriter, name='critic')
|
|
|
-
|
|
|
- # override
|
|
|
- def _write_model_stats(self, iteration:int):
|
|
|
- trainer = self.learn.gan_trainer
|
|
|
- generator = trainer.generator
|
|
|
- critic = trainer.critic
|
|
|
-
|
|
|
- # Don't want to write stats when model is not iterated on and hence has zeroed out gradients
|
|
|
- gen_mode = trainer.gen_mode
|
|
|
-
|
|
|
- if gen_mode and not self.gen_stats_updated:
|
|
|
- self.stats_writer.write(
|
|
|
- model=generator, iteration=iteration, tbwriter=self.tbwriter, name='gen_model_stats')
|
|
|
- self.gen_stats_updated = True
|
|
|
-
|
|
|
- if not gen_mode and not self.crit_stats_updated:
|
|
|
- self.stats_writer.write(
|
|
|
- model=critic, iteration=iteration, tbwriter=self.tbwriter, name='crit_model_stats')
|
|
|
- self.crit_stats_updated = True
|
|
|
-
|
|
|
- # override
|
|
|
- def _write_training_loss(self, iteration:int, last_loss:Tensor):
|
|
|
- trainer = self.learn.gan_trainer
|
|
|
- recorder = trainer.recorder
|
|
|
-
|
|
|
- if len(recorder.losses) > 0:
|
|
|
- scalar_value = to_np((recorder.losses[-1:])[0])
|
|
|
- tag = self.metrics_root + 'train_loss'
|
|
|
- self.tbwriter.add_scalar(tag=tag, scalar_value=scalar_value, global_step=iteration)
|
|
|
-
|
|
|
- def _write(self, iteration:int):
|
|
|
- trainer = self.learn.gan_trainer
|
|
|
- #TODO: Switching gen_mode temporarily seems a bit hacky here. Certainly not a good side-effect. Is there a better way?
|
|
|
- gen_mode = trainer.gen_mode
|
|
|
-
|
|
|
- try:
|
|
|
- trainer.switch(gen_mode=True)
|
|
|
- self.img_gen_vis.write(learn=self.learn, trn_batch=self.trn_batch, val_batch=self.val_batch,
|
|
|
- iteration=iteration, tbwriter=self.tbwriter)
|
|
|
- finally:
|
|
|
- trainer.switch(gen_mode=gen_mode)
|
|
|
-
|
|
|
- # override
|
|
|
- def on_batch_end(self, iteration:int, **kwargs):
|
|
|
- super().on_batch_end(iteration=iteration, **kwargs)
|
|
|
- if iteration == 0: return
|
|
|
- if iteration % self.visual_iters == 0:
|
|
|
- self._write(iteration=iteration)
|
|
|
-
|
|
|
- # override
|
|
|
- def on_backward_end(self, iteration:int, **kwargs):
|
|
|
- if iteration == 0: return
|
|
|
- self._update_batches_if_needed()
|
|
|
-
|
|
|
- #TODO: This could perhaps be implemented as queues of requests instead but that seemed like overkill.
|
|
|
- # But I'm not the biggest fan of maintaining these boolean flags either... Review pls.
|
|
|
- if iteration % self.stats_iters == 0:
|
|
|
- self.gen_stats_updated = False
|
|
|
- self.crit_stats_updated = False
|
|
|
-
|
|
|
- if not (self.gen_stats_updated and self.crit_stats_updated):
|
|
|
- self._write_model_stats(iteration=iteration)
|
|
|
-
|
|
|
-
|
|
|
-class ImageGenTensorboardWriter(LearnerTensorboardWriter):
|
|
|
- def __init__(self, learn:Learner, base_dir:Path, name:str, loss_iters:int=25, hist_iters:int=500,
|
|
|
- stats_iters: int = 100, visual_iters: int = 100):
|
|
|
- super().__init__(learn=learn, base_dir=base_dir, name=name, loss_iters=loss_iters, hist_iters=hist_iters,
|
|
|
- stats_iters=stats_iters)
|
|
|
- self.visual_iters = visual_iters
|
|
|
- self.img_gen_vis = ImageTBWriter()
|
|
|
-
|
|
|
- def _write(self, iteration:int):
|
|
|
- self.img_gen_vis.write(learn=self.learn, trn_batch=self.trn_batch, val_batch=self.val_batch,
|
|
|
- iteration=iteration, tbwriter=self.tbwriter)
|
|
|
-
|
|
|
- # override
|
|
|
- def on_batch_end(self, iteration:int, **kwargs):
|
|
|
- super().on_batch_end(iteration=iteration, **kwargs)
|
|
|
- if iteration == 0: return
|
|
|
-
|
|
|
- if iteration % self.visual_iters == 0:
|
|
|
- self._write(iteration=iteration)
|