tensorboard.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402
  1. "Provides convenient callbacks for Learners that write model images, metrics/losses, stats and histograms to Tensorboard"
  2. from fastai.basic_train import Learner
  3. from fastai.basic_data import DatasetType, DataBunch
  4. from fastai.vision import Image
  5. from fastai.callbacks import LearnerCallback
  6. from fastai.core import *
  7. from fastai.torch_core import *
  8. from threading import Thread, Event
  9. from time import sleep
  10. from queue import Queue
  11. import statistics
  12. import torchvision.utils as vutils
  13. from abc import ABC, abstractmethod
  14. from tensorboardX import SummaryWriter
  15. __all__=['LearnerTensorboardWriter', 'GANTensorboardWriter', 'ImageGenTensorboardWriter']
  16. class LearnerTensorboardWriter(LearnerCallback):
  17. def __init__(self, learn:Learner, base_dir:Path, name:str, loss_iters:int=25, hist_iters:int=500, stats_iters:int=100):
  18. super().__init__(learn=learn)
  19. self.base_dir = base_dir
  20. self.name = name
  21. log_dir = base_dir/name
  22. self.tbwriter = SummaryWriter(log_dir=str(log_dir))
  23. self.loss_iters = loss_iters
  24. self.hist_iters = hist_iters
  25. self.stats_iters = stats_iters
  26. self.hist_writer = HistogramTBWriter()
  27. self.stats_writer = ModelStatsTBWriter()
  28. self.data = None
  29. self.metrics_root = '/metrics/'
  30. self._update_batches_if_needed()
  31. def _update_batches_if_needed(self):
  32. # one_batch function is extremely slow with large datasets. This is an optimization.
  33. # Note that also we want to always show the same batches so we can see changes
  34. # in tensorboard
  35. update_batches = self.data is not self.learn.data
  36. if update_batches:
  37. self.data = self.learn.data
  38. self.trn_batch = self.learn.data.one_batch(
  39. ds_type=DatasetType.Train, detach=True, denorm=False, cpu=False)
  40. self.val_batch = self.learn.data.one_batch(
  41. ds_type=DatasetType.Valid, detach=True, denorm=False, cpu=False)
  42. def _write_model_stats(self, iteration:int):
  43. self.stats_writer.write(
  44. model=self.learn.model, iteration=iteration, tbwriter=self.tbwriter)
  45. def _write_training_loss(self, iteration:int, last_loss:Tensor):
  46. scalar_value = to_np(last_loss)
  47. tag = self.metrics_root + 'train_loss'
  48. self.tbwriter.add_scalar(tag=tag, scalar_value=scalar_value, global_step=iteration)
  49. def _write_weight_histograms(self, iteration:int):
  50. self.hist_writer.write(
  51. model=self.learn.model, iteration=iteration, tbwriter=self.tbwriter)
  52. #TODO: Relying on a specific hardcoded start_idx here isn't great. Is there a better solution?
  53. def _write_metrics(self, iteration:int, last_metrics:MetricsList, start_idx:int=2):
  54. recorder = self.learn.recorder
  55. for i, name in enumerate(recorder.names[start_idx:]):
  56. if len(last_metrics) < i+1: return
  57. scalar_value = last_metrics[i]
  58. tag = self.metrics_root + name
  59. self.tbwriter.add_scalar(tag=tag, scalar_value=scalar_value, global_step=iteration)
  60. def on_batch_end(self, last_loss:Tensor, iteration:int, **kwargs):
  61. if iteration == 0: return
  62. self._update_batches_if_needed()
  63. if iteration % self.loss_iters == 0:
  64. self._write_training_loss(iteration=iteration, last_loss=last_loss)
  65. if iteration % self.hist_iters == 0:
  66. self._write_weight_histograms(iteration=iteration)
  67. # Doing stuff here that requires gradient info, because they get zeroed out afterwards in training loop
  68. def on_backward_end(self, iteration:int, **kwargs):
  69. if iteration == 0: return
  70. self._update_batches_if_needed()
  71. if iteration % self.stats_iters == 0:
  72. self._write_model_stats(iteration=iteration)
  73. def on_epoch_end(self, last_metrics:MetricsList, iteration:int, **kwargs):
  74. self._write_metrics(iteration=iteration, last_metrics=last_metrics)
  75. # TODO: We're overriding almost everything here. Seems like a good idea to question that ("is a" vs "has a")
  76. class GANTensorboardWriter(LearnerTensorboardWriter):
  77. def __init__(self, learn:Learner, base_dir:Path, name:str, loss_iters:int=25, hist_iters:int=500,
  78. stats_iters:int=100, visual_iters:int=100):
  79. super().__init__(learn=learn, base_dir=base_dir, name=name, loss_iters=loss_iters,
  80. hist_iters=hist_iters, stats_iters=stats_iters)
  81. self.visual_iters = visual_iters
  82. self.img_gen_vis = ImageTBWriter()
  83. self.gen_stats_updated = True
  84. self.crit_stats_updated = True
  85. # override
  86. def _write_weight_histograms(self, iteration:int):
  87. trainer = self.learn.gan_trainer
  88. generator = trainer.generator
  89. critic = trainer.critic
  90. self.hist_writer.write(
  91. model=generator, iteration=iteration, tbwriter=self.tbwriter, name='generator')
  92. self.hist_writer.write(
  93. model=critic, iteration=iteration, tbwriter=self.tbwriter, name='critic')
  94. # override
  95. def _write_model_stats(self, iteration:int):
  96. trainer = self.learn.gan_trainer
  97. generator = trainer.generator
  98. critic = trainer.critic
  99. # Don't want to write stats when model is not iterated on and hence has zeroed out gradients
  100. gen_mode = trainer.gen_mode
  101. if gen_mode and not self.gen_stats_updated:
  102. self.stats_writer.write(
  103. model=generator, iteration=iteration, tbwriter=self.tbwriter, name='gen_model_stats')
  104. self.gen_stats_updated = True
  105. if not gen_mode and not self.crit_stats_updated:
  106. self.stats_writer.write(
  107. model=critic, iteration=iteration, tbwriter=self.tbwriter, name='crit_model_stats')
  108. self.crit_stats_updated = True
  109. # override
  110. def _write_training_loss(self, iteration:int, last_loss:Tensor):
  111. trainer = self.learn.gan_trainer
  112. recorder = trainer.recorder
  113. if len(recorder.losses) > 0:
  114. scalar_value = to_np((recorder.losses[-1:])[0])
  115. tag = self.metrics_root + 'train_loss'
  116. self.tbwriter.add_scalar(tag=tag, scalar_value=scalar_value, global_step=iteration)
  117. def _write(self, iteration:int):
  118. trainer = self.learn.gan_trainer
  119. #TODO: Switching gen_mode temporarily seems a bit hacky here. Certainly not a good side-effect. Is there a better way?
  120. gen_mode = trainer.gen_mode
  121. try:
  122. trainer.switch(gen_mode=True)
  123. self.img_gen_vis.write(learn=self.learn, trn_batch=self.trn_batch, val_batch=self.val_batch,
  124. iteration=iteration, tbwriter=self.tbwriter)
  125. finally:
  126. trainer.switch(gen_mode=gen_mode)
  127. # override
  128. def on_batch_end(self, iteration:int, **kwargs):
  129. super().on_batch_end(iteration=iteration, **kwargs)
  130. if iteration == 0: return
  131. if iteration % self.visual_iters == 0:
  132. self._write(iteration=iteration)
  133. # override
  134. def on_backward_end(self, iteration:int, **kwargs):
  135. if iteration == 0: return
  136. self._update_batches_if_needed()
  137. #TODO: This could perhaps be implemented as queues of requests instead but that seemed like overkill.
  138. # But I'm not the biggest fan of maintaining these boolean flags either... Review pls.
  139. if iteration % self.stats_iters == 0:
  140. self.gen_stats_updated = False
  141. self.crit_stats_updated = False
  142. if not (self.gen_stats_updated and self.crit_stats_updated):
  143. self._write_model_stats(iteration=iteration)
  144. class ImageGenTensorboardWriter(LearnerTensorboardWriter):
  145. def __init__(self, learn:Learner, base_dir:Path, name:str, loss_iters:int=25, hist_iters:int=500,
  146. stats_iters: int = 100, visual_iters: int = 100):
  147. super().__init__(learn=learn, base_dir=base_dir, name=name, loss_iters=loss_iters, hist_iters=hist_iters,
  148. stats_iters=stats_iters)
  149. self.visual_iters = visual_iters
  150. self.img_gen_vis = ImageTBWriter()
  151. def _write(self, iteration:int):
  152. self.img_gen_vis.write(learn=self.learn, trn_batch=self.trn_batch, val_batch=self.val_batch,
  153. iteration=iteration, tbwriter=self.tbwriter)
  154. # override
  155. def on_batch_end(self, iteration:int, **kwargs):
  156. super().on_batch_end(iteration=iteration, **kwargs)
  157. if iteration == 0: return
  158. if iteration % self.visual_iters == 0:
  159. self._write(iteration=iteration)
  160. #------PRIVATE-----------
  161. class TBWriteRequest(ABC):
  162. def __init__(self, tbwriter: SummaryWriter, iteration:int):
  163. super().__init__()
  164. self.tbwriter = tbwriter
  165. self.iteration = iteration
  166. @abstractmethod
  167. def write(self):
  168. pass
  169. # SummaryWriter writes tend to block quite a bit. This gets around that and greatly boosts performance.
  170. # Not all tensorboard writes are using this- just the ones that take a long time. Note that the
  171. # SummaryWriter does actually use a threadsafe consumer/producer design ultimately to write to Tensorboard,
  172. # so writes done outside of this async loop should be fine.
  173. class AsyncTBWriter():
  174. def __init__(self):
  175. super().__init__()
  176. self.stop_request = Event()
  177. self.queue = Queue()
  178. self.thread = Thread(target=self._queue_processor, daemon=True)
  179. self.thread.start()
  180. def request_write(self, request: TBWriteRequest):
  181. if self.stop_request.isSet():
  182. raise Exception('Close was already called! Cannot perform this operation.')
  183. self.queue.put(request)
  184. def _queue_processor(self):
  185. while not self.stop_request.isSet():
  186. while not self.queue.empty():
  187. request = self.queue.get()
  188. request.write()
  189. sleep(0.2)
  190. #Provided this to stop thread explicitly or by context management (with statement) but thread should end on its own
  191. # upon program exit, due to being a daemon. So using this is probably unecessary.
  192. def close(self):
  193. self.stop_request.set()
  194. self.thread.join()
  195. def __enter__(self):
  196. # Nothing to do, thread already started. Could start thread here to enforce use of context manager
  197. # (but that sounds like a pain and a bit unweildy and unecessary for actual usage)
  198. pass
  199. def __exit__(self, exc_type, exc_value, traceback):
  200. self.close()
  201. asyncTBWriter = AsyncTBWriter()
  202. class ModelImageSet():
  203. @staticmethod
  204. def get_list_from_model(learn:Learner, ds_type:DatasetType, batch:Tuple)->[]:
  205. image_sets = []
  206. x,y = batch[0],batch[1]
  207. preds = learn.pred_batch(ds_type=ds_type, batch=(x,y), reconstruct=True)
  208. for orig_px, real_px, gen in zip(x,y,preds):
  209. orig = Image(px=orig_px)
  210. real = Image(px=real_px)
  211. image_set = ModelImageSet(orig=orig, real=real, gen=gen)
  212. image_sets.append(image_set)
  213. return image_sets
  214. def __init__(self, orig:Image, real:Image, gen:Image):
  215. self.orig = orig
  216. self.real = real
  217. self.gen = gen
  218. class HistogramTBRequest(TBWriteRequest):
  219. def __init__(self, model:nn.Module, iteration:int, tbwriter:SummaryWriter, name:str):
  220. super().__init__(tbwriter=tbwriter, iteration=iteration)
  221. self.params = [(name, values.clone().detach().cpu()) for (name, values) in model.named_parameters()]
  222. self.name = name
  223. # override
  224. def write(self):
  225. try:
  226. for param_name, values in self.params:
  227. tag = self.name + '/weights/' + param_name
  228. self.tbwriter.add_histogram(tag=tag, values=values, global_step=self.iteration)
  229. except Exception as e:
  230. print(("Failed to write model histograms to Tensorboard: {0}").format(e))
  231. #If this isn't done async then this is sloooooow
  232. class HistogramTBWriter():
  233. def __init__(self):
  234. super().__init__()
  235. def write(self, model:nn.Module, iteration:int, tbwriter:SummaryWriter, name:str='model'):
  236. request = HistogramTBRequest(model=model, iteration=iteration, tbwriter=tbwriter, name=name)
  237. asyncTBWriter.request_write(request)
  238. class ModelStatsTBRequest(TBWriteRequest):
  239. def __init__(self, model:nn.Module, iteration:int, tbwriter:SummaryWriter, name:str):
  240. super().__init__(tbwriter=tbwriter, iteration=iteration)
  241. self.gradients = [x.grad.clone().detach().cpu() for x in model.parameters() if x.grad is not None]
  242. self.name = name
  243. self.gradients_root = '/gradients/'
  244. # override
  245. def write(self):
  246. try:
  247. if len(self.gradients) == 0: return
  248. gradient_nps = [to_np(x.data) for x in self.gradients]
  249. avg_norm = sum(x.data.norm() for x in self.gradients)/len(self.gradients)
  250. self.tbwriter.add_scalar(
  251. tag=self.name + self.gradients_root + 'avg_norm', scalar_value=avg_norm, global_step=self.iteration)
  252. median_norm = statistics.median(x.data.norm() for x in self.gradients)
  253. self.tbwriter.add_scalar(
  254. tag=self.name + self.gradients_root + 'median_norm', scalar_value=median_norm, global_step=self.iteration)
  255. max_norm = max(x.data.norm() for x in self.gradients)
  256. self.tbwriter.add_scalar(
  257. tag=self.name + self.gradients_root + 'max_norm', scalar_value=max_norm, global_step=self.iteration)
  258. min_norm = min(x.data.norm() for x in self.gradients)
  259. self.tbwriter.add_scalar(
  260. tag=self.name + self.gradients_root + 'min_norm', scalar_value=min_norm, global_step=self.iteration)
  261. num_zeros = sum((np.asarray(x) == 0.0).sum() for x in gradient_nps)
  262. self.tbwriter.add_scalar(
  263. tag=self.name + self.gradients_root + 'num_zeros', scalar_value=num_zeros, global_step=self.iteration)
  264. avg_gradient = sum(x.data.mean() for x in self.gradients)/len(self.gradients)
  265. self.tbwriter.add_scalar(
  266. tag=self.name + self.gradients_root + 'avg_gradient', scalar_value=avg_gradient, global_step=self.iteration)
  267. median_gradient = statistics.median(x.data.median() for x in self.gradients)
  268. self.tbwriter.add_scalar(
  269. tag=self.name + self.gradients_root + 'median_gradient', scalar_value=median_gradient, global_step=self.iteration)
  270. max_gradient = max(x.data.max() for x in self.gradients)
  271. self.tbwriter.add_scalar(
  272. tag=self.name + self.gradients_root + 'max_gradient', scalar_value=max_gradient, global_step=self.iteration)
  273. min_gradient = min(x.data.min() for x in self.gradients)
  274. self.tbwriter.add_scalar(
  275. tag=self.name + self.gradients_root + 'min_gradient', scalar_value=min_gradient, global_step=self.iteration)
  276. except Exception as e:
  277. print(("Failed to write model stats to Tensorboard: {0}").format(e))
  278. class ModelStatsTBWriter():
  279. def write(self, model:nn.Module, iteration:int, tbwriter:SummaryWriter, name:str='model_stats'):
  280. request = ModelStatsTBRequest(model=model, iteration=iteration, tbwriter=tbwriter, name=name)
  281. asyncTBWriter.request_write(request)
  282. class ImageTBRequest(TBWriteRequest):
  283. def __init__(self, learn:Learner, batch:Tuple, iteration:int, tbwriter:SummaryWriter, ds_type:DatasetType):
  284. super().__init__(tbwriter=tbwriter, iteration=iteration)
  285. self.image_sets = ModelImageSet.get_list_from_model(learn=learn, batch=batch, ds_type=ds_type)
  286. self.ds_type = ds_type
  287. # override
  288. def write(self):
  289. try:
  290. orig_images = []
  291. gen_images = []
  292. real_images = []
  293. for image_set in self.image_sets:
  294. orig_images.append(image_set.orig.px)
  295. gen_images.append(image_set.gen.px)
  296. real_images.append(image_set.real.px)
  297. prefix = self.ds_type.name
  298. self.tbwriter.add_image(
  299. tag=prefix + ' orig images', img_tensor=vutils.make_grid(orig_images, normalize=True),
  300. global_step=self.iteration)
  301. self.tbwriter.add_image(
  302. tag=prefix + ' gen images', img_tensor=vutils.make_grid(gen_images, normalize=True),
  303. global_step=self.iteration)
  304. self.tbwriter.add_image(
  305. tag=prefix + ' real images', img_tensor=vutils.make_grid(real_images, normalize=True),
  306. global_step=self.iteration)
  307. except Exception as e:
  308. print(("Failed to write images to Tensorboard: {0}").format(e))
  309. #If this isn't done async then this is noticeably slower
  310. class ImageTBWriter():
  311. def __init__(self):
  312. super().__init__()
  313. def write(self, learn:Learner, trn_batch:Tuple, val_batch:Tuple, iteration:int, tbwriter:SummaryWriter):
  314. self._write_for_dstype(learn=learn, batch=val_batch, iteration=iteration,
  315. tbwriter=tbwriter, ds_type=DatasetType.Valid)
  316. self._write_for_dstype(learn=learn, batch=trn_batch, iteration=iteration,
  317. tbwriter=tbwriter, ds_type=DatasetType.Train)
  318. def _write_for_dstype(self, learn:Learner, batch:Tuple, iteration:int, tbwriter:SummaryWriter, ds_type:DatasetType):
  319. request = ImageTBRequest(learn=learn, batch=batch, iteration=iteration, tbwriter=tbwriter, ds_type=ds_type)
  320. asyncTBWriter.request_write(request)