save.py 939 B

123456789101112131415161718192021
  1. from fastai.torch_core import *
  2. from fastai.basic_data import DataBunch
  3. from fastai.callback import *
  4. from fastai.basic_train import Learner, LearnerCallback
  5. from fastai.vision.gan import GANLearner
  6. class GANSaveCallback(LearnerCallback):
  7. "A `LearnerCallback` that saves history of metrics while training `learn` into CSV `filename`."
  8. def __init__(self, learn:GANLearner, learn_gen:Learner, filename:str, save_iters:int=1000):
  9. super().__init__(learn)
  10. self.learn_gen, self.filename, self.save_iters = learn_gen, filename, save_iters
  11. def on_batch_end(self, iteration:int, epoch:int, **kwargs)->None:
  12. if iteration == 0: return
  13. if iteration % self.save_iters == 0:
  14. self._save_gen_learner(iteration=iteration, epoch=epoch)
  15. def _save_gen_learner(self, iteration:int, epoch:int):
  16. fn = self.filename + '_' + str(epoch) + '_' + str(iteration)
  17. self.learn_gen.save(fn)