plots.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204
  1. from .imports import *
  2. from .torch_imports import *
  3. from sklearn.metrics import confusion_matrix
  4. def ceildiv(a, b):
  5. return -(-a // b)
  6. def plots(ims, figsize=(12,6), rows=1, interp=False, titles=None, maintitle=None):
  7. if type(ims[0]) is np.ndarray:
  8. ims = np.array(ims)
  9. if (ims.shape[-1] != 3): ims = ims.transpose((0,2,3,1))
  10. f = plt.figure(figsize=figsize)
  11. if maintitle is not None:
  12. plt.suptitle(maintitle, fontsize=16)
  13. for i in range(len(ims)):
  14. sp = f.add_subplot(rows, ceildiv(len(ims), rows), i+1)
  15. sp.axis('Off')
  16. if titles is not None: sp.set_title(titles[i], fontsize=16)
  17. plt.imshow(ims[i], interpolation=None if interp else 'none')
  18. def plots_from_files(imspaths, figsize=(10,5), rows=1, titles=None, maintitle=None):
  19. """Plots images given image files.
  20. Arguments:
  21. im_paths (list): list of paths
  22. figsize (tuple): figure size
  23. rows (int): number of rows
  24. titles (list): list of titles
  25. maintitle (string): main title
  26. """
  27. f = plt.figure(figsize=figsize)
  28. if maintitle is not None: plt.suptitle(maintitle, fontsize=16)
  29. for i in range(len(imspaths)):
  30. sp = f.add_subplot(rows, ceildiv(len(imspaths), rows), i+1)
  31. sp.axis('Off')
  32. if titles is not None: sp.set_title(titles[i], fontsize=16)
  33. img = plt.imread(imspaths[i])
  34. plt.imshow(img)
  35. def plot_confusion_matrix(cm, classes, normalize=False, title='Confusion matrix', cmap=plt.cm.Blues, figsize=None):
  36. """
  37. This function prints and plots the confusion matrix.
  38. Normalization can be applied by setting `normalize=True`.
  39. (This function is copied from the scikit docs.)
  40. """
  41. plt.figure(figsize=figsize)
  42. plt.imshow(cm, interpolation='nearest', cmap=cmap)
  43. plt.title(title)
  44. plt.colorbar()
  45. tick_marks = np.arange(len(classes))
  46. plt.xticks(tick_marks, classes, rotation=45)
  47. plt.yticks(tick_marks, classes)
  48. if normalize: cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
  49. print(cm)
  50. thresh = cm.max() / 2.
  51. for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
  52. plt.text(j, i, cm[i, j], horizontalalignment="center", color="white" if cm[i, j] > thresh else "black")
  53. plt.tight_layout()
  54. plt.ylabel('True label')
  55. plt.xlabel('Predicted label')
  56. def plots_raw(ims, figsize=(12,6), rows=1, titles=None):
  57. f = plt.figure(figsize=figsize)
  58. for i in range(len(ims)):
  59. sp = f.add_subplot(rows, ceildiv(len(ims), rows), i+1)
  60. sp.axis('Off')
  61. if titles is not None: sp.set_title(titles[i], fontsize=16)
  62. plt.imshow(ims[i])
  63. def load_img_id(ds, idx, path): return np.array(PIL.Image.open(os.path.join(path, ds.fnames[idx])))
  64. class ImageModelResults():
  65. """ Visualize the results of an image model
  66. Arguments:
  67. ds (dataset): a dataset which contains the images
  68. log_preds (numpy.ndarray): predictions for the dataset in log scale
  69. Returns:
  70. ImageModelResults
  71. """
  72. def __init__(self, ds, log_preds):
  73. """Initialize an ImageModelResults class instance"""
  74. self.ds = ds
  75. # returns the indices of the maximum value of predictions along axis 1, representing the predicted class
  76. # log_preds.shape = (number_of_samples, number_of_classes);
  77. # preds.shape = (number_of_samples,)
  78. self.preds = np.argmax(log_preds, axis=1)
  79. # computes the probabilities
  80. self.probs = np.exp(log_preds)
  81. # extracts the number of classes
  82. self.num_classes = log_preds.shape[1]
  83. def plot_val_with_title(self, idxs, y):
  84. """ Displays the images and their probabilities of belonging to a certain class
  85. Arguments:
  86. idxs (numpy.ndarray): indexes of the image samples from the dataset
  87. y (int): the selected class
  88. Returns:
  89. Plots the images in n rows [rows = n]
  90. """
  91. # if there are any samples to be displayed
  92. if len(idxs) > 0:
  93. imgs = np.stack([self.ds[x][0] for x in idxs])
  94. title_probs = [self.probs[x,y] for x in idxs]
  95. return plots(self.ds.denorm(imgs), rows=1, titles=title_probs)
  96. # if idxs is empty return false
  97. else:
  98. return False;
  99. def most_by_mask(self, mask, y, mult):
  100. """ Extracts the first 4 most correct/incorrect indexes from the ordered list of probabilities
  101. Arguments:
  102. mask (numpy.ndarray): the mask of probabilities specific to the selected class; a boolean array with shape (num_of_samples,) which contains True where class==selected_class, and False everywhere else
  103. y (int): the selected class
  104. mult (int): sets the ordering; -1 descending, 1 ascending
  105. Returns:
  106. idxs (ndarray): An array of indexes of length 4
  107. """
  108. idxs = np.where(mask)[0]
  109. cnt = min(4, len(idxs))
  110. return idxs[np.argsort(mult * self.probs[idxs,y])[:cnt]]
  111. def most_uncertain_by_mask(self, mask, y):
  112. """ Extracts the first 4 most uncertain indexes from the ordered list of probabilities
  113. Arguments:
  114. mask (numpy.ndarray): the mask of probabilities specific to the selected class; a boolean array with shape (num_of_samples,) which contains True where class==selected_class, and False everywhere else
  115. y (int): the selected class
  116. Returns:
  117. idxs (ndarray): An array of indexes of length 4
  118. """
  119. idxs = np.where(mask)[0]
  120. # the most uncertain samples will have abs(probs-1/num_classes) close to 0;
  121. return idxs[np.argsort(np.abs(self.probs[idxs,y]-(1/self.num_classes)))[:4]]
  122. def most_by_correct(self, y, is_correct):
  123. """ Extracts the predicted classes which correspond to the selected class (y) and to the specific case (prediction is correct - is_true=True, prediction is wrong - is_true=False)
  124. Arguments:
  125. y (int): the selected class
  126. is_correct (boolean): a boolean flag (True, False) which specify the what to look for. Ex: True - most correct samples, False - most incorrect samples
  127. Returns:
  128. idxs (numpy.ndarray): An array of indexes (numpy.ndarray)
  129. """
  130. # mult=-1 when the is_correct flag is true -> when we want to display the most correct classes we will make a descending sorting (argsort) because we want that the biggest probabilities to be displayed first.
  131. # When is_correct is false, we want to display the most incorrect classes, so we want an ascending sorting since our interest is in the smallest probabilities.
  132. mult = -1 if is_correct==True else 1
  133. return self.most_by_mask(((self.preds == self.ds.y)==is_correct)
  134. & (self.ds.y == y), y, mult)
  135. def plot_by_correct(self, y, is_correct):
  136. """ Plots the images which correspond to the selected class (y) and to the specific case (prediction is correct - is_true=True, prediction is wrong - is_true=False)
  137. Arguments:
  138. y (int): the selected class
  139. is_correct (boolean): a boolean flag (True, False) which specify the what to look for. Ex: True - most correct samples, False - most incorrect samples
  140. """
  141. return self.plot_val_with_title(self.most_by_correct(y, is_correct), y)
  142. def most_by_uncertain(self, y):
  143. """ Extracts the predicted classes which correspond to the selected class (y) and have probabilities nearest to 1/number_of_classes (eg. 0.5 for 2 classes, 0.33 for 3 classes) for the selected class.
  144. Arguments:
  145. y (int): the selected class
  146. Returns:
  147. idxs (numpy.ndarray): An array of indexes (numpy.ndarray)
  148. """
  149. return self.most_uncertain_by_mask((self.ds.y == y), y)
  150. def plot_most_correct(self, y):
  151. """ Plots the images which correspond to the selected class (y) and are most correct.
  152. Arguments:
  153. y (int): the selected class
  154. """
  155. return self.plot_by_correct(y, True)
  156. def plot_most_incorrect(self, y):
  157. """ Plots the images which correspond to the selected class (y) and are most incorrect.
  158. Arguments:
  159. y (int): the selected class
  160. """
  161. return self.plot_by_correct(y, False)
  162. def plot_most_uncertain(self, y):
  163. """ Plots the images which correspond to the selected class (y) and are most uncertain i.e have probabilities nearest to 1/number_of_classes.
  164. Arguments:
  165. y (int): the selected class
  166. """
  167. return self.plot_val_with_title(self.most_by_uncertain(y), y)