Forráskód Böngészése

Finally have a very stable, very cool model

Jason Antic 6 éve
szülő
commit
52b802a166
4 módosított fájl, 149 hozzáadás és 243 törlés
  1. 101 8
      ColorizeTraining.ipynb
  2. 3 2
      fasterai/filters.py
  3. 3 156
      fasterai/loss.py
  4. 42 77
      fasterai/tensorboard.py

+ 101 - 8
ColorizeTraining.ipynb

@@ -43,12 +43,9 @@
     "IMAGENET = Path('data/imagenet/ILSVRC/Data/CLS-LOC')\n",
     "BWIMAGENET = Path('data/imagenet/ILSVRC/Data/CLS-LOC/bandw')\n",
     "\n",
-    "proj_id = 'colorizeESR45'\n",
+    "proj_id = 'colorize1'\n",
     "TENSORBOARD_PATH = Path('data/tensorboard/' + proj_id)\n",
     "\n",
-    "gpath = IMAGENET.parent/(proj_id + '_gen_64.h5')\n",
-    "dpath = IMAGENET.parent/(proj_id + '_critic_64.h5')\n",
-    "\n",
     "torch.backends.cudnn.benchmark=True"
    ]
   },
@@ -131,7 +128,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "def colorize_gen_learner_exp(data:ImageDataBunch, gen_loss=FeatureLoss4(), arch=models.resnet34):\n",
+    "def colorize_gen_learner_exp(data:ImageDataBunch, gen_loss=FeatureLoss(), arch=models.resnet34):\n",
     "    return unet_learner3(data, arch, wd=1e-3, blur=True, norm_type=NormType.Spectral,\n",
     "                        self_attention=True, y_range=(-3.,3.), loss_func=gen_loss)"
    ]
@@ -157,7 +154,7 @@
     "learn_crit = colorize_crit_learner(data=data, nf=256)\n",
     "learn_crit.unfreeze()\n",
     "\n",
-    "gen_loss = FeatureLoss4()\n",
+    "gen_loss = FeatureLoss()\n",
     "learn_gen = colorize_gen_learner_exp(data=data)\n",
     "\n",
     "switcher = partial(AdaptiveGANSwitcher, critic_thresh=0.65)\n",
@@ -386,6 +383,15 @@
     "save()"
    ]
   },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "load()"
+   ]
+  },
   {
    "cell_type": "markdown",
    "metadata": {},
@@ -571,7 +577,7 @@
    "source": [
     "lr=lr/1.5\n",
     "sz=224\n",
-    "bs=bs//1.5"
+    "bs=int(bs//1.5)"
    ]
   },
   {
@@ -646,7 +652,94 @@
    "execution_count": null,
    "metadata": {},
    "outputs": [],
-   "source": []
+   "source": [
+    "load()"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## 256px"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "lr=lr/1.75\n",
+    "sz=256\n",
+    "bs=int(bs//1.5)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "learn.data = get_data(sz=sz, bs=bs, keep_pct=0.1)\n",
+    "learn_gen.freeze_to(-1)\n",
+    "learn.fit(1,lr/10)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "save()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "learn.data = get_data(sz=sz, bs=bs, keep_pct=0.25)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "learn_gen.freeze_to(-1)\n",
+    "learn.fit(1,lr)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "save()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "learn_gen.unfreeze()\n",
+    "learn.fit(1,lr*unfreeze_fctr)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "save()"
+   ]
   },
   {
    "cell_type": "code",

+ 3 - 2
fasterai/filters.py

@@ -30,7 +30,7 @@ class BaseFilter(IFilter):
         #a simple stretch to fit a square really makes a big difference in rendering quality/consistency.
         #I've tried padding to the square as well (reflect, symetric, constant, etc).  Not as good!
         targ_sz = (targ, targ)
-        return orig.resize(targ_sz, resample=PIL.Image.BILINEAR).convert('RGB')
+        return orig.resize(targ_sz, resample=PIL.Image.BILINEAR)
 
     def _get_model_ready_image(self, orig:PilImage, sz:int)->PilImage:
         result = self._scale_to_square(orig, sz)
@@ -51,7 +51,7 @@ class BaseFilter(IFilter):
 
     def _unsquare(self, image:PilImage, orig:PilImage)->PilImage:
         targ_sz = orig.size
-        image = image.resize(targ_sz, resample=PIL.Image.BILINEAR).convert('RGB')
+        image = image.resize(targ_sz, resample=PIL.Image.BILINEAR)
         return image
 
 
@@ -64,6 +64,7 @@ class ColorizerFilter(BaseFilter):
     def filter(self, orig_image:PilImage, filtered_image:PilImage, render_factor:int)->PilImage:
         render_sz = render_factor * self.render_base
         model_image = self._model_process(orig=filtered_image, sz=render_sz)
+
         if self.map_to_orig:
             return self._post_process(model_image, orig_image)
         else:

+ 3 - 156
fasterai/loss.py

@@ -5,162 +5,8 @@ from fastai.callbacks  import hook_outputs
 import torchvision.models as models
 
 
-class FeatureLoss(nn.Module):
-    def __init__(self, layer_wgts:[float]=[5.0,15.0,2.0], gram_wgt:float=5e3):
-        super().__init__()
-        self.gram_wgt = gram_wgt
-        self.base_loss = F.l1_loss
-        self.m_feat = models.vgg16_bn(True).features.cuda().eval()
-        requires_grad(self.m_feat, False)
-        blocks = [i-1 for i,o in enumerate(children(self.m_feat)) if isinstance(o,nn.MaxPool2d)]
-        layer_ids = blocks[2:5]
-        self.loss_features = [self.m_feat[i] for i in layer_ids]
-        self.hooks = hook_outputs(self.loss_features, detach=False)
-        self.wgts = layer_wgts
-        self.metric_names = ['pixel',] + [f'feat_{i}' for i in range(len(layer_ids))
-              ] + [f'gram_{i}' for i in range(len(layer_ids))]
-
-    def _gram_matrix(self, x:torch.Tensor):
-        n,c,h,w = x.size()
-        x = x.view(n, c, -1)
-        return (x @ x.transpose(1,2))/(c*h*w)
-
-    def make_features(self, x:torch.Tensor, clone=False):
-        self.m_feat(x)
-        return [(o.clone() if clone else o) for o in self.hooks.stored]
-    
-    def forward(self, input:torch.Tensor, target:torch.Tensor):
-        out_feat = self.make_features(target, clone=True)
-        in_feat = self.make_features(input)
-        self.feat_losses = [self.base_loss(f_in, f_out)*w
-                             for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)]
-
-        self.feat_losses += [self.base_loss(input,target)]
-
-        self.feat_losses += [self.base_loss(self._gram_matrix(f_in), self._gram_matrix(f_out))*w**2 * self.gram_wgt
-                             for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)]
-                             
-        self.metrics = dict(zip(self.metric_names, self.feat_losses))
-        return sum(self.feat_losses)
-    
-    def __del__(self): 
-        self.hooks.remove()
-
-
-
-class FeatureLoss2(nn.Module):
-    def __init__(self, layer_wgts:[float]=[20.0,70.0,10.0], gram_wgt:float=5e3):
-        super().__init__()
-        self.gram_wgt = gram_wgt
-        self.base_loss = F.l1_loss
-        self.m_feat = models.vgg16_bn(True).features.cuda().eval()
-        requires_grad(self.m_feat, False)
-        blocks = [i-1 for i,o in enumerate(children(self.m_feat)) if isinstance(o,nn.MaxPool2d)]
-        layer_ids = blocks[2:5]
-        self.loss_features = [self.m_feat[i] for i in layer_ids]
-        self.hooks = hook_outputs(self.loss_features, detach=False)
-        self.wgts = layer_wgts
-        self.metric_names = ['pixel',] + [f'feat_{i}' for i in range(len(layer_ids))
-              ] + [f'gram_{i}' for i in range(len(layer_ids))]
-
-    def _gram_matrix(self, x:torch.Tensor):
-        n,c,h,w = x.size()
-        x = x.view(n, c, -1)
-        return (x @ x.transpose(1,2))/(c*h*w)
-
-    def make_features(self, x:torch.Tensor, clone=False):
-        self.m_feat(x)
-        return [(o.clone() if clone else o) for o in self.hooks.stored]
-    
-    def forward(self, input:torch.Tensor, target:torch.Tensor):
-        out_feat = self.make_features(target, clone=True)
-        in_feat = self.make_features(input)
-        self.feat_losses = [self.base_loss(f_in, f_out)*w
-                             for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)]
-
-        self.feat_losses += [self.base_loss(input,target)*100]
-
-        self.feat_losses += [self.base_loss(self._gram_matrix(f_in), self._gram_matrix(f_out))*w**2 * self.gram_wgt
-                             for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)]
-                             
-        self.metrics = dict(zip(self.metric_names, self.feat_losses))
-        return sum(self.feat_losses)
-    
-    def __del__(self): 
-        self.hooks.remove()
-
-
-#Includes wasserstein loss
-class FeatureLoss3(nn.Module):
-    def __init__(self, layer_wgts=[5,15,2], wass_wgts=[3.0,0.7,0.01]):
-        super().__init__()
-
-        self.m_feat = models.vgg16_bn(True).features.cuda().eval()
-        requires_grad(self.m_feat, False)
-        blocks = [i-1 for i,o in enumerate(children(self.m_feat)) if isinstance(o,nn.MaxPool2d)]
-        layer_ids = blocks[2:5]
-        self.loss_features = [self.m_feat[i] for i in layer_ids]
-        self.hooks = hook_outputs(self.loss_features, detach=False)
-        self.wgts = layer_wgts
-        self.wass_wgts = wass_wgts
-        self.metric_names = ['pixel',] + [f'feat_{i}' for i in range(len(layer_ids))
-              ] + [f'wass_{i}' for i in range(len(layer_ids))]
-        self.base_loss = F.l1_loss
-
-    def _make_features(self, x, clone=False):
-        self.m_feat(x)
-        return [(o.clone() if clone else o) for o in self.hooks.stored]
-
-    def _calc_2_moments(self, tensor):
-        chans = tensor.shape[1]
-        tensor = tensor.view(1, chans, -1)
-        n = tensor.shape[2] 
-        mu = tensor.mean(2)
-        tensor = (tensor - mu[:,:,None]).squeeze(0)
-        cov = torch.mm(tensor, tensor.t()) / float(n)   
-        return mu, cov
-
-    def _get_style_vals(self, tensor):
-        mean, cov = self._calc_2_moments(tensor) 
-        eigvals, eigvects = torch.symeig(cov, eigenvectors=True)
-        eigroot_mat = torch.diag(torch.sqrt(eigvals.clamp(min=0)))     
-        root_cov = torch.mm(torch.mm(eigvects, eigroot_mat), eigvects.t())  
-        tr_cov = eigvals.clamp(min=0).sum() 
-        return mean, tr_cov, root_cov
-
-    def _calc_l2wass_dist(self, mean_stl, tr_cov_stl, root_cov_stl, mean_synth, cov_synth):
-        tr_cov_synth = torch.symeig(cov_synth, eigenvectors=True)[0].clamp(min=0).sum()
-        mean_diff_squared = (mean_stl - mean_synth).pow(2).sum()
-        cov_prod = torch.mm(torch.mm(root_cov_stl, cov_synth), root_cov_stl)
-        var_overlap = torch.sqrt(torch.symeig(cov_prod, eigenvectors=True)[0].clamp(min=0)+1e-8).sum()
-        dist = mean_diff_squared + tr_cov_stl + tr_cov_synth - 2*var_overlap
-        return dist
-
-    def _single_wass_loss(self, pred, targ):
-        mean_test, tr_cov_test, root_cov_test = targ
-        mean_synth, cov_synth = self._calc_2_moments(pred)
-        loss = self._calc_l2wass_dist(mean_test, tr_cov_test, root_cov_test, mean_synth, cov_synth)
-        return loss
-    
-    def forward(self, input, target):
-        out_feat = self._make_features(target, clone=True)
-        in_feat = self._make_features(input)
-        self.feat_losses = [self.base_loss(input,target)]
-        self.feat_losses += [self.base_loss(f_in, f_out)*w
-                             for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)]
-        
-        styles = [self._get_style_vals(i) for i in out_feat]
-        self.feat_losses += [self._single_wass_loss(f_pred, f_targ)*w
-                            for f_pred, f_targ, w in zip(in_feat, styles, self.wass_wgts)]
-        
-        self.metrics = dict(zip(self.metric_names, self.feat_losses))
-        return sum(self.feat_losses)
-    
-    def __del__(self): self.hooks.remove()
-
-
 #"Before activations" in ESRGAN paper
-class FeatureLoss4(nn.Module):
+class FeatureLoss(nn.Module):
     def __init__(self, layer_wgts=[5,15,2]):
         super().__init__()
 
@@ -188,4 +34,5 @@ class FeatureLoss4(nn.Module):
         self.metrics = dict(zip(self.metric_names, self.feat_losses))
         return sum(self.feat_losses)
     
-    def __del__(self): self.hooks.remove()
+    def __del__(self): self.hooks.remove()
+

+ 42 - 77
fasterai/tensorboard.py

@@ -10,6 +10,7 @@ import torchvision.utils as vutils
 from tensorboardX import SummaryWriter
 
 
+
 class ModelGraphVisualizer():
     def __init__(self):
         return 
@@ -26,10 +27,10 @@ class ModelHistogramVisualizer():
     def __init__(self):
         return 
 
-    def write_tensorboard_histograms(self, model:nn.Module, iter_count:int, tbwriter:SummaryWriter, name:str='model'):
+    def write_tensorboard_histograms(self, model:nn.Module, iteration:int, tbwriter:SummaryWriter, name:str='model'):
         try:
             for param_name, param in model.named_parameters():
-                tbwriter.add_histogram(name + '/weights/' + param_name, param, iter_count)
+                tbwriter.add_histogram(name + '/weights/' + param_name, param, iteration)
         except Exception as e:
             print(("Failed to update histogram for model:  {0}").format(e))
 
@@ -38,7 +39,7 @@ class ModelStatsVisualizer():
     def __init__(self):
         return 
 
-    def write_tensorboard_stats(self, model:nn.Module, iter_count:int, tbwriter:SummaryWriter, name:str='model'):
+    def write_tensorboard_stats(self, model:nn.Module, iteration:int, tbwriter:SummaryWriter, name:str='model_stats'):
         try:
             gradients = [x.grad  for x in model.parameters() if x.grad is not None]
             gradient_nps = [to_np(x.data) for x in gradients]
@@ -47,45 +48,45 @@ class ModelStatsVisualizer():
                 return 
 
             avg_norm = sum(x.data.norm() for x in gradients)/len(gradients)
-            tbwriter.add_scalar(name + '/gradients/avg_norm', avg_norm, iter_count)
+            tbwriter.add_scalar(name + '/gradients/avg_norm', avg_norm, iteration)
 
             median_norm = statistics.median(x.data.norm() for x in gradients)
-            tbwriter.add_scalar(name + '/gradients/median_norm', median_norm, iter_count)
+            tbwriter.add_scalar(name + '/gradients/median_norm', median_norm, iteration)
 
             max_norm = max(x.data.norm() for x in gradients)
-            tbwriter.add_scalar(name + '/gradients/max_norm', max_norm, iter_count)
+            tbwriter.add_scalar(name + '/gradients/max_norm', max_norm, iteration)
 
             min_norm = min(x.data.norm() for x in gradients)
-            tbwriter.add_scalar(name + '/gradients/min_norm', min_norm, iter_count)
+            tbwriter.add_scalar(name + '/gradients/min_norm', min_norm, iteration)
 
             num_zeros = sum((np.asarray(x)==0.0).sum() for x in  gradient_nps)
-            tbwriter.add_scalar(name + '/gradients/num_zeros', num_zeros, iter_count)
+            tbwriter.add_scalar(name + '/gradients/num_zeros', num_zeros, iteration)
 
 
             avg_gradient= sum(x.data.mean() for x in gradients)/len(gradients)
-            tbwriter.add_scalar(name + '/gradients/avg_gradient', avg_gradient, iter_count)
+            tbwriter.add_scalar(name + '/gradients/avg_gradient', avg_gradient, iteration)
 
             median_gradient = statistics.median(x.data.median() for x in gradients)
-            tbwriter.add_scalar(name + '/gradients/median_gradient', median_gradient, iter_count)
+            tbwriter.add_scalar(name + '/gradients/median_gradient', median_gradient, iteration)
 
             max_gradient = max(x.data.max() for x in gradients) 
-            tbwriter.add_scalar(name + '/gradients/max_gradient', max_gradient, iter_count)
+            tbwriter.add_scalar(name + '/gradients/max_gradient', max_gradient, iteration)
 
             min_gradient = min(x.data.min() for x in gradients) 
-            tbwriter.add_scalar(name + '/gradients/min_gradient', min_gradient, iter_count)
+            tbwriter.add_scalar(name + '/gradients/min_gradient', min_gradient, iteration)
         except Exception as e:
             print(("Failed to update tensorboard stats for model:  {0}").format(e))
 
 class ImageGenVisualizer():
-    def output_image_gen_visuals(self, learn:Learner, trn_batch:Tuple, val_batch:Tuple, iter_count:int, tbwriter:SummaryWriter):
-        self._output_visuals(learn=learn, batch=val_batch, iter_count=iter_count, tbwriter=tbwriter, ds_type=DatasetType.Valid)
-        self._output_visuals(learn=learn, batch=trn_batch, iter_count=iter_count, tbwriter=tbwriter, ds_type=DatasetType.Train)
+    def output_image_gen_visuals(self, learn:Learner, trn_batch:Tuple, val_batch:Tuple, iteration:int, tbwriter:SummaryWriter):
+        self._output_visuals(learn=learn, batch=val_batch, iteration=iteration, tbwriter=tbwriter, ds_type=DatasetType.Valid)
+        self._output_visuals(learn=learn, batch=trn_batch, iteration=iteration, tbwriter=tbwriter, ds_type=DatasetType.Train)
 
-    def _output_visuals(self, learn:Learner, batch:Tuple, iter_count:int, tbwriter:SummaryWriter, ds_type: DatasetType):
+    def _output_visuals(self, learn:Learner, batch:Tuple, iteration:int, tbwriter:SummaryWriter, ds_type: DatasetType):
         image_sets = ModelImageSet.get_list_from_model(learn=learn, batch=batch, ds_type=ds_type)
-        self._write_tensorboard_images(image_sets=image_sets, iter_count=iter_count, tbwriter=tbwriter, ds_type=ds_type)
+        self._write_tensorboard_images(image_sets=image_sets, iteration=iteration, tbwriter=tbwriter, ds_type=ds_type)
     
-    def _write_tensorboard_images(self, image_sets:[ModelImageSet], iter_count:int, tbwriter:SummaryWriter, ds_type: DatasetType):
+    def _write_tensorboard_images(self, image_sets:[ModelImageSet], iteration:int, tbwriter:SummaryWriter, ds_type: DatasetType):
         try:
             orig_images = []
             gen_images = []
@@ -98,17 +99,15 @@ class ImageGenVisualizer():
 
             prefix = str(ds_type)
 
-            tbwriter.add_image(prefix + ' orig images', vutils.make_grid(orig_images, normalize=True), iter_count)
-            tbwriter.add_image(prefix + ' gen images', vutils.make_grid(gen_images, normalize=True), iter_count)
-            tbwriter.add_image(prefix + ' real images', vutils.make_grid(real_images, normalize=True), iter_count)
+            tbwriter.add_image(prefix + ' orig images', vutils.make_grid(orig_images, normalize=True), iteration)
+            tbwriter.add_image(prefix + ' gen images', vutils.make_grid(gen_images, normalize=True), iteration)
+            tbwriter.add_image(prefix + ' real images', vutils.make_grid(real_images, normalize=True), iteration)
         except Exception as e:
             print(("Failed to update tensorboard images for model:  {0}").format(e))
 
 
 #--------Below are what you actually want ot use, in practice----------------#
 
-
-
 class LearnerTensorboardWriter(LearnerCallback):
     def __init__(self, learn:Learner, base_dir:Path, name:str, loss_iters:int=25, weight_iters:int=1000, stats_iters:int=1000):
         super().__init__(learn=learn)
@@ -122,6 +121,7 @@ class LearnerTensorboardWriter(LearnerCallback):
         self.weight_vis = ModelHistogramVisualizer()
         self.model_vis = ModelStatsVisualizer() 
         self.data = None
+        self.metrics_root = '/metrics/'
 
     def _update_batches_if_needed(self):
         #one_batch function is extremely slow.  this is an optimization
@@ -133,35 +133,26 @@ class LearnerTensorboardWriter(LearnerCallback):
             self.val_batch = self.learn.data.one_batch(DatasetType.Valid, detach=True, denorm=False, cpu=False)
 
     def _write_model_stats(self, iteration):
-        self.model_vis.write_tensorboard_stats(model=self.learn.model, iter_count=iteration, tbwriter=self.tbwriter) 
+        self.model_vis.write_tensorboard_stats(model=self.learn.model, iteration=iteration, tbwriter=self.tbwriter) 
 
     def _write_training_loss(self, iteration, last_loss):
         trn_loss = to_np(last_loss)
-        self.tbwriter.add_scalar('/loss/trn_loss', trn_loss, iteration)
+        self.tbwriter.add_scalar(self.metrics_root + 'train_loss', trn_loss, iteration)
 
     def _write_weight_histograms(self, iteration):
-        self.weight_vis.write_tensorboard_histograms(model=self.learn.model, iter_count=iteration, tbwriter=self.tbwriter)
+        self.weight_vis.write_tensorboard_histograms(model=self.learn.model, iteration=iteration, tbwriter=self.tbwriter)
 
-    def _write_val_loss(self, iteration, last_metrics):
-        #TODO: Not a fan of this indexing but...what to do?
-        val_loss = last_metrics[0]
-        self.tbwriter.add_scalar('/loss/val_loss', val_loss, iteration)  
-    
-    def _write_metrics(self, iteration):
-        rec = self.learn.recorder
 
-        for i, name in enumerate(rec.names[3:]):
-            if len(rec.metrics) == 0: continue
-            if len(rec.metrics[-1:]) == 0: continue
-            if len(rec.metrics[-1:][0]) == 0: continue
-            value = rec.metrics[-1:][0][i]
-            if value is None: continue
-            self.tbwriter.add_scalar('/metrics/' + name, to_np(value), iteration) 
+    def _write_metrics(self, iteration, last_metrics, start_idx:int=2):
+        recorder = self.learn.recorder
 
+        for i, name in enumerate(recorder.names[start_idx:]):
+            if len(last_metrics) < i+1: return 
+            value = last_metrics[i]
+            self.tbwriter.add_scalar(self.metrics_root + name, value, iteration)  
+  
     def on_batch_end(self, last_loss, metrics, iteration, **kwargs):
-        if iteration==0:
-            return
-
+        if iteration==0: return
         self._update_batches_if_needed()
 
         if iteration % self.loss_iters == 0: 
@@ -174,8 +165,7 @@ class LearnerTensorboardWriter(LearnerCallback):
             self._write_model_stats(iteration)
 
     def on_epoch_end(self, metrics, last_metrics, iteration, **kwargs):
-        self._write_val_loss(iteration, last_metrics)
-        self._write_metrics(iteration)
+        self._write_metrics(iteration, last_metrics)
 
 
 class GANTensorboardWriter(LearnerTensorboardWriter):
@@ -186,59 +176,34 @@ class GANTensorboardWriter(LearnerTensorboardWriter):
         self.visual_iters = visual_iters
         self.img_gen_vis = ImageGenVisualizer()
 
-    #override
-    def _write_training_loss(self, iteration, last_loss):
-        trainer = self.learn.gan_trainer
-        recorder = trainer.recorder
-
-        if len(recorder.losses) > 0:      
-            trn_loss = to_np((recorder.losses[-1:])[0])
-            self.tbwriter.add_scalar('/loss/trn_loss', trn_loss, iteration)
-
     #override
     def _write_weight_histograms(self, iteration):
         trainer = self.learn.gan_trainer
         generator = trainer.generator
         critic = trainer.critic
-
-        self.weight_vis.write_tensorboard_histograms(model=generator, iter_count=iteration, tbwriter=self.tbwriter, name='generator')
-        self.weight_vis.write_tensorboard_histograms(model=critic, iter_count=iteration, tbwriter=self.tbwriter, name='critic')
+        self.weight_vis.write_tensorboard_histograms(model=generator, iteration=iteration, tbwriter=self.tbwriter, name='generator')
+        self.weight_vis.write_tensorboard_histograms(model=critic, iteration=iteration, tbwriter=self.tbwriter, name='critic')
 
     #override
     def _write_model_stats(self, iteration):
         trainer = self.learn.gan_trainer
         generator = trainer.generator
         critic = trainer.critic
-
-        self.model_vis.write_tensorboard_stats(model=generator, iter_count=iteration, tbwriter=self.tbwriter, name='generator')
-        self.model_vis.write_tensorboard_stats(model=critic, iter_count=iteration, tbwriter=self.tbwriter, name='critic')
-
-    #override
-    def _write_val_loss(self, iteration, last_metrics):
-        trainer = self.learn.gan_trainer
-        recorder = trainer.recorder 
-
-        if len(recorder.val_losses) > 0:
-            val_loss = (recorder.val_losses[-1:])[0]
-            self.tbwriter.add_scalar('/loss/val_loss', val_loss, iteration) 
-
+        self.model_vis.write_tensorboard_stats(model=generator, iteration=iteration, tbwriter=self.tbwriter, name='gen_model_stats')
+        self.model_vis.write_tensorboard_stats(model=critic, iteration=iteration, tbwriter=self.tbwriter, name='crit_model_stats')
 
     def _write_images(self, iteration):
         trainer = self.learn.gan_trainer
         recorder = trainer.recorder
-
         gen_mode = trainer.gen_mode
         trainer.switch(gen_mode=True)
         self.img_gen_vis.output_image_gen_visuals(learn=self.learn, trn_batch=self.trn_batch, val_batch=self.val_batch, 
-                                                iter_count=iteration, tbwriter=self.tbwriter)
+                                               iteration=iteration, tbwriter=self.tbwriter)
         trainer.switch(gen_mode=gen_mode)
 
     def on_batch_end(self, metrics, iteration, **kwargs):
         super().on_batch_end(metrics=metrics, iteration=iteration, **kwargs)
-
-        if iteration==0:
-            return
-
+        if iteration==0: return
         if iteration % self.visual_iters == 0:
             self._write_images(iteration)
 
@@ -254,7 +219,7 @@ class ImageGenTensorboardWriter(LearnerTensorboardWriter):
 
     def _write_images(self, iteration):
         self.img_gen_vis.output_image_gen_visuals(learn=self.learn, trn_batch=self.trn_batch, val_batch=self.val_batch, 
-            iter_count=iteration, tbwriter=self.tbwriter)
+            iteration=iteration, tbwriter=self.tbwriter)
 
     def on_batch_end(self, metrics, iteration, **kwargs):
         super().on_batch_end(metrics=metrics, iteration=iteration, **kwargs)