visualize.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207
  1. from numpy import ndarray
  2. from fastai.torch_imports import *
  3. from fastai.core import *
  4. from matplotlib.axes import Axes
  5. from matplotlib.figure import Figure
  6. from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
  7. from fastai.dataset import FilesDataset, ImageData, ModelData, open_image
  8. from fastai.transforms import Transform, scale_min, tfms_from_stats, inception_stats
  9. from fastai.transforms import CropType, NoCrop, Denormalize, Scale
  10. from .transforms import BlackAndWhiteTransform
  11. from .training import GenResult, CriticResult, GANTrainer
  12. from .images import ModelImageSet, EasyTensorImage
  13. from .generators import GeneratorModule
  14. from .filters import Filter
  15. from IPython.display import display
  16. from tensorboardX import SummaryWriter
  17. from scipy import misc
  18. import torchvision.utils as vutils
  19. import statistics
  20. from PIL import Image
  21. class ModelImageVisualizer():
  22. def __init__(self, filters:[Filter]=[], render_factor:int=18, results_dir:str=None):
  23. self.filters = filters
  24. self.render_factor=render_factor
  25. self.results_dir=None if results_dir is None else Path(results_dir)
  26. def plot_transformed_image(self, path:str, figsize:(int,int)=(20,20), render_factor:int=None)->ndarray:
  27. path = Path(path)
  28. result = self._get_transformed_image_ndarray(path, render_factor)
  29. orig = open_image(str(path))
  30. fig,axes = plt.subplots(1, 2, figsize=figsize)
  31. self._plot_image_from_ndarray(orig, axes=axes[0], figsize=figsize)
  32. self._plot_image_from_ndarray(result, axes=axes[1], figsize=figsize)
  33. if self.results_dir is not None:
  34. self._save_result_image(path, result)
  35. def get_transformed_image_as_pil(self, path:str, render_factor:int=None)->Image:
  36. path = Path(path)
  37. array = self._get_transformed_image_ndarray(path, render_factor)
  38. return misc.toimage(array)
  39. def _save_result_image(self, source_path:Path, result:ndarray):
  40. result_path = self.results_dir/source_path.name
  41. misc.imsave(result_path, np.clip(result,0,1))
  42. def _get_transformed_image_ndarray(self, path:Path, render_factor:int=None):
  43. orig_image = open_image(str(path))
  44. filtered_image = orig_image
  45. render_factor = self.render_factor if render_factor is None else render_factor
  46. for filt in self.filters:
  47. filtered_image = filt.filter(orig_image, filtered_image, render_factor=render_factor)
  48. return filtered_image
  49. def _plot_image_from_ndarray(self, image:ndarray, axes:Axes=None, figsize=(20,20)):
  50. if axes is None:
  51. _,axes = plt.subplots(figsize=figsize)
  52. clipped_image =np.clip(image,0,1)
  53. axes.imshow(clipped_image)
  54. axes.axis('off')
  55. def _get_num_rows_columns(self, num_images:int, max_columns:int):
  56. columns = min(num_images, max_columns)
  57. rows = num_images//columns
  58. rows = rows if rows * columns == num_images else rows + 1
  59. return rows, columns
  60. class ModelGraphVisualizer():
  61. def __init__(self):
  62. return
  63. def write_model_graph_to_tensorboard(self, ds:FilesDataset, model:nn.Module, tbwriter:SummaryWriter):
  64. try:
  65. x,_=ds[0]
  66. tbwriter.add_graph(model, V(x[None]))
  67. except Exception as e:
  68. print(("Failed to generate graph for model: {0}. Note that there's an outstanding issue with "
  69. + "scopes being addressed here: https://github.com/pytorch/pytorch/pull/12400").format(e))
  70. class ModelHistogramVisualizer():
  71. def __init__(self):
  72. return
  73. def write_tensorboard_histograms(self, model:nn.Module, iter_count:int, tbwriter:SummaryWriter):
  74. for name, param in model.named_parameters():
  75. tbwriter.add_histogram('/weights/' + name, param, iter_count)
  76. class ModelStatsVisualizer():
  77. def __init__(self):
  78. return
  79. def write_tensorboard_stats(self, model:nn.Module, iter_count:int, tbwriter:SummaryWriter):
  80. gradients = [x.grad for x in model.parameters() if x.grad is not None]
  81. gradient_nps = [to_np(x.data) for x in gradients]
  82. if len(gradients) == 0:
  83. return
  84. avg_norm = sum(x.data.norm() for x in gradients)/len(gradients)
  85. tbwriter.add_scalar('/gradients/avg_norm', avg_norm, iter_count)
  86. median_norm = statistics.median(x.data.norm() for x in gradients)
  87. tbwriter.add_scalar('/gradients/median_norm', median_norm, iter_count)
  88. max_norm = max(x.data.norm() for x in gradients)
  89. tbwriter.add_scalar('/gradients/max_norm', max_norm, iter_count)
  90. min_norm = min(x.data.norm() for x in gradients)
  91. tbwriter.add_scalar('/gradients/min_norm', min_norm, iter_count)
  92. num_zeros = sum((np.asarray(x)==0.0).sum() for x in gradient_nps)
  93. tbwriter.add_scalar('/gradients/num_zeros', num_zeros, iter_count)
  94. avg_gradient= sum(x.data.mean() for x in gradients)/len(gradients)
  95. tbwriter.add_scalar('/gradients/avg_gradient', avg_gradient, iter_count)
  96. median_gradient = statistics.median(x.data.median() for x in gradients)
  97. tbwriter.add_scalar('/gradients/median_gradient', median_gradient, iter_count)
  98. max_gradient = max(x.data.max() for x in gradients)
  99. tbwriter.add_scalar('/gradients/max_gradient', max_gradient, iter_count)
  100. min_gradient = min(x.data.min() for x in gradients)
  101. tbwriter.add_scalar('/gradients/min_gradient', min_gradient, iter_count)
  102. class ImageGenVisualizer():
  103. def __init__(self):
  104. self.model_vis = ModelImageVisualizer()
  105. def output_image_gen_visuals(self, md:ImageData, model:GeneratorModule, iter_count:int, tbwriter:SummaryWriter):
  106. self._output_visuals(ds=md.val_ds, model=model, iter_count=iter_count, tbwriter=tbwriter, validation=True)
  107. self._output_visuals(ds=md.trn_ds, model=model, iter_count=iter_count, tbwriter=tbwriter, validation=False)
  108. def _output_visuals(self, ds:FilesDataset, model:GeneratorModule, iter_count:int, tbwriter:SummaryWriter, validation:bool):
  109. #TODO: Parameterize these
  110. start_idx=0
  111. count = 8
  112. end_index = start_idx + count
  113. idxs = list(range(start_idx,end_index))
  114. image_sets = ModelImageSet.get_list_from_model(ds=ds, model=model, idxs=idxs)
  115. self._write_tensorboard_images(image_sets=image_sets, iter_count=iter_count, tbwriter=tbwriter, validation=validation)
  116. def _write_tensorboard_images(self, image_sets:[ModelImageSet], iter_count:int, tbwriter:SummaryWriter, validation:bool):
  117. orig_images = []
  118. gen_images = []
  119. real_images = []
  120. for image_set in image_sets:
  121. orig_images.append(image_set.orig.tensor)
  122. gen_images.append(image_set.gen.tensor)
  123. real_images.append(image_set.real.tensor)
  124. prefix = 'val' if validation else 'train'
  125. tbwriter.add_image(prefix + ' orig images', vutils.make_grid(orig_images, normalize=True), iter_count)
  126. tbwriter.add_image(prefix + ' gen images', vutils.make_grid(gen_images, normalize=True), iter_count)
  127. tbwriter.add_image(prefix + ' real images', vutils.make_grid(real_images, normalize=True), iter_count)
  128. class GANTrainerStatsVisualizer():
  129. def __init__(self):
  130. return
  131. def write_tensorboard_stats(self, gresult:GenResult, cresult:CriticResult, iter_count:int, tbwriter:SummaryWriter):
  132. tbwriter.add_scalar('/loss/hingeloss', cresult.hingeloss, iter_count)
  133. tbwriter.add_scalar('/loss/dfake', cresult.dfake, iter_count)
  134. tbwriter.add_scalar('/loss/dreal', cresult.dreal, iter_count)
  135. tbwriter.add_scalar('/loss/gcost', gresult.gcost, iter_count)
  136. tbwriter.add_scalar('/loss/gcount', gresult.iters, iter_count)
  137. tbwriter.add_scalar('/loss/gaddlloss', gresult.gaddlloss, iter_count)
  138. def print_stats_in_jupyter(self, gresult:GenResult, cresult:CriticResult):
  139. print(f'\nHingeLoss {cresult.hingeloss}; RScore {cresult.dreal}; FScore {cresult.dfake}; GAddlLoss {gresult.gaddlloss}; ' +
  140. f'Iters: {gresult.iters}; GCost: {gresult.gcost};')
  141. class LearnerStatsVisualizer():
  142. def __init__(self):
  143. return
  144. def write_tensorboard_stats(self, metrics, iter_count:int, tbwriter:SummaryWriter):
  145. if isinstance(metrics, list):
  146. tbwriter.add_scalar('/loss/trn_loss', metrics[0], iter_count)
  147. if len(metrics) == 1: return
  148. tbwriter.add_scalar('/loss/val_loss', metrics[1], iter_count)
  149. if len(metrics) == 2: return
  150. for metric in metrics[2:]:
  151. name = metric.__name__
  152. tbwriter.add_scalar('/loss/'+name, metric, iter_count)
  153. else:
  154. tbwriter.add_scalar('/loss/trn_loss', metrics, iter_count)