callbacks.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. from fastai.core import *
  2. from fastai.sgdr import Callback
  3. from fastai.dataset import ModelData, ImageData
  4. from .visualize import ModelStatsVisualizer, ImageGenVisualizer, GANTrainerStatsVisualizer
  5. from .visualize import LearnerStatsVisualizer, ModelGraphVisualizer, ModelHistogramVisualizer
  6. from .training import GenResult, CriticResult, GANTrainer
  7. from .generators import GeneratorModule
  8. from tensorboardX import SummaryWriter
  9. def clear_directory(dir:Path):
  10. for f in dir.glob('*'):
  11. os.remove(f)
  12. class ModelVisualizationHook():
  13. def __init__(self, base_dir: Path, module: nn.Module, name: str, stats_iters: int=10):
  14. self.base_dir = base_dir
  15. self.name = name
  16. log_dir = base_dir/name
  17. clear_directory(log_dir)
  18. self.tbwriter = SummaryWriter(log_dir=str(log_dir))
  19. self.hook = module.register_forward_hook(self.forward_hook)
  20. self.stats_iters = stats_iters
  21. self.iter_count = 0
  22. self.model_vis = ModelStatsVisualizer()
  23. def forward_hook(self, module:nn.Module, input, output):
  24. self.iter_count += 1
  25. if self.iter_count % self.stats_iters == 0:
  26. self.model_vis.write_tensorboard_stats(module, iter_count=self.iter_count, tbwriter=self.tbwriter)
  27. def close(self):
  28. self.tbwriter.close()
  29. self.hook.remove()
  30. class GANVisualizationHook():
  31. def __init__(self, base_dir:Path, trainer:GANTrainer, name:str, stats_iters:int=10,
  32. visual_iters:int=200, weight_iters:int=1000, jupyter:bool=False):
  33. super().__init__()
  34. self.base_dir = base_dir
  35. self.name = name
  36. log_dir = base_dir/name
  37. clear_directory(log_dir)
  38. self.tbwriter = SummaryWriter(log_dir=str(log_dir))
  39. self.hooks = [trainer.register_train_loop_hook(self.train_loop_hook)]
  40. self.hooks.append(trainer.register_train_begin_hook(self.train_begin_hook))
  41. self.stats_iters = stats_iters
  42. self.visual_iters = visual_iters
  43. self.weight_iters = weight_iters
  44. self.jupyter=jupyter
  45. self.img_gen_vis = ImageGenVisualizer()
  46. self.stats_vis = GANTrainerStatsVisualizer()
  47. self.graph_vis = ModelGraphVisualizer()
  48. self.weight_vis = ModelHistogramVisualizer()
  49. self.trainer = trainer
  50. def train_begin_hook(self):
  51. ds = self.trainer.md.val_ds
  52. self.graph_vis.write_model_graph_to_tensorboard(ds=ds, model=self.trainer.netD, tbwriter=self.tbwriter)
  53. self.graph_vis.write_model_graph_to_tensorboard(ds=ds, model=self.trainer.netG, tbwriter=self.tbwriter)
  54. def train_loop_hook(self, gresult:GenResult, cresult:CriticResult):
  55. if self.trainer.iters % self.stats_iters == 0:
  56. self.stats_vis.print_stats_in_jupyter(gresult, cresult)
  57. self.stats_vis.write_tensorboard_stats(gresult, cresult, iter_count=self.trainer.iters, tbwriter=self.tbwriter)
  58. if self.trainer.iters % self.visual_iters == 0:
  59. model = self.trainer.netG
  60. self.img_gen_vis.output_image_gen_visuals(md=self.trainer.md, model=model, iter_count=self.trainer.iters,
  61. tbwriter=self.tbwriter)
  62. if self.trainer.iters % self.weight_iters == 0:
  63. self.weight_vis.write_tensorboard_histograms(model=self.trainer.netG, iter_count=self.trainer.iters, tbwriter=self.tbwriter)
  64. self.weight_vis.write_tensorboard_histograms(model=self.trainer.netD, iter_count=self.trainer.iters, tbwriter=self.tbwriter)
  65. def close(self):
  66. self.tbwriter.close()
  67. for hook in self.hooks:
  68. hook.remove()
  69. class ModelVisualizationCallback(Callback):
  70. def __init__(self, base_dir:Path, model:GeneratorModule, md:ModelData, name:str, stats_iters:int=25,
  71. visual_iters:int=200, weight_iters:int=25, jupyter:bool=False):
  72. super().__init__()
  73. self.base_dir = base_dir
  74. self.name = name
  75. log_dir = base_dir/name
  76. clear_directory(log_dir)
  77. self.tbwriter = SummaryWriter(log_dir=str(log_dir))
  78. self.stats_iters = stats_iters
  79. self.visual_iters = visual_iters
  80. self.weight_iters = weight_iters
  81. self.iter_count = 0
  82. self.model = model
  83. self.md = md
  84. self.jupyter = jupyter
  85. self.learner_vis = LearnerStatsVisualizer()
  86. self.graph_vis = ModelGraphVisualizer()
  87. self.weight_vis = ModelHistogramVisualizer()
  88. self.img_gen_vis = ImageGenVisualizer()
  89. def on_train_begin(self):
  90. self.output_model_graph()
  91. def on_batch_begin(self):
  92. return
  93. def on_phase_begin(self):
  94. return
  95. def on_epoch_end(self, metrics):
  96. self.output_stats(metrics=metrics)
  97. def on_phase_end(self):
  98. return
  99. def on_batch_end(self, metrics):
  100. self.iter_count += 1
  101. if self.iter_count % self.stats_iters == 0:
  102. self.output_stats(metrics=metrics)
  103. if self.iter_count % self.visual_iters == 0:
  104. self.output_visuals()
  105. if self.iter_count % self.weight_iters == 0:
  106. self.output_weights()
  107. def on_train_end(self):
  108. return
  109. def output_model_graph(self):
  110. self.graph_vis.write_model_graph_to_tensorboard(ds=self.md.val_ds, model=self.model, tbwriter=self.tbwriter)
  111. def output_stats(self, metrics):
  112. self.learner_vis.write_tensorboard_stats(metrics=metrics, iter_count=self.iter_count, tbwriter=self.tbwriter)
  113. def output_visuals(self):
  114. self.img_gen_vis.output_image_gen_visuals(md=self.md, model=self.model, iter_count=self.iter_count,
  115. tbwriter=self.tbwriter, jupyter=self.jupyter)
  116. def output_weights(self):
  117. self.weight_vis.write_tensorboard_histograms(model=self.model, iter_count=self.iter_count, tbwriter=self.tbwriter)
  118. def close(self):
  119. self.tbwriter.close()
  120. class ImageGenVisualizationCallback(ModelVisualizationCallback):
  121. def __init__(self, base_dir: Path, model: GeneratorModule, md: ImageData, name: str, stats_iters: int=25, visual_iters: int=200, jupyter:bool=False):
  122. super().__init__(base_dir=base_dir, model=model, md=md, name=name, stats_iters=stats_iters, visual_iters=visual_iters, jupyter=jupyter)
  123. self.img_gen_vis = ImageGenVisualizer()
  124. def output_visuals(self):
  125. super().output_visuals()
  126. self.img_gen_vis.output_image_gen_visuals(md=self.md, model=self.model, iter_count=self.iter_count,
  127. tbwriter=self.tbwriter, jupyter=self.jupyter)