callbacks.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. from fastai.core import *
  2. from fastai.sgdr import Callback
  3. from fastai.dataset import ModelData, ImageData
  4. from fasterai.visualize import ModelStatsVisualizer, ImageGenVisualizer, GANTrainerStatsVisualizer
  5. from fasterai.visualize import LearnerStatsVisualizer, ModelGraphVisualizer, ModelHistogramVisualizer
  6. from fasterai.training import GenResult, CriticResult, GANTrainer
  7. from tensorboardX import SummaryWriter
  8. def clear_directory(dir:Path):
  9. for f in dir.glob('*'):
  10. os.remove(f)
  11. class ModelVisualizationHook():
  12. def __init__(self, base_dir: Path, module: nn.Module, name: str, stats_iters: int=10):
  13. self.base_dir = base_dir
  14. self.name = name
  15. log_dir = base_dir/name
  16. clear_directory(log_dir)
  17. self.tbwriter = SummaryWriter(log_dir=str(log_dir))
  18. self.hook = module.register_forward_hook(self.forward_hook)
  19. self.stats_iters = stats_iters
  20. self.iter_count = 0
  21. self.model_vis = ModelStatsVisualizer()
  22. def forward_hook(self, module:nn.Module, input, output):
  23. self.iter_count += 1
  24. if self.iter_count % self.stats_iters == 0:
  25. self.model_vis.write_tensorboard_stats(module, iter_count=self.iter_count, tbwriter=self.tbwriter)
  26. def close(self):
  27. self.tbwriter.close()
  28. self.hook.remove()
  29. class GANVisualizationHook():
  30. def __init__(self, base_dir:Path, trainer:GANTrainer, name:str, stats_iters:int=10,
  31. visual_iters:int=200, weight_iters:int=1000, jupyter:bool=False):
  32. super().__init__()
  33. self.base_dir = base_dir
  34. self.name = name
  35. log_dir = base_dir/name
  36. clear_directory(log_dir)
  37. self.tbwriter = SummaryWriter(log_dir=str(log_dir))
  38. self.hooks = [trainer.register_train_loop_hook(self.train_loop_hook)]
  39. self.hooks.append(trainer.register_train_begin_hook(self.train_begin_hook))
  40. self.stats_iters = stats_iters
  41. self.visual_iters = visual_iters
  42. self.weight_iters = weight_iters
  43. self.jupyter=jupyter
  44. self.img_gen_vis = ImageGenVisualizer()
  45. self.stats_vis = GANTrainerStatsVisualizer()
  46. self.graph_vis = ModelGraphVisualizer()
  47. self.weight_vis = ModelHistogramVisualizer()
  48. self.trainer = trainer
  49. def train_begin_hook(self):
  50. ds = self.trainer.md.val_ds
  51. self.graph_vis.write_model_graph_to_tensorboard(ds=ds, model=self.trainer.netD, tbwriter=self.tbwriter)
  52. self.graph_vis.write_model_graph_to_tensorboard(ds=ds, model=self.trainer.netG, tbwriter=self.tbwriter)
  53. def train_loop_hook(self, gresult:GenResult, cresult:CriticResult):
  54. if self.trainer.iters % self.stats_iters == 0:
  55. self.stats_vis.print_stats_in_jupyter(gresult, cresult)
  56. self.stats_vis.write_tensorboard_stats(gresult, cresult, iter_count=self.trainer.iters, tbwriter=self.tbwriter)
  57. if self.trainer.iters % self.visual_iters == 0:
  58. model = self.trainer.netG
  59. self.img_gen_vis.output_image_gen_visuals(md=self.trainer.md, model=model, iter_count=self.trainer.iters,
  60. tbwriter=self.tbwriter, jupyter=self.jupyter)
  61. if self.trainer.iters % self.weight_iters == 0:
  62. self.weight_vis.write_tensorboard_histograms(model=self.trainer.netG, iter_count=self.trainer.iters, tbwriter=self.tbwriter)
  63. self.weight_vis.write_tensorboard_histograms(model=self.trainer.netD, iter_count=self.trainer.iters, tbwriter=self.tbwriter)
  64. def close(self):
  65. self.tbwriter.close()
  66. for hook in self.hooks:
  67. hook.remove()
  68. class ModelVisualizationCallback(Callback):
  69. def __init__(self, base_dir:Path, model:nn.Module, md:ModelData, name:str, stats_iters:int=25,
  70. visual_iters:int=200, weight_iters:int=25, jupyter:bool=False):
  71. super().__init__()
  72. self.base_dir = base_dir
  73. self.name = name
  74. log_dir = base_dir/name
  75. clear_directory(log_dir)
  76. self.tbwriter = SummaryWriter(log_dir=str(log_dir))
  77. self.stats_iters = stats_iters
  78. self.visual_iters = visual_iters
  79. self.weight_iters = weight_iters
  80. self.iter_count = 0
  81. self.model = model
  82. self.md = md
  83. self.jupyter = jupyter
  84. self.learner_vis = LearnerStatsVisualizer()
  85. self.graph_vis = ModelGraphVisualizer()
  86. self.weight_vis = ModelHistogramVisualizer()
  87. self.img_gen_vis = ImageGenVisualizer()
  88. def on_train_begin(self):
  89. self.output_model_graph()
  90. def on_batch_begin(self):
  91. return
  92. def on_phase_begin(self):
  93. return
  94. def on_epoch_end(self, metrics):
  95. self.output_stats(metrics=metrics)
  96. def on_phase_end(self):
  97. return
  98. def on_batch_end(self, metrics):
  99. self.iter_count += 1
  100. if self.iter_count % self.stats_iters == 0:
  101. self.output_stats(metrics=metrics)
  102. if self.iter_count % self.visual_iters == 0:
  103. self.output_visuals()
  104. if self.iter_count % self.weight_iters == 0:
  105. self.output_weights()
  106. def on_train_end(self):
  107. return
  108. def output_model_graph(self):
  109. self.graph_vis.write_model_graph_to_tensorboard(ds=self.md.val_ds, model=self.model, tbwriter=self.tbwriter)
  110. def output_stats(self, metrics):
  111. self.learner_vis.write_tensorboard_stats(metrics=metrics, iter_count=self.iter_count, tbwriter=self.tbwriter)
  112. def output_visuals(self):
  113. self.img_gen_vis.output_image_gen_visuals(md=self.md, model=self.model, iter_count=self.iter_count,
  114. tbwriter=self.tbwriter, jupyter=self.jupyter)
  115. def output_weights(self):
  116. self.weight_vis.write_tensorboard_histograms(model=self.model, iter_count=self.iter_count, tbwriter=self.tbwriter)
  117. def close(self):
  118. self.tbwriter.close()
  119. class ImageGenVisualizationCallback(ModelVisualizationCallback):
  120. def __init__(self, base_dir: Path, model: nn.Module, md: ImageData, name: str, stats_iters: int=25, visual_iters: int=200, jupyter:bool=False):
  121. super().__init__(base_dir=base_dir, model=model, md=md, name=name, stats_iters=stats_iters, visual_iters=visual_iters, jupyter=jupyter)
  122. self.img_gen_vis = ImageGenVisualizer()
  123. def output_visuals(self):
  124. super().output_visuals()
  125. self.img_gen_vis.output_image_gen_visuals(md=self.md, model=self.model, iter_count=self.iter_count,
  126. tbwriter=self.tbwriter, jupyter=self.jupyter)