visualize.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247
  1. from numpy import ndarray
  2. from fastai.torch_imports import *
  3. from fastai.core import *
  4. from matplotlib.axes import Axes
  5. from fastai.dataset import FilesDataset, ImageData, ModelData, open_image
  6. from fastai.transforms import Transform, scale_min, tfms_from_stats, inception_stats
  7. from fastai.transforms import CropType, NoCrop, Denormalize
  8. from fasterai.training import GenResult, CriticResult, GANTrainer
  9. from fasterai.images import ModelImageSet, EasyTensorImage
  10. from IPython.display import display
  11. from tensorboardX import SummaryWriter
  12. import torchvision.utils as vutils
  13. import statistics
  14. class ModelImageVisualizer():
  15. def __init__(self, default_sz:int=500):
  16. self.default_sz=default_sz
  17. self.denorm = Denormalize(*inception_stats)
  18. def plot_transformed_image(self, path:Path, model:nn.Module, figsize:(int,int)=(20,20), sz:int=None,
  19. tfms:[Transform]=[], compare:bool=True):
  20. result = self.get_transformed_image_ndarray(path, model, sz, tfms=tfms)
  21. if compare:
  22. orig = open_image(str(path))
  23. fig,axes = plt.subplots(1, 2, figsize=figsize)
  24. self.plot_image_from_ndarray(orig, axes=axes[0], figsize=figsize)
  25. self.plot_image_from_ndarray(result, axes=axes[1], figsize=figsize)
  26. else:
  27. self.plot_image_from_ndarray(result, figsize=figsize)
  28. def get_transformed_image_ndarray(self, path:Path, model:nn.Module, sz:int=None, tfms:[Transform]=[]):
  29. training = model.training
  30. model.eval()
  31. orig = self.get_model_ready_image_ndarray(path, model, sz, tfms)
  32. orig = VV(orig[None])
  33. result = model(orig).detach().cpu().numpy()
  34. result = self._denorm(result)
  35. if training:
  36. model.train()
  37. return result[0]
  38. def _denorm(self, image: ndarray):
  39. if len(image.shape)==3: arr = arr[None]
  40. return self.denorm(np.rollaxis(image,1,4))
  41. def _transform(self, orig:ndarray, tfms:[Transform], model:nn.Module, sz:int):
  42. for tfm in tfms:
  43. orig,_=tfm(orig, False)
  44. _,val_tfms = tfms_from_stats(inception_stats, sz, crop_type=CropType.NO, aug_tfms=[])
  45. val_tfms.tfms = [tfm for tfm in val_tfms.tfms if not isinstance(tfm, NoCrop)]
  46. orig = val_tfms(orig)
  47. return orig
  48. def get_model_ready_image_ndarray(self, path:Path, model:nn.Module, sz:int=None, tfms:[Transform]=[]):
  49. im = open_image(str(path))
  50. sz = self.default_sz if sz is None else sz
  51. im = scale_min(im, sz)
  52. im = self._transform(im, tfms, model, sz)
  53. return im
  54. def plot_image_from_ndarray(self, image:ndarray, axes:Axes=None, figsize=(20,20)):
  55. if axes is None:
  56. _,axes = plt.subplots(figsize=figsize)
  57. clipped_image =np.clip(image,0,1)
  58. axes.imshow(clipped_image)
  59. axes.axis('off')
  60. def plot_images_from_image_sets(self, image_sets:[ModelImageSet], validation:bool, figsize:(int,int)=(20,20),
  61. max_columns:int=6, immediate_display:bool=True):
  62. num_sets = len(image_sets)
  63. num_images = num_sets * 2
  64. rows, columns = self._get_num_rows_columns(num_images, max_columns)
  65. fig, axes = plt.subplots(rows, columns, figsize=figsize)
  66. title = 'Validation' if validation else 'Training'
  67. fig.suptitle(title, fontsize=16)
  68. for i, image_set in enumerate(image_sets):
  69. self.plot_image_from_ndarray(image_set.orig.array, axes=axes.flat[i*2])
  70. self.plot_image_from_ndarray(image_set.gen.array, axes=axes.flat[i*2+1])
  71. if immediate_display:
  72. display(fig)
  73. def plot_image_outputs_from_model(self, ds:FilesDataset, model:nn.Module, idxs:[int], figsize:(int,int)=(20,20), max_columns:int=6,
  74. immediate_display:bool=True):
  75. image_sets = ModelImageSet.get_list_from_model(ds=ds, model=model, idxs=idxs)
  76. self.plot_images_from_image_sets(image_sets=image_sets, figsize=figsize, max_columns=max_columns, immediate_display=immediate_display)
  77. def _get_num_rows_columns(self, num_images:int, max_columns:int):
  78. columns = min(num_images, max_columns)
  79. rows = num_images//columns
  80. rows = rows if rows * columns == num_images else rows + 1
  81. return rows, columns
  82. class ModelGraphVisualizer():
  83. def __init__(self):
  84. return
  85. def write_model_graph_to_tensorboard(self, ds:FilesDataset, model:nn.Module, tbwriter:SummaryWriter):
  86. try:
  87. x,_=ds[0]
  88. tbwriter.add_graph(model, V(x[None]))
  89. except Exception as e:
  90. print(("Failed to generate graph for model: {0}. Note that there's an outstanding issue with "
  91. + "scopes being addressed here: https://github.com/pytorch/pytorch/pull/12400").format(e))
  92. class ModelHistogramVisualizer():
  93. def __init__(self):
  94. return
  95. def write_tensorboard_histograms(self, model:nn.Module, iter_count:int, tbwriter:SummaryWriter):
  96. for name, param in model.named_parameters():
  97. tbwriter.add_histogram('/weights/' + name, param, iter_count)
  98. class ModelStatsVisualizer():
  99. def __init__(self):
  100. return
  101. def write_tensorboard_stats(self, model:nn.Module, iter_count:int, tbwriter:SummaryWriter):
  102. gradients = [x.grad for x in model.parameters() if x.grad is not None]
  103. gradient_nps = [to_np(x.data) for x in gradients]
  104. if len(gradients) == 0:
  105. return
  106. avg_norm = sum(x.data.norm() for x in gradients)/len(gradients)
  107. tbwriter.add_scalar('/gradients/avg_norm', avg_norm, iter_count)
  108. median_norm = statistics.median(x.data.norm() for x in gradients)
  109. tbwriter.add_scalar('/gradients/median_norm', median_norm, iter_count)
  110. max_norm = max(x.data.norm() for x in gradients)
  111. tbwriter.add_scalar('/gradients/max_norm', max_norm, iter_count)
  112. min_norm = min(x.data.norm() for x in gradients)
  113. tbwriter.add_scalar('/gradients/min_norm', min_norm, iter_count)
  114. num_zeros = sum((np.asarray(x)==0.0).sum() for x in gradient_nps)
  115. tbwriter.add_scalar('/gradients/num_zeros', num_zeros, iter_count)
  116. avg_gradient= sum(x.data.mean() for x in gradients)/len(gradients)
  117. tbwriter.add_scalar('/gradients/avg_gradient', avg_gradient, iter_count)
  118. median_gradient = statistics.median(x.data.median() for x in gradients)
  119. tbwriter.add_scalar('/gradients/median_gradient', median_gradient, iter_count)
  120. max_gradient = max(x.data.max() for x in gradients)
  121. tbwriter.add_scalar('/gradients/max_gradient', max_gradient, iter_count)
  122. min_gradient = min(x.data.min() for x in gradients)
  123. tbwriter.add_scalar('/gradients/min_gradient', min_gradient, iter_count)
  124. class ImageGenVisualizer():
  125. def __init__(self):
  126. self.model_vis = ModelImageVisualizer()
  127. def output_image_gen_visuals(self, md:ImageData, model:nn.Module, iter_count:int, tbwriter:SummaryWriter, jupyter:bool=False):
  128. self._output_visuals(ds=md.val_ds, model=model, iter_count=iter_count, tbwriter=tbwriter, jupyter=jupyter, validation=True)
  129. self._output_visuals(ds=md.trn_ds, model=model, iter_count=iter_count, tbwriter=tbwriter, jupyter=jupyter, validation=False)
  130. def _output_visuals(self, ds:FilesDataset, model:nn.Module, iter_count:int, tbwriter:SummaryWriter,
  131. validation:bool, jupyter:bool=False):
  132. #TODO: Parameterize these
  133. start_idx=0
  134. count = 8
  135. end_index = start_idx + count
  136. idxs = list(range(start_idx,end_index))
  137. image_sets = ModelImageSet.get_list_from_model(ds=ds, model=model, idxs=idxs)
  138. self._write_tensorboard_images(image_sets=image_sets, iter_count=iter_count, tbwriter=tbwriter, validation=validation)
  139. if jupyter:
  140. self._show_images_in_jupyter(image_sets, validation=validation)
  141. def _write_tensorboard_images(self, image_sets:[ModelImageSet], iter_count:int, tbwriter:SummaryWriter, validation:bool):
  142. orig_images = []
  143. gen_images = []
  144. real_images = []
  145. for image_set in image_sets:
  146. orig_images.append(image_set.orig.tensor)
  147. gen_images.append(image_set.gen.tensor)
  148. real_images.append(image_set.real.tensor)
  149. prefix = 'val' if validation else 'train'
  150. tbwriter.add_image(prefix + ' orig images', vutils.make_grid(orig_images, normalize=True), iter_count)
  151. tbwriter.add_image(prefix + ' gen images', vutils.make_grid(gen_images, normalize=True), iter_count)
  152. tbwriter.add_image(prefix + ' real images', vutils.make_grid(real_images, normalize=True), iter_count)
  153. def _show_images_in_jupyter(self, image_sets:[ModelImageSet], validation:bool):
  154. #TODO: Parameterize these
  155. figsize=(20,20)
  156. max_columns=4
  157. immediate_display=True
  158. self.model_vis.plot_images_from_image_sets(image_sets, figsize=figsize, max_columns=max_columns,
  159. immediate_display=immediate_display, validation=validation)
  160. class GANTrainerStatsVisualizer():
  161. def __init__(self):
  162. return
  163. def write_tensorboard_stats(self, gresult:GenResult, cresult:CriticResult, iter_count:int, tbwriter:SummaryWriter):
  164. tbwriter.add_scalar('/loss/hingeloss', cresult.hingeloss, iter_count)
  165. tbwriter.add_scalar('/loss/dfake', cresult.dfake, iter_count)
  166. tbwriter.add_scalar('/loss/dreal', cresult.dreal, iter_count)
  167. tbwriter.add_scalar('/loss/gcost', gresult.gcost, iter_count)
  168. tbwriter.add_scalar('/loss/gcount', gresult.iters, iter_count)
  169. tbwriter.add_scalar('/loss/gaddlloss', gresult.gaddlloss, iter_count)
  170. def print_stats_in_jupyter(self, gresult:GenResult, cresult:CriticResult):
  171. print(f'\nHingeLoss {cresult.hingeloss}; RScore {cresult.dreal}; FScore {cresult.dfake}; GAddlLoss {gresult.gaddlloss}; ' +
  172. f'Iters: {gresult.iters}; GCost: {gresult.gcost};')
  173. class LearnerStatsVisualizer():
  174. def __init__(self):
  175. return
  176. def write_tensorboard_stats(self, metrics, iter_count:int, tbwriter:SummaryWriter):
  177. if isinstance(metrics, list):
  178. tbwriter.add_scalar('/loss/trn_loss', metrics[0], iter_count)
  179. if len(metrics) == 1: return
  180. tbwriter.add_scalar('/loss/val_loss', metrics[1], iter_count)
  181. if len(metrics) == 2: return
  182. for metric in metrics[2:]:
  183. name = metric.__name__
  184. tbwriter.add_scalar('/loss/'+name, metric, iter_count)
  185. else:
  186. tbwriter.add_scalar('/loss/trn_loss', metrics, iter_count)