tensorboard.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397
  1. import fastai
  2. from fastai.basic_train import Learner
  3. from fastai.basic_data import DatasetType, DataBunch
  4. from fastai.vision import Image
  5. from fastai.callbacks import LearnerCallback
  6. from fastai.core import *
  7. from fastai.torch_core import *
  8. from threading import Thread, Event
  9. import time
  10. from time import sleep
  11. from queue import Queue
  12. import statistics
  13. import torchvision.utils as vutils
  14. from abc import ABC, abstractmethod
  15. from tensorboardX import SummaryWriter
  16. class TBWriteRequest(ABC):
  17. def __init__(self, tbwriter: SummaryWriter, iteration:int):
  18. super().__init__()
  19. self.tbwriter = tbwriter
  20. self.iteration = iteration
  21. @abstractmethod
  22. def write(self):
  23. pass
  24. # SummaryWriter writes tend to block quite a bit. This gets around that and greatly boosts performance.
  25. # Not all tensorboard writes are using this- just the ones that take a long time. Note that the
  26. # SummaryWriter does actually use a threadsafe consumer/producer design ultimately to write to Tensorboard,
  27. # so writes done outside of this async loop should be fine.
  28. class AsyncTBWriter():
  29. def __init__(self):
  30. super().__init__()
  31. self.stoprequest = Event()
  32. self.queue = Queue()
  33. self.thread = Thread(target=self._queue_processor, daemon=True)
  34. self.thread.start()
  35. def request_write(self, request: TBWriteRequest):
  36. if self.stoprequest.isSet():
  37. raise Exception('Close was already called! Cannot perform this operation.')
  38. self.queue.put(request)
  39. def _queue_processor(self):
  40. while not self.stoprequest.isSet():
  41. while not self.queue.empty():
  42. request = self.queue.get()
  43. request.write()
  44. sleep(0.2)
  45. #Provided this to stop thread explicitly or by context management (with statement) but thread should end on its own
  46. # upon program exit, due to being a daemon. So using this is probably unecessary.
  47. def close(self):
  48. self.stoprequest.set()
  49. self.thread.join()
  50. def __enter__(self):
  51. # Nothing to do, thread already started. Could start thread here to enforce use of context manager
  52. # (but that sounds like a pain and a bit unweildy and unecessary for actual usage)
  53. pass
  54. def __exit__(self, exc_type, exc_value, traceback):
  55. self.close()
  56. asyncTBWriter = AsyncTBWriter()
  57. class ModelImageSet():
  58. @staticmethod
  59. def get_list_from_model(learn:Learner, ds_type:DatasetType, batch:Tuple)->[]:
  60. image_sets = []
  61. x,y = batch[0],batch[1]
  62. preds = learn.pred_batch(ds_type=ds_type, batch=(x,y), reconstruct=True)
  63. for orig_px, real_px, gen in zip(x,y,preds):
  64. orig = Image(px=orig_px)
  65. real = Image(px=real_px)
  66. image_set = ModelImageSet(orig=orig, real=real, gen=gen)
  67. image_sets.append(image_set)
  68. return image_sets
  69. def __init__(self, orig:Image, real:Image, gen:Image):
  70. self.orig = orig
  71. self.real = real
  72. self.gen = gen
  73. class HistogramTBRequest(TBWriteRequest):
  74. def __init__(self, model:nn.Module, iteration:int, tbwriter:SummaryWriter, name:str):
  75. super().__init__(tbwriter=tbwriter, iteration=iteration)
  76. self.params = [(name, values.clone().detach().cpu()) for (name, values) in model.named_parameters()]
  77. self.name = name
  78. # override
  79. def write(self):
  80. try:
  81. for param_name, values in self.params:
  82. tag = self.name + '/weights/' + param_name
  83. self.tbwriter.add_histogram(tag=tag, values=values, global_step=self.iteration)
  84. except Exception as e:
  85. print(("Failed to write model histograms to Tensorboard: {0}").format(e))
  86. #If this isn't done async then this is sloooooow
  87. class HistogramTBWriter():
  88. def __init__(self):
  89. super().__init__()
  90. def write(self, model:nn.Module, iteration:int, tbwriter:SummaryWriter, name:str='model'):
  91. request = HistogramTBRequest(model=model, iteration=iteration, tbwriter=tbwriter, name=name)
  92. asyncTBWriter.request_write(request)
  93. class ModelStatsTBRequest(TBWriteRequest):
  94. def __init__(self, model:nn.Module, iteration:int, tbwriter:SummaryWriter, name:str):
  95. super().__init__(tbwriter=tbwriter, iteration=iteration)
  96. self.gradients = [x.grad.clone().detach().cpu() for x in model.parameters() if x.grad is not None]
  97. self.name = name
  98. self.gradients_root = '/gradients/'
  99. # override
  100. def write(self):
  101. try:
  102. if len(self.gradients) == 0: return
  103. gradient_nps = [to_np(x.data) for x in self.gradients]
  104. avg_norm = sum(x.data.norm() for x in self.gradients)/len(self.gradients)
  105. self.tbwriter.add_scalar(
  106. tag=self.name + self.gradients_root + 'avg_norm', scalar_value=avg_norm, global_step=self.iteration)
  107. median_norm = statistics.median(x.data.norm() for x in self.gradients)
  108. self.tbwriter.add_scalar(
  109. tag=self.name + self.gradients_root + 'median_norm', scalar_value=median_norm, global_step=self.iteration)
  110. max_norm = max(x.data.norm() for x in self.gradients)
  111. self.tbwriter.add_scalar(
  112. tag=self.name + self.gradients_root + 'max_norm', scalar_value=max_norm, global_step=self.iteration)
  113. min_norm = min(x.data.norm() for x in self.gradients)
  114. self.tbwriter.add_scalar(
  115. tag=self.name + self.gradients_root + 'min_norm', scalar_value=min_norm, global_step=self.iteration)
  116. num_zeros = sum((np.asarray(x) == 0.0).sum() for x in gradient_nps)
  117. self.tbwriter.add_scalar(
  118. tag=self.name + self.gradients_root + 'num_zeros', scalar_value=num_zeros, global_step=self.iteration)
  119. avg_gradient = sum(x.data.mean() for x in self.gradients)/len(self.gradients)
  120. self.tbwriter.add_scalar(
  121. tag=self.name + self.gradients_root + 'avg_gradient', scalar_value=avg_gradient, global_step=self.iteration)
  122. median_gradient = statistics.median(x.data.median() for x in self.gradients)
  123. self.tbwriter.add_scalar(
  124. tag=self.name + self.gradients_root + 'median_gradient', scalar_value=median_gradient, global_step=self.iteration)
  125. max_gradient = max(x.data.max() for x in self.gradients)
  126. self.tbwriter.add_scalar(
  127. tag=self.name + self.gradients_root + 'max_gradient', scalar_value=max_gradient, global_step=self.iteration)
  128. min_gradient = min(x.data.min() for x in self.gradients)
  129. self.tbwriter.add_scalar(
  130. tag=self.name + self.gradients_root + 'min_gradient', scalar_value=min_gradient, global_step=self.iteration)
  131. except Exception as e:
  132. print(("Failed to write model stats to Tensorboard: {0}").format(e))
  133. class ModelStatsTBWriter():
  134. def write(self, model:nn.Module, iteration:int, tbwriter:SummaryWriter, name:str='model_stats'):
  135. request = ModelStatsTBRequest(model=model, iteration=iteration, tbwriter=tbwriter, name=name)
  136. asyncTBWriter.request_write(request)
  137. class ImageTBRequest(TBWriteRequest):
  138. def __init__(self, learn:Learner, batch:Tuple, iteration:int, tbwriter:SummaryWriter, ds_type:DatasetType):
  139. super().__init__(tbwriter=tbwriter, iteration=iteration)
  140. self.image_sets = ModelImageSet.get_list_from_model(learn=learn, batch=batch, ds_type=ds_type)
  141. self.ds_type = ds_type
  142. # override
  143. def write(self):
  144. try:
  145. orig_images = []
  146. gen_images = []
  147. real_images = []
  148. for image_set in self.image_sets:
  149. orig_images.append(image_set.orig.px)
  150. gen_images.append(image_set.gen.px)
  151. real_images.append(image_set.real.px)
  152. prefix = self.ds_type.name
  153. self.tbwriter.add_image(
  154. tag=prefix + ' orig images', img_tensor=vutils.make_grid(orig_images, normalize=True),
  155. global_step=self.iteration)
  156. self.tbwriter.add_image(
  157. tag=prefix + ' gen images', img_tensor=vutils.make_grid(gen_images, normalize=True),
  158. global_step=self.iteration)
  159. self.tbwriter.add_image(
  160. tag=prefix + ' real images', img_tensor=vutils.make_grid(real_images, normalize=True),
  161. global_step=self.iteration)
  162. except Exception as e:
  163. print(("Failed to write images to Tensorboard: {0}").format(e))
  164. #If this isn't done async then this is noticeably slower
  165. class ImageTBWriter():
  166. def __init__(self):
  167. super().__init__()
  168. def write(self, learn:Learner, trn_batch:Tuple, val_batch:Tuple, iteration:int, tbwriter:SummaryWriter):
  169. self._write_for_dstype(learn=learn, batch=val_batch, iteration=iteration,
  170. tbwriter=tbwriter, ds_type=DatasetType.Valid)
  171. self._write_for_dstype(learn=learn, batch=trn_batch, iteration=iteration,
  172. tbwriter=tbwriter, ds_type=DatasetType.Train)
  173. def _write_for_dstype(self, learn:Learner, batch:Tuple, iteration:int, tbwriter:SummaryWriter, ds_type:DatasetType):
  174. request = ImageTBRequest(learn=learn, batch=batch, iteration=iteration, tbwriter=tbwriter, ds_type=ds_type)
  175. asyncTBWriter.request_write(request)
  176. #--------CALLBACKS----------------#
  177. class LearnerTensorboardWriter(LearnerCallback):
  178. def __init__(self, learn:Learner, base_dir:Path, name:str, loss_iters:int=25, hist_iters:int=500, stats_iters:int=100):
  179. super().__init__(learn=learn)
  180. self.base_dir = base_dir
  181. self.name = name
  182. log_dir = base_dir/name
  183. self.tbwriter = SummaryWriter(log_dir=str(log_dir))
  184. self.loss_iters = loss_iters
  185. self.hist_iters = hist_iters
  186. self.stats_iters = stats_iters
  187. self.hist_writer = HistogramTBWriter()
  188. self.stats_writer = ModelStatsTBWriter()
  189. self.data = None
  190. self.metrics_root = '/metrics/'
  191. self._update_batches_if_needed()
  192. def _update_batches_if_needed(self):
  193. # one_batch function is extremely slow with large datasets. This is an optimization.
  194. # Note that also we want to always show the same batches so we can see changes
  195. # in tensorboard
  196. update_batches = self.data is not self.learn.data
  197. if update_batches:
  198. self.data = self.learn.data
  199. self.trn_batch = self.learn.data.one_batch(
  200. ds_type=DatasetType.Train, detach=True, denorm=False, cpu=False)
  201. self.val_batch = self.learn.data.one_batch(
  202. ds_type=DatasetType.Valid, detach=True, denorm=False, cpu=False)
  203. def _write_model_stats(self, iteration:int):
  204. self.stats_writer.write(
  205. model=self.learn.model, iteration=iteration, tbwriter=self.tbwriter)
  206. def _write_training_loss(self, iteration:int, last_loss:Tensor):
  207. scalar_value = to_np(last_loss)
  208. tag = self.metrics_root + 'train_loss'
  209. self.tbwriter.add_scalar(tag=tag, scalar_value=scalar_value, global_step=iteration)
  210. def _write_weight_histograms(self, iteration:int):
  211. self.hist_writer.write(
  212. model=self.learn.model, iteration=iteration, tbwriter=self.tbwriter)
  213. #TODO: Relying on a specific hardcoded start_idx here isn't great. Is there a better solution?
  214. def _write_metrics(self, iteration:int, last_metrics:MetricsList, start_idx:int=2):
  215. recorder = self.learn.recorder
  216. for i, name in enumerate(recorder.names[start_idx:]):
  217. if len(last_metrics) < i+1: return
  218. scalar_value = last_metrics[i]
  219. tag = self.metrics_root + name
  220. self.tbwriter.add_scalar(tag=tag, scalar_value=scalar_value, global_step=iteration)
  221. def on_batch_end(self, last_loss:Tensor, iteration:int, **kwargs):
  222. if iteration == 0: return
  223. self._update_batches_if_needed()
  224. if iteration % self.loss_iters == 0:
  225. self._write_training_loss(iteration=iteration, last_loss=last_loss)
  226. if iteration % self.hist_iters == 0:
  227. self._write_weight_histograms(iteration=iteration)
  228. # Doing stuff here that requires gradient info, because they get zeroed out afterwards in training loop
  229. def on_backward_end(self, iteration:int, **kwargs):
  230. if iteration == 0: return
  231. self._update_batches_if_needed()
  232. if iteration % self.stats_iters == 0:
  233. self._write_model_stats(iteration=iteration)
  234. def on_epoch_end(self, last_metrics:MetricsList, iteration:int, **kwargs):
  235. self._write_metrics(iteration=iteration, last_metrics=last_metrics)
  236. # TODO: We're overriding almost everything here. Seems like a good idea to question that ("is a" vs "has a")
  237. class GANTensorboardWriter(LearnerTensorboardWriter):
  238. def __init__(self, learn:Learner, base_dir:Path, name:str, loss_iters:int=25, hist_iters:int=500,
  239. stats_iters:int=100, visual_iters:int=100):
  240. super().__init__(learn=learn, base_dir=base_dir, name=name, loss_iters=loss_iters,
  241. hist_iters=hist_iters, stats_iters=stats_iters)
  242. self.visual_iters = visual_iters
  243. self.img_gen_vis = ImageTBWriter()
  244. self.gen_stats_updated = True
  245. self.crit_stats_updated = True
  246. # override
  247. def _write_weight_histograms(self, iteration:int):
  248. trainer = self.learn.gan_trainer
  249. generator = trainer.generator
  250. critic = trainer.critic
  251. self.hist_writer.write(
  252. model=generator, iteration=iteration, tbwriter=self.tbwriter, name='generator')
  253. self.hist_writer.write(
  254. model=critic, iteration=iteration, tbwriter=self.tbwriter, name='critic')
  255. # override
  256. def _write_model_stats(self, iteration:int):
  257. trainer = self.learn.gan_trainer
  258. generator = trainer.generator
  259. critic = trainer.critic
  260. # Don't want to write stats when model is not iterated on and hence has zeroed out gradients
  261. gen_mode = trainer.gen_mode
  262. if gen_mode and not self.gen_stats_updated:
  263. self.stats_writer.write(
  264. model=generator, iteration=iteration, tbwriter=self.tbwriter, name='gen_model_stats')
  265. self.gen_stats_updated = True
  266. if not gen_mode and not self.crit_stats_updated:
  267. self.stats_writer.write(
  268. model=critic, iteration=iteration, tbwriter=self.tbwriter, name='crit_model_stats')
  269. self.crit_stats_updated = True
  270. # override
  271. def _write_training_loss(self, iteration:int, last_loss:Tensor):
  272. trainer = self.learn.gan_trainer
  273. recorder = trainer.recorder
  274. if len(recorder.losses) > 0:
  275. scalar_value = to_np((recorder.losses[-1:])[0])
  276. tag = self.metrics_root + 'train_loss'
  277. self.tbwriter.add_scalar(tag=tag, scalar_value=scalar_value, global_step=iteration)
  278. def _write(self, iteration:int):
  279. trainer = self.learn.gan_trainer
  280. #TODO: Switching gen_mode temporarily seems a bit hacky here. Certainly not a good side-effect. Is there a better way?
  281. gen_mode = trainer.gen_mode
  282. try:
  283. trainer.switch(gen_mode=True)
  284. self.img_gen_vis.write(learn=self.learn, trn_batch=self.trn_batch, val_batch=self.val_batch,
  285. iteration=iteration, tbwriter=self.tbwriter)
  286. finally:
  287. trainer.switch(gen_mode=gen_mode)
  288. # override
  289. def on_batch_end(self, iteration:int, **kwargs):
  290. super().on_batch_end(iteration=iteration, **kwargs)
  291. if iteration == 0: return
  292. if iteration % self.visual_iters == 0:
  293. self._write(iteration=iteration)
  294. # override
  295. def on_backward_end(self, iteration:int, **kwargs):
  296. if iteration == 0: return
  297. self._update_batches_if_needed()
  298. #TODO: This could perhaps be implemented as queues of requests instead but that seemed like overkill.
  299. # But I'm not the biggest fan of maintaining these boolean flags either... Review pls.
  300. if iteration % self.stats_iters == 0:
  301. self.gen_stats_updated = False
  302. self.crit_stats_updated = False
  303. if not (self.gen_stats_updated and self.crit_stats_updated):
  304. self._write_model_stats(iteration=iteration)
  305. class ImageGenTensorboardWriter(LearnerTensorboardWriter):
  306. def __init__(self, learn:Learner, base_dir:Path, name:str, loss_iters:int=25, hist_iters:int=500,
  307. stats_iters: int = 100, visual_iters: int = 100):
  308. super().__init__(learn=learn, base_dir=base_dir, name=name, loss_iters=loss_iters, hist_iters=hist_iters,
  309. stats_iters=stats_iters)
  310. self.visual_iters = visual_iters
  311. self.img_gen_vis = ImageTBWriter()
  312. def _write(self, iteration:int):
  313. self.img_gen_vis.write(learn=self.learn, trn_batch=self.trn_batch, val_batch=self.val_batch,
  314. iteration=iteration, tbwriter=self.tbwriter)
  315. # override
  316. def on_batch_end(self, iteration:int, **kwargs):
  317. super().on_batch_end(iteration=iteration, **kwargs)
  318. if iteration == 0: return
  319. if iteration % self.visual_iters == 0:
  320. self._write(iteration=iteration)