save.py 936 B

1234567891011121314151617181920212223242526272829
  1. from fastai.basic_train import Learner, LearnerCallback
  2. from fastai.vision.gan import GANLearner
  3. class GANSaveCallback(LearnerCallback):
  4. """A `LearnerCallback` that saves history of metrics while training `learn` into CSV `filename`."""
  5. def __init__(
  6. self,
  7. learn: GANLearner,
  8. learn_gen: Learner,
  9. filename: str,
  10. save_iters: int = 1000,
  11. ):
  12. super().__init__(learn)
  13. self.learn_gen = learn_gen
  14. self.filename = filename
  15. self.save_iters = save_iters
  16. def on_batch_end(self, iteration: int, epoch: int, **kwargs) -> None:
  17. if iteration == 0:
  18. return
  19. if iteration % self.save_iters == 0:
  20. self._save_gen_learner(iteration=iteration, epoch=epoch)
  21. def _save_gen_learner(self, iteration: int, epoch: int):
  22. filename = '{}_{}_{}'.format(self.filename, epoch, iteration)
  23. self.learn_gen.save(filename)