generators.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. from fastai.core import *
  2. from fastai.conv_learner import model_meta, cut_model
  3. from fasterai.modules import ConvBlock, UnetBlock, UpSampleBlock, SaveFeatures
  4. from abc import ABC, abstractmethod
  5. class GeneratorModule(ABC, nn.Module):
  6. def __init__(self):
  7. super().__init__()
  8. def set_trainable(self, trainable:bool):
  9. set_trainable(self, trainable)
  10. @abstractmethod
  11. def get_layer_groups(self, precompute:bool=False)->[]:
  12. pass
  13. def freeze_to(self, n:int):
  14. c=self.get_layer_groups()
  15. for l in c: set_trainable(l, False)
  16. for l in c[n:]: set_trainable(l, True)
  17. def get_device(self):
  18. return next(self.parameters()).device
  19. class Unet34(GeneratorModule):
  20. @staticmethod
  21. def get_pretrained_resnet_base(layers_cut:int=0):
  22. f = resnet34
  23. cut,lr_cut = model_meta[f]
  24. cut-=layers_cut
  25. layers = cut_model(f(True), cut)
  26. return nn.Sequential(*layers), lr_cut
  27. def __init__(self, nf_factor:int=1, scale:int=1):
  28. super().__init__()
  29. assert (math.log(scale,2)).is_integer()
  30. leakyReLu=False
  31. self_attention=True
  32. bn=True
  33. sn=True
  34. self.rn, self.lr_cut = Unet34.get_pretrained_resnet_base()
  35. self.relu = nn.ReLU()
  36. self.up1 = UnetBlock(512,256,512*nf_factor, sn=sn, leakyReLu=leakyReLu, bn=bn)
  37. self.up2 = UnetBlock(512*nf_factor,128,512*nf_factor, sn=sn, leakyReLu=leakyReLu, bn=bn)
  38. self.up3 = UnetBlock(512*nf_factor,64,512*nf_factor, sn=sn, self_attention=self_attention, leakyReLu=leakyReLu, bn=bn)
  39. self.up4 = UnetBlock(512*nf_factor,64,256*nf_factor, sn=sn, leakyReLu=leakyReLu, bn=bn)
  40. self.up5 = UpSampleBlock(256*nf_factor, 32*nf_factor, 2*scale, sn=sn, leakyReLu=leakyReLu, bn=bn)
  41. self.out= nn.Sequential(ConvBlock(32*nf_factor, 3, ks=3, actn=False, bn=False, sn=sn), nn.Tanh())
  42. #Gets around irritating inconsistent halving coming from resnet
  43. def _pad(self, x:torch.Tensor, target:torch.Tensor)-> torch.Tensor:
  44. h = x.shape[2]
  45. w = x.shape[3]
  46. target_h = target.shape[2]*2
  47. target_w = target.shape[3]*2
  48. if h<target_h or w<target_w:
  49. padh = target_h-h if target_h > h else 0
  50. padw = target_w-w if target_w > w else 0
  51. return F.pad(x, (0,padw,0,padh), "constant",0)
  52. return x
  53. def forward(self, x_in:torch.Tensor):
  54. x = self.rn[0](x_in)
  55. x = self.rn[1](x)
  56. x = self.rn[2](x)
  57. enc0 = x
  58. x = self.rn[3](x)
  59. x = self.rn[4](x)
  60. enc1 = x
  61. x = self.rn[5](x)
  62. enc2 = x
  63. x = self.rn[6](x)
  64. enc3 = x
  65. x = self.rn[7](x)
  66. x = self.relu(x)
  67. x = self.up1(x, self._pad(enc3, x))
  68. x = self.up2(x, self._pad(enc2, x))
  69. x = self.up3(x, self._pad(enc1, x))
  70. x = self.up4(x, self._pad(enc0, x))
  71. x = self.up5(x)
  72. x = self.out(x)
  73. return x
  74. def get_layer_groups(self, precompute:bool=False)->[]:
  75. lgs = list(split_by_idxs(children(self.rn), [self.lr_cut]))
  76. return lgs + [children(self)[1:]]
  77. def close(self):
  78. for sf in self.sfs:
  79. sf.remove()
  80. class Unet34_V2(GeneratorModule):
  81. @staticmethod
  82. def get_pretrained_resnet_base(layers_cut:int=0):
  83. f = resnet34
  84. cut,lr_cut = model_meta[f]
  85. cut-=layers_cut
  86. layers = cut_model(f(True), cut)
  87. return nn.Sequential(*layers), lr_cut
  88. def __init__(self, nf_factor:int=1, scale:int=1):
  89. super().__init__()
  90. assert (math.log(scale,2)).is_integer()
  91. leakyReLu=False
  92. self_attention=True
  93. bn=True
  94. sn=True
  95. self.rn, self.lr_cut = Unet34.get_pretrained_resnet_base()
  96. self.relu = nn.ReLU()
  97. self.up1 = UnetBlock(512,256,512*nf_factor, sn=sn, leakyReLu=leakyReLu, bn=bn)
  98. self.up2 = UnetBlock(512*nf_factor,128,512*nf_factor, sn=sn, leakyReLu=leakyReLu, bn=bn)
  99. self.up3 = UnetBlock(512*nf_factor,64,256*nf_factor, sn=sn, self_attention=self_attention, leakyReLu=leakyReLu, bn=bn)
  100. self.up4 = UnetBlock(256*nf_factor,64,128*nf_factor, sn=sn, leakyReLu=leakyReLu, bn=bn)
  101. self.up5 = UpSampleBlock(128*nf_factor, 32*nf_factor, 2*scale, sn=sn, leakyReLu=leakyReLu, bn=bn)
  102. self.out= nn.Sequential(ConvBlock(32*nf_factor, 3, ks=3, actn=False, bn=False, sn=sn), nn.Tanh())
  103. #Gets around irritating inconsistent halving coming from resnet
  104. def _pad(self, x:torch.Tensor, target:torch.Tensor)-> torch.Tensor:
  105. h = x.shape[2]
  106. w = x.shape[3]
  107. target_h = target.shape[2]*2
  108. target_w = target.shape[3]*2
  109. if h<target_h or w<target_w:
  110. padh = target_h-h if target_h > h else 0
  111. padw = target_w-w if target_w > w else 0
  112. return F.pad(x, (0,padw,0,padh), "constant",0)
  113. return x
  114. def forward(self, x_in:torch.Tensor):
  115. x = self.rn[0](x_in)
  116. x = self.rn[1](x)
  117. x = self.rn[2](x)
  118. enc0 = x
  119. x = self.rn[3](x)
  120. x = self.rn[4](x)
  121. enc1 = x
  122. x = self.rn[5](x)
  123. enc2 = x
  124. x = self.rn[6](x)
  125. enc3 = x
  126. x = self.rn[7](x)
  127. x = self.relu(x)
  128. x = self.up1(x, self._pad(enc3, x))
  129. x = self.up2(x, self._pad(enc2, x))
  130. x = self.up3(x, self._pad(enc1, x))
  131. x = self.up4(x, self._pad(enc0, x))
  132. x = self.up5(x)
  133. x = self.out(x)
  134. return x
  135. def get_layer_groups(self, precompute:bool=False)->[]:
  136. lgs = list(split_by_idxs(children(self.rn), [self.lr_cut]))
  137. return lgs + [children(self)[1:]]
  138. def close(self):
  139. for sf in self.sfs:
  140. sf.remove()