1234567891011121314151617181920212223242526272829 |
- from fastai.basic_train import Learner, LearnerCallback
- from fastai.vision.gan import GANLearner
- class GANSaveCallback(LearnerCallback):
- """A `LearnerCallback` that saves history of metrics while training `learn` into CSV `filename`."""
- def __init__(
- self,
- learn: GANLearner,
- learn_gen: Learner,
- filename: str,
- save_iters: int = 1000,
- ):
- super().__init__(learn)
- self.learn_gen = learn_gen
- self.filename = filename
- self.save_iters = save_iters
- def on_batch_end(self, iteration: int, epoch: int, **kwargs) -> None:
- if iteration == 0:
- return
- if iteration % self.save_iters == 0:
- self._save_gen_learner(iteration=iteration, epoch=epoch)
- def _save_gen_learner(self, iteration: int, epoch: int):
- filename = '{}_{}_{}'.format(self.filename, epoch, iteration)
- self.learn_gen.save(filename)
|