utils.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. import math, os, json, sys, re, numpy as np, pickle, PIL, scipy
  2. from PIL import Image
  3. from glob import glob
  4. from matplotlib import pyplot as plt
  5. from operator import itemgetter, attrgetter, methodcaller
  6. from collections import OrderedDict
  7. import itertools
  8. from itertools import chain
  9. import pandas as pd
  10. from numpy.random import random, permutation, randn, normal, uniform, choice
  11. from numpy import newaxis
  12. from scipy import misc, ndimage
  13. from scipy.ndimage.interpolation import zoom
  14. from scipy.ndimage import imread
  15. from sklearn.metrics import confusion_matrix
  16. from sklearn.preprocessing import OneHotEncoder
  17. from sklearn.manifold import TSNE
  18. import bcolz
  19. from IPython.lib.display import FileLink
  20. import keras
  21. from keras import backend as K
  22. from keras.utils.data_utils import get_file
  23. from keras.utils import np_utils
  24. from keras.utils.np_utils import to_categorical
  25. from keras.models import Sequential, Model
  26. from keras.layers import Input, Embedding, Reshape, merge, LSTM, Bidirectional
  27. from keras.layers import TimeDistributed, Activation, SimpleRNN, GRU
  28. from keras.layers import Flatten, Dense, Dropout, Lambda
  29. from keras.regularizers import l2, l1
  30. from keras.layers.normalization import BatchNormalization
  31. from keras.optimizers import SGD, RMSprop, Adam
  32. from keras.layers import deserialize as layer_from_config
  33. from keras.metrics import categorical_crossentropy, categorical_accuracy
  34. from keras.layers.convolutional import *
  35. from keras.preprocessing import image, sequence
  36. from keras.preprocessing.text import Tokenizer
  37. from vgg16 import Vgg16
  38. np.set_printoptions(precision=4, linewidth=100)
  39. to_bw = np.array([0.299, 0.587, 0.114])
  40. def gray(img): return np.rollaxis(img, 0, 1).dot(to_bw)
  41. def to_plot(img): return np.rollaxis(img, 0, 1).astype(np.uint8)
  42. def plot(img): plt.imshow(to_plot(img))
  43. def floor(x): return int(math.floor(x))
  44. def ceil(x): return int(math.ceil(x))
  45. def plots(ims, figsize=(12,6), rows=1, interp=False, titles=None):
  46. if type(ims[0]) is np.ndarray:
  47. ims = np.array(ims).astype(np.uint8)
  48. if (ims.shape[-1] != 3): ims = ims.transpose((0,2,3,1))
  49. f = plt.figure(figsize=figsize)
  50. for i in range(len(ims)):
  51. sp = f.add_subplot(rows, len(ims)//rows, i+1)
  52. sp.axis('Off')
  53. if titles is not None: sp.set_title(titles[i], fontsize=16)
  54. plt.imshow(ims[i], interpolation=None if interp else 'none')
  55. def do_clip(arr, mx):
  56. clipped = np.clip(arr, (1-mx)/1, mx)
  57. return clipped/clipped.sum(axis=1)[:, np.newaxis]
  58. def wrap_config(layer):
  59. return {'class_name': layer.__class__.__name__, 'config': layer.get_config()}
  60. def copy_layer(layer): return layer_from_config(wrap_config(layer))
  61. def copy_layers(layers): return [copy_layer(layer) for layer in layers]
  62. def copy_weights(from_layers, to_layers):
  63. for from_layer,to_layer in zip(from_layers, to_layers):
  64. to_layer.set_weights(from_layer.get_weights())
  65. def save_array(fname, arr):
  66. c=bcolz.carray(arr, rootdir=fname, mode='w')
  67. c.flush()
  68. def load_array(fname): return bcolz.open(fname)[:]
  69. def get_classes(path):
  70. batches = get_batches(path+'train', shuffle=False, batch_size=1)
  71. val_batches = get_batches(path+'valid', shuffle=False, batch_size=1)
  72. test_batches = get_batches(path+'test', shuffle=False, batch_size=1)
  73. return (val_batches.classes, batches.classes, onehot(val_batches.classes), onehot(batches.classes),
  74. val_batches.filenames, batches.filenames, test_batches.filenames)
  75. def limit_mem():
  76. K.get_session().close()
  77. cfg = K.tf.ConfigProto()
  78. cfg.gpu_options.allow_growth = True
  79. K.set_session(K.tf.Session(config=cfg))
  80. class MixIterator(object):
  81. def __init__(self, iters):
  82. self.iters = iters
  83. self.multi = type(iters) is list
  84. if self.multi:
  85. self.N = sum([it[0].N for it in self.iters])
  86. else:
  87. self.N = sum([it.N for it in self.iters])
  88. def reset(self):
  89. for it in self.iters: it.reset()
  90. def __iter__(self):
  91. return self
  92. def next(self, *args, **kwargs):
  93. if self.multi:
  94. nexts = [[next(it) for it in o] for o in self.iters]
  95. n0 = np.concatenate([n[0] for n in nexts])
  96. n1 = np.concatenate([n[1] for n in nexts])
  97. return (n0, n1)
  98. else:
  99. nexts = [next(it) for it in self.iters]
  100. n0 = np.concatenate([n[0] for n in nexts])
  101. n1 = np.concatenate([n[1] for n in nexts])
  102. return (n0, n1)