critics.py 1.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950
  1. from fastai.core import *
  2. from fastai.torch_core import *
  3. from fastai.vision import *
  4. from fastai.vision.gan import AdaptiveLoss, accuracy_thresh_expand
  5. _conv_args = dict(leaky=0.2, norm_type=NormType.Spectral)
  6. def _conv(ni:int, nf:int, ks:int=3, stride:int=1, **kwargs):
  7. return conv_layer(ni, nf, ks=ks, stride=stride, **_conv_args, **kwargs)
  8. def custom_gan_critic(n_channels:int=3, nf:int=256, n_blocks:int=3, p:int=0.15):
  9. "Critic to train a `GAN`."
  10. layers = [
  11. _conv(n_channels, nf, ks=4, stride=2),
  12. nn.Dropout2d(p/2)]
  13. for i in range(n_blocks):
  14. layers += [
  15. _conv(nf, nf, ks=3, stride=1),
  16. nn.Dropout2d(p),
  17. _conv(nf, nf*2, ks=4, stride=2, self_attention=(i==0))]
  18. nf *= 2
  19. layers += [
  20. _conv(nf, nf, ks=3, stride=1),
  21. _conv(nf, 1, ks=4, bias=False, padding=0, use_activ=False),
  22. Flatten()]
  23. return nn.Sequential(*layers)
  24. def colorize_crit_learner(data:ImageDataBunch, loss_critic=AdaptiveLoss(nn.BCEWithLogitsLoss()), nf:int=256)->Learner:
  25. return Learner(data, custom_gan_critic(nf=nf), metrics=accuracy_thresh_expand, loss_func=loss_critic, wd=1e-3)
  26. def custom_gan_critic2(n_channels:int=3, nf:int=256, n_blocks:int=3, p:int=0.15):
  27. "Critic to train a `GAN`."
  28. layers = [
  29. _conv(n_channels, nf, ks=4, stride=2),
  30. nn.Dropout2d(p/2),
  31. _conv(nf, nf, ks=3, stride=1)]
  32. for i in range(n_blocks):
  33. layers += [
  34. nn.Dropout2d(p),
  35. _conv(nf, nf*2, ks=4, stride=2, self_attention=(i==0))]
  36. nf *= 2
  37. layers += [
  38. _conv(nf, 1, ks=4, bias=False, padding=0, use_activ=False),
  39. Flatten()]
  40. return nn.Sequential(*layers)
  41. def colorize_crit_learner2(data:ImageDataBunch, loss_critic=AdaptiveLoss(nn.BCEWithLogitsLoss()), nf:int=256)->Learner:
  42. return Learner(data, custom_gan_critic2(nf=nf), metrics=accuracy_thresh_expand, loss_func=loss_critic, wd=1e-3)