critics.py 1.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344
  1. from fastai.core import *
  2. from fastai.torch_core import *
  3. from fastai.vision import *
  4. from fastai.vision.gan import AdaptiveLoss, accuracy_thresh_expand
  5. _conv_args = dict(leaky=0.2, norm_type=NormType.Spectral)
  6. def _conv(ni: int, nf: int, ks: int = 3, stride: int = 1, **kwargs):
  7. return conv_layer(ni, nf, ks=ks, stride=stride, **_conv_args, **kwargs)
  8. def custom_gan_critic(
  9. n_channels: int = 3, nf: int = 256, n_blocks: int = 3, p: int = 0.15
  10. ):
  11. "Critic to train a `GAN`."
  12. layers = [_conv(n_channels, nf, ks=4, stride=2), nn.Dropout2d(p / 2)]
  13. for i in range(n_blocks):
  14. layers += [
  15. _conv(nf, nf, ks=3, stride=1),
  16. nn.Dropout2d(p),
  17. _conv(nf, nf * 2, ks=4, stride=2, self_attention=(i == 0)),
  18. ]
  19. nf *= 2
  20. layers += [
  21. _conv(nf, nf, ks=3, stride=1),
  22. _conv(nf, 1, ks=4, bias=False, padding=0, use_activ=False),
  23. Flatten(),
  24. ]
  25. return nn.Sequential(*layers)
  26. def colorize_crit_learner(
  27. data: ImageDataBunch,
  28. loss_critic=AdaptiveLoss(nn.BCEWithLogitsLoss()),
  29. nf: int = 256,
  30. ) -> Learner:
  31. return Learner(
  32. data,
  33. custom_gan_critic(nf=nf),
  34. metrics=accuracy_thresh_expand,
  35. loss_func=loss_critic,
  36. wd=1e-3,
  37. )