core.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. from .imports import *
  2. from .torch_imports import *
  3. def sum_geom(a,r,n): return a*n if r==1 else math.ceil(a*(1-r**n)/(1-r))
  4. def is_listy(x): return isinstance(x, (list,tuple))
  5. def is_iter(x): return isinstance(x, collections.Iterable)
  6. def map_over(x, f): return [f(o) for o in x] if is_listy(x) else f(x)
  7. def map_none(x, f): return None if x is None else f(x)
  8. def delistify(x): return x[0] if is_listy(x) else x
  9. def listify(x, y):
  10. if not is_iter(x): x=[x]
  11. n = y if type(y)==int else len(y)
  12. if len(x)==1: x = x * n
  13. return x
  14. conv_dict = {np.dtype('int8'): torch.LongTensor, np.dtype('int16'): torch.LongTensor,
  15. np.dtype('int32'): torch.LongTensor, np.dtype('int64'): torch.LongTensor,
  16. np.dtype('float32'): torch.FloatTensor, np.dtype('float64'): torch.FloatTensor}
  17. def A(*a):
  18. """convert iterable object into numpy array"""
  19. return np.array(a[0]) if len(a)==1 else [np.array(o) for o in a]
  20. def T(a, half=False, cuda=True):
  21. """
  22. Convert numpy array into a pytorch tensor.
  23. if Cuda is available and USE_GPU=True, store resulting tensor in GPU.
  24. """
  25. if not torch.is_tensor(a):
  26. a = np.array(np.ascontiguousarray(a))
  27. if a.dtype in (np.int8, np.int16, np.int32, np.int64):
  28. a = torch.LongTensor(a.astype(np.int64))
  29. elif a.dtype in (np.float32, np.float64):
  30. a = torch.cuda.HalfTensor(a) if half else torch.FloatTensor(a)
  31. else: raise NotImplementedError(a.dtype)
  32. if cuda: a = to_gpu(a, non_blocking=True)
  33. return a
  34. def create_variable(x, volatile, requires_grad=False):
  35. if type (x) != Variable:
  36. if IS_TORCH_04: x = Variable(T(x), requires_grad=requires_grad)
  37. else: x = Variable(T(x), requires_grad=requires_grad, volatile=volatile)
  38. return x
  39. def V_(x, requires_grad=False, volatile=False):
  40. '''equivalent to create_variable, which creates a pytorch tensor'''
  41. return create_variable(x, volatile=volatile, requires_grad=requires_grad)
  42. def V(x, requires_grad=False, volatile=False):
  43. '''creates a single or a list of pytorch tensors, depending on input x. '''
  44. return map_over(x, lambda o: V_(o, requires_grad, volatile))
  45. def VV_(x):
  46. '''creates a volatile tensor, which does not require gradients. '''
  47. return create_variable(x, True)
  48. def VV(x):
  49. '''creates a single or a list of pytorch tensors, depending on input x. '''
  50. return map_over(x, VV_)
  51. def to_np(v):
  52. '''returns an np.array object given an input of np.array, list, tuple, torch variable or tensor.'''
  53. if isinstance(v, (np.ndarray, np.generic)): return v
  54. if isinstance(v, (list,tuple)): return [to_np(o) for o in v]
  55. if isinstance(v, Variable): v=v.data
  56. if isinstance(v, torch.cuda.HalfTensor): v=v.float()
  57. return v.cpu().numpy()
  58. IS_TORCH_04 = LooseVersion(torch.__version__) >= LooseVersion('0.4')
  59. USE_GPU = torch.cuda.is_available()
  60. def to_gpu(x, *args, **kwargs):
  61. '''puts pytorch variable to gpu, if cuda is available and USE_GPU is set to true. '''
  62. return x.cuda(*args, **kwargs) if USE_GPU else x
  63. def noop(*args, **kwargs): return
  64. def split_by_idxs(seq, idxs):
  65. '''A generator that returns sequence pieces, seperated by indexes specified in idxs. '''
  66. last = 0
  67. for idx in idxs:
  68. if not (-len(seq) <= idx < len(seq)):
  69. raise KeyError(f'Idx {idx} is out-of-bounds')
  70. yield seq[last:idx]
  71. last = idx
  72. yield seq[last:]
  73. def trainable_params_(m):
  74. '''Returns a list of trainable parameters in the model m. (i.e., those that require gradients.)'''
  75. return [p for p in m.parameters() if p.requires_grad]
  76. def chain_params(p):
  77. if is_listy(p):
  78. return list(chain(*[trainable_params_(o) for o in p]))
  79. return trainable_params_(p)
  80. def set_trainable_attr(m,b):
  81. m.trainable=b
  82. for p in m.parameters(): p.requires_grad=b
  83. def apply_leaf(m, f):
  84. c = children(m)
  85. if isinstance(m, nn.Module): f(m)
  86. if len(c)>0:
  87. for l in c: apply_leaf(l,f)
  88. def set_trainable(l, b):
  89. apply_leaf(l, lambda m: set_trainable_attr(m,b))
  90. def SGD_Momentum(momentum):
  91. return lambda *args, **kwargs: optim.SGD(*args, momentum=momentum, **kwargs)
  92. def one_hot(a,c): return np.eye(c)[a]
  93. def partition(a, sz):
  94. """splits iterables a in equal parts of size sz"""
  95. return [a[i:i+sz] for i in range(0, len(a), sz)]
  96. def partition_by_cores(a):
  97. return partition(a, len(a)//num_cpus() + 1)
  98. def num_cpus():
  99. try:
  100. return len(os.sched_getaffinity(0))
  101. except AttributeError:
  102. return os.cpu_count()
  103. class BasicModel():
  104. def __init__(self,model,name='unnamed'): self.model,self.name = model,name
  105. def get_layer_groups(self, do_fc=False): return children(self.model)
  106. class SingleModel(BasicModel):
  107. def get_layer_groups(self): return [self.model]
  108. class SimpleNet(nn.Module):
  109. def __init__(self, layers):
  110. super().__init__()
  111. self.layers = nn.ModuleList([
  112. nn.Linear(layers[i], layers[i + 1]) for i in range(len(layers) - 1)])
  113. def forward(self, x):
  114. x = x.view(x.size(0), -1)
  115. for l in self.layers:
  116. l_x = l(x)
  117. x = F.relu(l_x)
  118. return F.log_softmax(l_x, dim=-1)
  119. def save(fn, a):
  120. """Utility function that savess model, function, etc as pickle"""
  121. pickle.dump(a, open(fn,'wb'))
  122. def load(fn):
  123. """Utility function that loads model, function, etc as pickle"""
  124. return pickle.load(open(fn,'rb'))
  125. def load2(fn):
  126. """Utility funciton allowing model piclking across Python2 and Python3"""
  127. return pickle.load(open(fn,'rb'), encoding='iso-8859-1')
  128. def load_array(fname):
  129. '''
  130. Load array using bcolz, which is based on numpy, for fast array saving and loading operations.
  131. https://github.com/Blosc/bcolz
  132. '''
  133. return bcolz.open(fname)[:]
  134. def chunk_iter(iterable, chunk_size):
  135. '''A generator that yields chunks of iterable, chunk_size at a time. '''
  136. while True:
  137. chunk = []
  138. try:
  139. for _ in range(chunk_size): chunk.append(next(iterable))
  140. yield chunk
  141. except StopIteration:
  142. if chunk: yield chunk
  143. break
  144. def set_grad_enabled(mode): return torch.set_grad_enabled(mode) if IS_TORCH_04 else contextlib.suppress()
  145. def no_grad_context(): return torch.no_grad() if IS_TORCH_04 else contextlib.suppress()