training.py 12 KB


  1. from fastai.core import *
  2. from fastai.torch_imports import *
  3. from fastai.dataset import Transform
  4. from fastai.layer_optimizer import LayerOptimizer
  5. from fastai.sgdr import CircularLR_beta
  6. from fasterai.modules import ConvBlock
  7. from fasterai.generators import GeneratorModule
  8. from fasterai.dataset import ImageGenDataLoader
  9. from collections import Iterable
  10. import torch.utils.hooks as hooks
  11. from abc import ABC
  12. class CriticModule(ABC, nn.Module):
  13. def __init__(self):
  14. super().__init__()
  15. def freeze_to(self, n:int):
  16. c=self.get_layer_groups()
  17. for l in c: set_trainable(l, False)
  18. for l in c[n:]: set_trainable(l, True)
  19. def set_trainable(self, trainable:bool):
  20. set_trainable(self, trainable)
  21. @abstractmethod
  22. def get_layer_groups(self)->[]:
  23. pass
  24. def get_device(self):
  25. return next(self.parameters()).device
  26. class DCCritic(CriticModule):
  27. def _generate_reduce_layers(self, nf:int):
  28. layers=[]
  29. layers.append(nn.Dropout2d(0.5))
  30. layers.append(ConvBlock(nf, nf*2, 4, 2, bn=False, sn=True, leakyReLu=True, self_attention=True))
  31. return layers
  32. def __init__(self, ni:int=3, nf:int=128):
  33. super().__init__()
  34. scale:int=16
  35. sn=True
  36. self_attention=True
  37. assert (math.log(scale,2)).is_integer()
  38. self.initial = nn.Sequential(
  39. ConvBlock(ni, nf, 4, 2, bn=False, sn=sn, leakyReLu=True),
  40. nn.Dropout2d(0.2),
  41. ConvBlock(nf, nf, 3, 1, bn=False, sn=sn, leakyReLu=True)
  42. )
  43. cndf = nf
  44. mid_layers = []
  45. for i in range(int(math.log(scale,2))-1):
  46. use_attention = (i == 0 and self_attention)
  47. layers = self._generate_reduce_layers(nf=cndf)
  48. mid_layers.extend(layers)
  49. cndf = int(cndf*2)
  50. self.mid = nn.Sequential(*mid_layers)
  51. out_layers=[]
  52. out_layers.append(ConvBlock(cndf, 1, ks=4, stride=1, bias=False, bn=False, sn=sn, pad=0, actn=False))
  53. self.out = nn.Sequential(*out_layers)
  54. def get_layer_groups(self)->[]:
  55. return children(self)
  56. def forward(self, input):
  57. x=self.initial(input)
  58. x=self.mid(x)
  59. return self.out(x), x
  60. class GenResult():
  61. def __init__(self, gcost: np.array, iters: int, gaddlloss: np.array):
  62. self.gcost=gcost
  63. self.iters=iters
  64. self.gaddlloss=gaddlloss
  65. class CriticResult():
  66. def __init__(self, hingeloss: np.array, dreal: np.array, dfake: np.array, dcost: np.array):
  67. self.hingeloss=hingeloss
  68. self.dreal=dreal
  69. self.dfake=dfake
  70. self.dcost=dcost
  71. class GANTrainSchedule():
  72. @staticmethod
  73. def generate_schedules(szs:[int], bss:[int], path:Path, keep_pcts:[float], save_base_name:str,
  74. c_lrs:[float], g_lrs:[float], gen_freeze_tos:[int], lrs_unfreeze_factor:float=0.1,
  75. x_noise:int=None, random_seed=None, x_tfms:[Transform]=[], extra_aug_tfms:[Transform]=[],
  76. reduce_x_scale=1):
  77. scheds = []
  78. for i in range(len(szs)):
  79. sz = szs[i]
  80. bs = bss[i]
  81. keep_pct = keep_pcts[i]
  82. gen_freeze_to = gen_freeze_tos[i]
  83. critic_lrs = c_lrs * (lrs_unfreeze_factor if gen_freeze_to == 0 else 1.0)
  84. gen_lrs = g_lrs * (lrs_unfreeze_factor if gen_freeze_to == 0 else 1.0)
  85. critic_save_path = path.parent/(save_base_name + '_critic_' + str(sz) + '.h5')
  86. gen_save_path = path.parent/(save_base_name + '_gen_' + str(sz) + '.h5')
  87. sched = GANTrainSchedule(sz=sz, bs=bs, path=path, critic_lrs=critic_lrs, gen_lrs=gen_lrs,
  88. critic_save_path=critic_save_path, gen_save_path=gen_save_path, random_seed=random_seed,
  89. x_noise=x_noise, keep_pct=keep_pct, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms,
  90. reduce_x_scale=reduce_x_scale, gen_freeze_to=gen_freeze_to)
  91. scheds.append(sched)
  92. return scheds
  93. def __init__(self, sz:int, bs:int, path:Path, critic_lrs:[float], gen_lrs:[float],
  94. critic_save_path: Path, gen_save_path: Path, random_seed=None, x_noise:int=None,
  95. keep_pct:float=1.0, num_epochs=1, x_tfms:[Transform]=[], extra_aug_tfms:[Transform]=[],
  96. reduce_x_scale=1, gen_freeze_to=0):
  97. self.md = None
  98. self.data_loader = ImageGenDataLoader(sz=sz, bs=bs, path=path, random_seed=random_seed, x_noise=x_noise,
  99. keep_pct=keep_pct, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, reduce_x_scale=reduce_x_scale)
  100. self.sz = sz
  101. self.bs = bs
  102. self.path = path
  103. self.critic_lrs = np.array(critic_lrs)
  104. self.gen_lrs = np.array(gen_lrs)
  105. self.critic_save_path = critic_save_path
  106. self.gen_save_path = gen_save_path
  107. self.num_epochs=num_epochs
  108. self.gen_freeze_to = gen_freeze_to
  109. #Lazy init
  110. def get_model_data(self):
  111. return self.data_loader.get_model_data()
  112. class GANTrainer():
  113. def __init__(self, netD: nn.Module, netG: GeneratorModule, save_iters:int=1000, genloss_fns:[]=[]):
  114. self.netD = netD
  115. self.netG = netG
  116. self._train_loop_hooks = OrderedDict()
  117. self._train_begin_hooks = OrderedDict()
  118. self.genloss_fns = genloss_fns
  119. self.save_iters=save_iters
  120. self.iters = 0
  121. def register_train_loop_hook(self, hook):
  122. handle = hooks.RemovableHandle(self._train_loop_hooks)
  123. self._train_loop_hooks[handle.id] = hook
  124. return handle
  125. def register_train_begin_hook(self, hook):
  126. handle = hooks.RemovableHandle(self._train_begin_hooks)
  127. self._train_begin_hooks[handle.id] = hook
  128. return handle
  129. def train(self, scheds:[GANTrainSchedule]):
  130. for sched in scheds:
  131. self.md = sched.get_model_data()
  132. self.dpath = sched.critic_save_path
  133. self.gpath = sched.gen_save_path
  134. epochs = sched.num_epochs
  135. lrs_gen = sched.gen_lrs
  136. lrs_critic = sched.critic_lrs
  137. if self.iters == 0:
  138. self.gen_sched = self._generate_clr_sched(self.netG, use_clr_beta=(1,8), lrs=lrs_gen, cycle_len=1)
  139. self.critic_sched = self._generate_clr_sched(self.netD, use_clr_beta=(1,8), lrs=lrs_critic, cycle_len=1)
  140. self._call_train_begin_hooks()
  141. else:
  142. self.gen_sched.init_lrs = lrs_gen
  143. self.critic_sched.init_lrs = lrs_critic
  144. self._get_inner_module(self.netG).freeze_to(sched.gen_freeze_to)
  145. self.critic_sched.on_train_begin()
  146. self.gen_sched.on_train_begin()
  147. for epoch in trange(epochs):
  148. self._train_one_epoch()
  149. def _get_inner_module(self, model:nn.Module):
  150. return model.module if isinstance(model, nn.DataParallel) else model
  151. def _generate_clr_sched(self, model:nn.Module, use_clr_beta: (int), lrs: [float], cycle_len: int):
  152. wds = 1e-7
  153. opt_fn = partial(optim.Adam, betas=(0.0,0.9))
  154. layer_opt = LayerOptimizer(opt_fn, self._get_inner_module(model).get_layer_groups(), lrs, wds)
  155. div,pct = use_clr_beta[:2]
  156. moms = use_clr_beta[2:] if len(use_clr_beta) > 3 else None
  157. cycle_end = None
  158. return CircularLR_beta(layer_opt, len(self.md.trn_dl)*cycle_len, on_cycle_end=cycle_end, div=div, pct=pct, momentums=moms)
  159. def _train_one_epoch(self)->int:
  160. self.netD.train()
  161. self.netG.train()
  162. data_iter = iter(self.md.trn_dl)
  163. n = len(self.md.trn_dl)
  164. with tqdm(total=n) as pbar:
  165. while True:
  166. self.iters+=1
  167. cresult = self._train_critic(data_iter, pbar)
  168. if cresult is None:
  169. break
  170. gresult = self._train_generator(data_iter, pbar, cresult)
  171. if gresult is None:
  172. break
  173. self._save_if_applicable()
  174. self._call_train_loop_hooks(gresult, cresult)
  175. def _call_train_begin_hooks(self):
  176. for hook in self._train_begin_hooks.values():
  177. hook_result = hook()
  178. if hook_result is not None:
  179. raise RuntimeError(
  180. "train begin hooks should never return any values, but '{}'"
  181. "didn't return None".format(hook))
  182. def _call_train_loop_hooks(self, gresult: GenResult, cresult: CriticResult):
  183. for hook in self._train_loop_hooks.values():
  184. hook_result = hook(gresult, cresult)
  185. if hook_result is not None:
  186. raise RuntimeError(
  187. "train loop hooks should never return any values, but '{}'"
  188. "didn't return None".format(hook))
  189. def _get_next_training_images(self, data_iter: Iterable)->(torch.Tensor,torch.Tensor):
  190. x, y = next(data_iter, (None, None))
  191. if x is None or y is None:
  192. return (None,None)
  193. orig_image = V(x)
  194. real_image = V(y)
  195. return (orig_image, real_image)
  196. def _train_critic(self, data_iter: Iterable, pbar: tqdm)->CriticResult:
  197. self._get_inner_module(self.netD).set_trainable(True)
  198. self._get_inner_module(self.netG).set_trainable(False)
  199. orig_image, real_image = self._get_next_training_images(data_iter)
  200. if orig_image is None:
  201. return None
  202. cresult = self._train_critic_once(orig_image, real_image)
  203. pbar.update()
  204. return cresult
  205. def _train_critic_once(self, orig_image: torch.Tensor, real_image: torch.Tensor)->CriticResult:
  206. fake_image = self.netG(orig_image)
  207. dfake_raw,_ = self.netD(fake_image)
  208. dfake = torch.nn.ReLU()(1.0+dfake_raw).mean()
  209. dreal_raw,_ = self.netD(real_image)
  210. dreal = torch.nn.ReLU()(1.0-dreal_raw).mean()
  211. self.netD.zero_grad()
  212. hingeloss = dfake + dreal
  213. hingeloss.backward()
  214. self.critic_sched.layer_opt.opt.step()
  215. self.critic_sched.on_batch_end(to_np(hingeloss))
  216. self.gen_sched.on_batch_end(to_np(-dfake))
  217. return CriticResult(to_np(hingeloss), to_np(dreal), to_np(dfake), to_np(hingeloss))
  218. def _train_generator(self, data_iter: Iterable, pbar: tqdm, cresult: CriticResult)->GenResult:
  219. orig_image, real_image = self._get_next_training_images(data_iter)
  220. if orig_image is None:
  221. return None
  222. gresult = self._train_generator_once(orig_image, real_image, cresult)
  223. pbar.update()
  224. return gresult
  225. def _train_generator_once(self, orig_image: torch.Tensor, real_image: torch.Tensor,
  226. cresult: CriticResult)->GenResult:
  227. self._get_inner_module(self.netD).set_trainable(False)
  228. self._get_inner_module(self.netG).set_trainable(True)
  229. self.netG.zero_grad()
  230. fake_image = self.netG(orig_image)
  231. gcost = -self._get_dscore(fake_image)
  232. gaddlloss = self._calc_addl_gen_loss(real_image, fake_image)
  233. total_loss = gcost if gaddlloss is None else gcost + gaddlloss
  234. total_loss.backward()
  235. self.gen_sched.layer_opt.opt.step()
  236. self.critic_sched.on_batch_end(to_np(cresult.dcost))
  237. self.gen_sched.on_batch_end(to_np(gcost))
  238. return GenResult(to_np(gcost), self.iters, to_np(gaddlloss))
  239. def _save_if_applicable(self):
  240. if self.iters % self.save_iters == 0:
  241. save_model(self.netD, self.dpath)
  242. save_model(self.netG, self.gpath)
  243. def _get_dscore(self, new_image: torch.Tensor):
  244. scores, _ = self.netD(new_image)
  245. return scores.mean()
  246. def _calc_addl_gen_loss(self, real_data: torch.Tensor, fake_data: torch.Tensor)->torch.Tensor:
  247. total_loss = V(0.0)
  248. for loss_fn in self.genloss_fns:
  249. loss = loss_fn(fake_data, real_data)
  250. total_loss = total_loss + loss
  251. return total_loss