loss.py 1.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940
  1. from fastai.torch_imports import *
  2. from fastai.core import *
  3. from fastai.conv_learner import children
  4. from .modules import SaveFeatures
  5. import torchvision.models as models
  6. class FeatureLoss(nn.Module):
  7. def __init__(self, block_wgts:[float]=[0.2,0.7,0.1], multiplier:float=1.0):
  8. super().__init__()
  9. m_vgg = vgg16(True)
  10. blocks = [i-1 for i,o in enumerate(children(m_vgg)) if isinstance(o,nn.MaxPool2d)]
  11. blocks, [m_vgg[i] for i in blocks]
  12. layer_ids = blocks[:3]
  13. vgg_layers = children(m_vgg)[:23]
  14. m_vgg = nn.Sequential(*vgg_layers).cuda().eval()
  15. set_trainable(m_vgg, False)
  16. self.m,self.wgts = m_vgg,block_wgts
  17. self.sfs = [SaveFeatures(m_vgg[i]) for i in layer_ids]
  18. self.multiplier = multiplier
  19. def forward(self, input, target, sum_layers:bool=True):
  20. self.m(VV(target.data))
  21. res = [F.l1_loss(input,target)/100]
  22. targ_feat = [V(o.features.data.clone()) for o in self.sfs]
  23. self.m(input)
  24. res += [F.l1_loss(self._flatten(inp.features),self._flatten(targ))*wgt
  25. for inp,targ,wgt in zip(self.sfs, targ_feat, self.wgts)]
  26. if sum_layers: res = sum(res)
  27. return res*self.multiplier
  28. def _flatten(self, x:torch.Tensor):
  29. return x.view(x.size(0), -1)
  30. def close(self):
  31. for o in self.sfs: o.remove()