tensorboard.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369
  1. import fastai
  2. from fastai.basic_train import Learner
  3. from fastai.basic_data import DatasetType, DataBunch
  4. from fastai.vision import Image
  5. from fastai.callbacks import LearnerCallback
  6. from fastai.core import *
  7. from fastai.torch_core import *
  8. from threading import Thread
  9. import time
  10. from time import sleep
  11. from queue import Queue
  12. import statistics
  13. import torchvision.utils as vutils
  14. from abc import ABC, abstractmethod
  15. from tensorboardX import SummaryWriter
  16. class AsyncTBWriter(ABC):
  17. def __init__(self):
  18. super().__init__()
  19. self.exit = False
  20. self.queue = Queue()
  21. self.thread = Thread(target=self._queue_processor)
  22. self.thread.start()
  23. def _queue_processor(self):
  24. while not self.exit:
  25. while not self.queue.empty():
  26. request = self.queue.get()
  27. self._write_async(request)
  28. sleep(0.1)
  29. @abstractmethod
  30. def _write_async(self, request):
  31. pass
  32. def __del__(self):
  33. self.exit = True
  34. self.thread.join()
  35. class ModelImageSet():
  36. @staticmethod
  37. def get_list_from_model(learn:Learner, ds_type:DatasetType, batch:Tuple)->[]:
  38. image_sets = []
  39. x,y = batch[0],batch[1]
  40. preds = learn.pred_batch(ds_type=ds_type, batch=(x,y), reconstruct=True)
  41. for orig_px, real_px, gen in zip(x,y,preds):
  42. orig = Image(px=orig_px)
  43. real = Image(px=real_px)
  44. image_set = ModelImageSet(orig=orig, real=real, gen=gen)
  45. image_sets.append(image_set)
  46. return image_sets
  47. def __init__(self, orig:Image, real:Image, gen:Image):
  48. self.orig = orig
  49. self.real = real
  50. self.gen = gen
  51. #TODO: There aren't any callbacks using this yet. Not sure if we want this included (not sure if it's useful, honestly)
  52. class ModelGraphTBWriter():
  53. def __init__(self):
  54. return
  55. def write_model_graph_to_tensorboard(self, md:DataBunch, model:nn.Module, tbwriter:SummaryWriter):
  56. x,y = md.one_batch(ds_type=DatasetType.Valid, detach=False, denorm=False)
  57. tbwriter.add_graph(model=model, input_to_model=x)
  58. class HistogramTBRequest():
  59. def __init__(self, model:nn.Module, iteration:int, tbwriter:SummaryWriter, name:str):
  60. self.params = [(name, values.clone().detach()) for (name, values) in model.named_parameters()]
  61. self.iteration = iteration
  62. self.tbwriter = tbwriter
  63. self.name = name
  64. #If this isn't done async then this is sloooooow
  65. class HistogramTBWriter(AsyncTBWriter):
  66. def __init__(self):
  67. super().__init__()
  68. # override
  69. def _write_async(self, request:HistogramTBRequest):
  70. try:
  71. params = request.params
  72. iteration = request.iteration
  73. tbwriter = request.tbwriter
  74. name = request.name
  75. for param_name, values in params:
  76. tag = name + '/weights/' + param_name
  77. tbwriter.add_histogram(tag=tag, values=values, global_step=iteration)
  78. except Exception as e:
  79. print(("Failed to write model histograms to Tensorboard: {0}").format(e))
  80. def write_tensorboard_histograms(self, model:nn.Module, iteration:int, tbwriter:SummaryWriter, name:str='model'):
  81. request = HistogramTBRequest(model, iteration, tbwriter, name)
  82. self.queue.put(request)
  83. #This is pretty speedy- Don't think we need async writes here
  84. class ModelStatsTBWriter():
  85. def __init__(self):
  86. self.gradients_root = '/gradients/'
  87. def write_tensorboard_stats(self, model:nn.Module, iteration:int, tbwriter:SummaryWriter, name:str='model_stats'):
  88. gradients = [x.grad for x in model.parameters() if x.grad is not None]
  89. gradient_nps = [to_np(x.data) for x in gradients]
  90. if len(gradients) == 0: return
  91. avg_norm = sum(x.data.norm() for x in gradients)/len(gradients)
  92. tbwriter.add_scalar(
  93. tag=name + self.gradients_root + 'avg_norm', scalar_value=avg_norm, global_step=iteration)
  94. median_norm = statistics.median(x.data.norm() for x in gradients)
  95. tbwriter.add_scalar(
  96. tag=name + self.gradients_root + 'median_norm', scalar_value=median_norm, global_step=iteration)
  97. max_norm = max(x.data.norm() for x in gradients)
  98. tbwriter.add_scalar(
  99. tag=name + self.gradients_root + 'max_norm', scalar_value=max_norm, global_step=iteration)
  100. min_norm = min(x.data.norm() for x in gradients)
  101. tbwriter.add_scalar(
  102. tag=name + self.gradients_root + 'min_norm', scalar_value=min_norm, global_step=iteration)
  103. num_zeros = sum((np.asarray(x) == 0.0).sum() for x in gradient_nps)
  104. tbwriter.add_scalar(
  105. tag=name + self.gradients_root + 'num_zeros', scalar_value=num_zeros, global_step=iteration)
  106. avg_gradient = sum(x.data.mean() for x in gradients)/len(gradients)
  107. tbwriter.add_scalar(
  108. tag=name + self.gradients_root + 'avg_gradient', scalar_value=avg_gradient, global_step=iteration)
  109. median_gradient = statistics.median(x.data.median() for x in gradients)
  110. tbwriter.add_scalar(
  111. tag=name + self.gradients_root + 'median_gradient', scalar_value=median_gradient, global_step=iteration)
  112. max_gradient = max(x.data.max() for x in gradients)
  113. tbwriter.add_scalar(
  114. tag=name + self.gradients_root + 'max_gradient', scalar_value=max_gradient, global_step=iteration)
  115. min_gradient = min(x.data.min() for x in gradients)
  116. tbwriter.add_scalar(
  117. tag=name + self.gradients_root + 'min_gradient', scalar_value=min_gradient, global_step=iteration)
  118. class ImageTBRequest():
  119. def __init__(self, learn:Learner, batch:Tuple, iteration:int, tbwriter:SummaryWriter, ds_type:DatasetType):
  120. self.image_sets = ModelImageSet.get_list_from_model(learn=learn, batch=batch, ds_type=ds_type)
  121. self.iteration = iteration
  122. self.tbwriter = tbwriter
  123. self.ds_type = ds_type
  124. #If this isn't done async then this is noticeably slower
  125. class ImageTBWriter(AsyncTBWriter):
  126. def __init__(self):
  127. super().__init__()
  128. # override
  129. def _write_async(self, request:ImageTBRequest):
  130. try:
  131. orig_images = []
  132. gen_images = []
  133. real_images = []
  134. for image_set in request.image_sets:
  135. orig_images.append(image_set.orig.px)
  136. gen_images.append(image_set.gen.px)
  137. real_images.append(image_set.real.px)
  138. prefix = request.ds_type.name
  139. tbwriter = request.tbwriter
  140. iteration = request.iteration
  141. tbwriter.add_image(
  142. tag=prefix + ' orig images', img_tensor=vutils.make_grid(orig_images, normalize=True), global_step=iteration)
  143. tbwriter.add_image(
  144. tag=prefix + ' gen images', img_tensor=vutils.make_grid(gen_images, normalize=True), global_step=iteration)
  145. tbwriter.add_image(
  146. tag=prefix + ' real images', img_tensor=vutils.make_grid(real_images, normalize=True), global_step=iteration)
  147. except Exception as e:
  148. print(("Failed to write images to Tensorboard: {0}").format(e))
  149. def write_images(self, learn:Learner, trn_batch:Tuple, val_batch:Tuple, iteration:int, tbwriter:SummaryWriter):
  150. self._write_images_for_dstype(learn=learn, batch=val_batch, iteration=iteration,
  151. tbwriter=tbwriter, ds_type=DatasetType.Valid)
  152. self._write_images_for_dstype(learn=learn, batch=trn_batch, iteration=iteration,
  153. tbwriter=tbwriter, ds_type=DatasetType.Train)
  154. def _write_images_for_dstype(self, learn:Learner, batch:Tuple, iteration:int, tbwriter:SummaryWriter, ds_type:DatasetType):
  155. request = ImageTBRequest(learn=learn, batch=batch, iteration=iteration, tbwriter=tbwriter, ds_type=ds_type)
  156. self.queue.put(request)
  157. #--------CALLBACKS----------------#
  158. class LearnerTensorboardWriter(LearnerCallback):
  159. def __init__(self, learn:Learner, base_dir:Path, name:str, loss_iters:int=25, hist_iters:int=1000, stats_iters:int=1000):
  160. super().__init__(learn=learn)
  161. self.base_dir = base_dir
  162. self.name = name
  163. log_dir = base_dir/name
  164. self.tbwriter = SummaryWriter(log_dir=str(log_dir))
  165. self.loss_iters = loss_iters
  166. self.hist_iters = hist_iters
  167. self.stats_iters = stats_iters
  168. self.hist_writer = HistogramTBWriter()
  169. self.stats_writer = ModelStatsTBWriter()
  170. self.data = None
  171. self.metrics_root = '/metrics/'
  172. def _update_batches_if_needed(self):
  173. # one_batch function is extremely slow. this is an optimization
  174. update_batches = self.data is not self.learn.data
  175. if update_batches:
  176. self.data = self.learn.data
  177. self.trn_batch = self.learn.data.one_batch(
  178. ds_type=DatasetType.Train, detach=True, denorm=False, cpu=False)
  179. self.val_batch = self.learn.data.one_batch(
  180. ds_type=DatasetType.Valid, detach=True, denorm=False, cpu=False)
  181. def _write_model_stats(self, iteration:int):
  182. self.stats_writer.write_tensorboard_stats(
  183. model=self.learn.model, iteration=iteration, tbwriter=self.tbwriter)
  184. def _write_training_loss(self, iteration:int, last_loss:Tensor):
  185. scalar_value = to_np(last_loss)
  186. tag = self.metrics_root + 'train_loss'
  187. self.tbwriter.add_scalar(tag=tag, scalar_value=scalar_value, global_step=iteration)
  188. def _write_weight_histograms(self, iteration:int):
  189. self.hist_writer.write_tensorboard_histograms(
  190. model=self.learn.model, iteration=iteration, tbwriter=self.tbwriter)
  191. #TODO: Relying on a specific hardcoded start_idx here isn't great. Is there a better solution?
  192. def _write_metrics(self, iteration:int, last_metrics:MetricsList, start_idx:int=2):
  193. recorder = self.learn.recorder
  194. for i, name in enumerate(recorder.names[start_idx:]):
  195. if len(last_metrics) < i+1: return
  196. scalar_value = last_metrics[i]
  197. tag = self.metrics_root + name
  198. self.tbwriter.add_scalar(tag=tag, scalar_value=scalar_value, global_step=iteration)
  199. def on_batch_end(self, last_loss:Tensor, iteration:int, **kwargs):
  200. if iteration == 0: return
  201. self._update_batches_if_needed()
  202. if iteration % self.loss_iters == 0:
  203. self._write_training_loss(iteration=iteration, last_loss=last_loss)
  204. if iteration % self.hist_iters == 0:
  205. self._write_weight_histograms(iteration=iteration)
  206. # Doing stuff here that requires gradient info, because they get zeroed out afterwards in training loop
  207. def on_backward_end(self, iteration:int, **kwargs):
  208. if iteration == 0: return
  209. self._update_batches_if_needed()
  210. if iteration % self.stats_iters == 0:
  211. self._write_model_stats(iteration=iteration)
  212. def on_epoch_end(self, last_metrics:MetricsList, iteration:int, **kwargs):
  213. self._write_metrics(iteration=iteration, last_metrics=last_metrics)
  214. # TODO: We're overriding almost everything here. Seems like a good idea to question that ("is a" vs "has a")
  215. class GANTensorboardWriter(LearnerTensorboardWriter):
  216. def __init__(self, learn:Learner, base_dir:Path, name:str, loss_iters:int=25, hist_iters:int=1000,
  217. stats_iters:int=1000, visual_iters:int=100):
  218. super().__init__(learn=learn, base_dir=base_dir, name=name, loss_iters=loss_iters,
  219. hist_iters=hist_iters, stats_iters=stats_iters)
  220. self.visual_iters = visual_iters
  221. self.img_gen_vis = ImageTBWriter()
  222. self.gen_stats_updated = True
  223. self.crit_stats_updated = True
  224. # override
  225. def _write_weight_histograms(self, iteration:int):
  226. trainer = self.learn.gan_trainer
  227. generator = trainer.generator
  228. critic = trainer.critic
  229. self.hist_writer.write_tensorboard_histograms(
  230. model=generator, iteration=iteration, tbwriter=self.tbwriter, name='generator')
  231. self.hist_writer.write_tensorboard_histograms(
  232. model=critic, iteration=iteration, tbwriter=self.tbwriter, name='critic')
  233. # override
  234. def _write_model_stats(self, iteration:int):
  235. trainer = self.learn.gan_trainer
  236. generator = trainer.generator
  237. critic = trainer.critic
  238. # Don't want to write stats when model is not iterated on and hence has zeroed out gradients
  239. gen_mode = trainer.gen_mode
  240. if gen_mode and not self.gen_stats_updated:
  241. self.stats_writer.write_tensorboard_stats(
  242. model=generator, iteration=iteration, tbwriter=self.tbwriter, name='gen_model_stats')
  243. self.gen_stats_updated = True
  244. if not gen_mode and not self.crit_stats_updated:
  245. self.stats_writer.write_tensorboard_stats(
  246. model=critic, iteration=iteration, tbwriter=self.tbwriter, name='crit_model_stats')
  247. self.crit_stats_updated = True
  248. # override
  249. def _write_training_loss(self, iteration:int, last_loss:Tensor):
  250. trainer = self.learn.gan_trainer
  251. recorder = trainer.recorder
  252. if len(recorder.losses) > 0:
  253. scalar_value = to_np((recorder.losses[-1:])[0])
  254. tag = self.metrics_root + 'train_loss'
  255. self.tbwriter.add_scalar(tag=tag, scalar_value=scalar_value, global_step=iteration)
  256. def _write_images(self, iteration:int):
  257. trainer = self.learn.gan_trainer
  258. #TODO: Switching gen_mode temporarily seems a bit hacky here. Certainly not a good side-effect. Is there a better way?
  259. gen_mode = trainer.gen_mode
  260. try:
  261. trainer.switch(gen_mode=True)
  262. self.img_gen_vis.write_images(learn=self.learn, trn_batch=self.trn_batch, val_batch=self.val_batch,
  263. iteration=iteration, tbwriter=self.tbwriter)
  264. finally:
  265. trainer.switch(gen_mode=gen_mode)
  266. # override
  267. def on_batch_end(self, iteration:int, **kwargs):
  268. super().on_batch_end(iteration=iteration, **kwargs)
  269. if iteration == 0: return
  270. if iteration % self.visual_iters == 0:
  271. self._write_images(iteration=iteration)
  272. # override
  273. def on_backward_end(self, iteration:int, **kwargs):
  274. if iteration == 0: return
  275. self._update_batches_if_needed()
  276. #TODO: This could perhaps be implemented as queues of requests instead but that seemed like overkill.
  277. # But I'm not the biggest fan of maintaining these boolean flags either... Review pls.
  278. if iteration % self.stats_iters == 0:
  279. self.gen_stats_updated = False
  280. self.crit_stats_updated = False
  281. if not (self.gen_stats_updated and self.crit_stats_updated):
  282. self._write_model_stats(iteration=iteration)
  283. class ImageGenTensorboardWriter(LearnerTensorboardWriter):
  284. def __init__(self, learn:Learner, base_dir:Path, name:str, loss_iters:int=25, hist_iters:int=1000,
  285. stats_iters: int = 1000, visual_iters: int = 100):
  286. super().__init__(learn=learn, base_dir=base_dir, name=name, loss_iters=loss_iters, hist_iters=hist_iters,
  287. stats_iters=stats_iters)
  288. self.visual_iters = visual_iters
  289. self.img_gen_vis = ImageTBWriter()
  290. def _write_images(self, iteration:int):
  291. self.img_gen_vis.write_images(learn=self.learn, trn_batch=self.trn_batch, val_batch=self.val_batch,
  292. iteration=iteration, tbwriter=self.tbwriter)
  293. # override
  294. def on_batch_end(self, iteration:int, **kwargs):
  295. super().on_batch_end(iteration=iteration, **kwargs)
  296. if iteration == 0: return
  297. if iteration % self.visual_iters == 0:
  298. self._write_images(iteration=iteration)