generators.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201
  1. from fastai.core import *
  2. from fastai.conv_learner import model_meta, cut_model
  3. from fastai.transforms import scale_min
  4. from .modules import ConvBlock, UnetBlock, UpSampleBlock, SaveFeatures
  5. from abc import ABC, abstractmethod
  6. from torchvision import transforms
  7. from torch.nn.utils.spectral_norm import spectral_norm
  8. class GeneratorModule(ABC, nn.Module):
  9. def __init__(self):
  10. super().__init__()
  11. def set_trainable(self, trainable:bool):
  12. set_trainable(self, trainable)
  13. @abstractmethod
  14. def get_layer_groups(self, precompute:bool=False)->[]:
  15. pass
  16. @abstractmethod
  17. def forward(self, x_in:torch.Tensor, max_render_sz:int=400):
  18. pass
  19. def freeze_to(self, n:int):
  20. c=self.get_layer_groups()
  21. for l in c: set_trainable(l, False)
  22. for l in c[n:]: set_trainable(l, True)
  23. def get_device(self):
  24. return next(self.parameters()).device
  25. class AbstractUnet(GeneratorModule):
  26. def __init__(self, nf_factor:int=1, scale:int=1):
  27. super().__init__()
  28. assert (math.log(scale,2)).is_integer()
  29. self.rn, self.lr_cut = self._get_pretrained_resnet_base()
  30. ups = self._get_decoding_layers(nf_factor=nf_factor, scale=scale)
  31. self.relu = nn.ReLU()
  32. self.up1 = ups[0]
  33. self.up2 = ups[1]
  34. self.up3 = ups[2]
  35. self.up4 = ups[3]
  36. self.up5 = ups[4]
  37. self.out= nn.Sequential(ConvBlock(32*nf_factor, 3, ks=3, actn=False, bn=False, sn=True), nn.Tanh())
  38. @abstractmethod
  39. def _get_pretrained_resnet_base(self, layers_cut:int=0):
  40. pass
  41. @abstractmethod
  42. def _get_decoding_layers(self, nf_factor:int, scale:int):
  43. pass
  44. #Gets around irritating inconsistent halving coming from resnet
  45. def _pad(self, x:torch.Tensor, target:torch.Tensor, total_padh:int, total_padw:int)-> torch.Tensor:
  46. h = x.shape[2]
  47. w = x.shape[3]
  48. target_h = target.shape[2]*2
  49. target_w = target.shape[3]*2
  50. if h<target_h or w<target_w:
  51. padh = target_h-h if target_h > h else 0
  52. total_padh = total_padh + padh
  53. padw = target_w-w if target_w > w else 0
  54. total_padw = total_padw + padw
  55. return (F.pad(x, (0,padw,0,padh), "reflect",0), total_padh, total_padw)
  56. return (x, total_padh, total_padw)
  57. def _remove_padding(self, x:torch.Tensor, padh:int, padw:int)->torch.Tensor:
  58. if padw == 0 and padh == 0:
  59. return x
  60. target_h = x.shape[2]-padh
  61. target_w = x.shape[3]-padw
  62. return x[:,:,:target_h, :target_w]
  63. def _encode(self, x:torch.Tensor):
  64. x = self.rn[0](x)
  65. x = self.rn[1](x)
  66. x = self.rn[2](x)
  67. enc0 = x
  68. x = self.rn[3](x)
  69. x = self.rn[4](x)
  70. enc1 = x
  71. x = self.rn[5](x)
  72. enc2 = x
  73. x = self.rn[6](x)
  74. enc3 = x
  75. x = self.rn[7](x)
  76. return (x, enc0, enc1, enc2, enc3)
  77. def _decode(self, x:torch.Tensor, enc0:torch.Tensor, enc1:torch.Tensor, enc2:torch.Tensor, enc3:torch.Tensor):
  78. padh = 0
  79. padw = 0
  80. x = self.relu(x)
  81. enc3, padh, padw = self._pad(enc3, x, padh, padw)
  82. x = self.up1(x, enc3)
  83. enc2, padh, padw = self._pad(enc2, x, padh, padw)
  84. x = self.up2(x, enc2)
  85. enc1, padh, padw = self._pad(enc1, x, padh, padw)
  86. x = self.up3(x, enc1)
  87. enc0, padh, padw = self._pad(enc0, x, padh, padw)
  88. x = self.up4(x, enc0)
  89. #This is a bit too much padding being removed, but I
  90. #haven't yet figured out a good way to determine what
  91. #exactly should be removed. This is consistently more
  92. #than enough though.
  93. x = self.up5(x)
  94. x = self.out(x)
  95. x = self._remove_padding(x, padh, padw)
  96. return x
  97. def forward(self, x:torch.Tensor):
  98. x, enc0, enc1, enc2, enc3 = self._encode(x)
  99. x = self._decode(x, enc0, enc1, enc2, enc3)
  100. return x
  101. def get_layer_groups(self, precompute:bool=False)->[]:
  102. lgs = list(split_by_idxs(children(self.rn), [self.lr_cut]))
  103. return lgs + [children(self)[1:]]
  104. def close(self):
  105. for sf in self.sfs:
  106. sf.remove()
  107. class Unet34(AbstractUnet):
  108. def __init__(self, nf_factor:int=1, scale:int=1):
  109. super().__init__(nf_factor=nf_factor, scale=scale)
  110. def _get_pretrained_resnet_base(self, layers_cut:int=0):
  111. f = resnet34
  112. cut,lr_cut = model_meta[f]
  113. cut-=layers_cut
  114. layers = cut_model(f(True), cut)
  115. return nn.Sequential(*layers), lr_cut
  116. def _get_decoding_layers(self, nf_factor:int, scale:int):
  117. self_attention=True
  118. bn=True
  119. sn=True
  120. leakyReLu=False
  121. layers = []
  122. layers.append(UnetBlock(512,256,512*nf_factor, sn=sn, leakyReLu=leakyReLu, bn=bn))
  123. layers.append(UnetBlock(512*nf_factor,128,512*nf_factor, sn=sn, leakyReLu=leakyReLu, bn=bn))
  124. layers.append(UnetBlock(512*nf_factor,64,512*nf_factor, sn=sn, self_attention=self_attention, leakyReLu=leakyReLu, bn=bn))
  125. layers.append(UnetBlock(512*nf_factor,64,256*nf_factor, sn=sn, leakyReLu=leakyReLu, bn=bn))
  126. layers.append(UpSampleBlock(256*nf_factor, 32*nf_factor, 2*scale, sn=sn, leakyReLu=leakyReLu, bn=bn))
  127. return layers
  128. class Unet101(AbstractUnet):
  129. def __init__(self, nf_factor:int=1, scale:int=1):
  130. super().__init__(nf_factor=nf_factor, scale=scale)
  131. def _get_pretrained_resnet_base(self, layers_cut:int=0):
  132. f = resnet101
  133. cut,lr_cut = model_meta[f]
  134. cut-=layers_cut
  135. layers = cut_model(f(True), cut)
  136. return nn.Sequential(*layers), lr_cut
  137. def _get_decoding_layers(self, nf_factor:int, scale:int):
  138. self_attention=True
  139. bn=True
  140. sn=True
  141. leakyReLu=False
  142. layers = []
  143. layers.append(UnetBlock(2048,1024,512*nf_factor, sn=sn, leakyReLu=leakyReLu, bn=bn))
  144. layers.append(UnetBlock(512*nf_factor,512,512*nf_factor, sn=sn, leakyReLu=leakyReLu, bn=bn))
  145. layers.append(UnetBlock(512*nf_factor,256,512*nf_factor, sn=sn, self_attention=self_attention, leakyReLu=leakyReLu, bn=bn))
  146. layers.append(UnetBlock(512*nf_factor,64,256*nf_factor, sn=sn, leakyReLu=leakyReLu, bn=bn))
  147. layers.append(UpSampleBlock(256*nf_factor, 32*nf_factor, 2*scale, sn=sn, leakyReLu=leakyReLu, bn=bn))
  148. return layers
  149. class Unet152(AbstractUnet):
  150. def __init__(self, nf_factor:int=1, scale:int=1):
  151. super().__init__(nf_factor=nf_factor, scale=scale)
  152. def _get_pretrained_resnet_base(self, layers_cut:int=0):
  153. f = resnet152
  154. cut,lr_cut = model_meta[f]
  155. cut-=layers_cut
  156. layers = cut_model(f(True), cut)
  157. return nn.Sequential(*layers), lr_cut
  158. def _get_decoding_layers(self, nf_factor:int, scale:int):
  159. self_attention=True
  160. bn=True
  161. sn=True
  162. leakyReLu=False
  163. layers = []
  164. layers.append(UnetBlock(2048,1024,512*nf_factor, sn=sn, leakyReLu=leakyReLu, bn=bn))
  165. layers.append(UnetBlock(512*nf_factor,512,512*nf_factor, sn=sn, leakyReLu=leakyReLu, bn=bn))
  166. layers.append(UnetBlock(512*nf_factor,256,512*nf_factor, sn=sn, self_attention=self_attention, leakyReLu=leakyReLu, bn=bn))
  167. layers.append(UnetBlock(512*nf_factor,64,256*nf_factor, sn=sn, leakyReLu=leakyReLu, bn=bn))
  168. layers.append(UpSampleBlock(256*nf_factor, 32*nf_factor, 2*scale, sn=sn, leakyReLu=leakyReLu, bn=bn))
  169. return layers