Bläddra i källkod

Cleanup of Tensorboard functionality

Jason Antic 6 år sedan
förälder
incheckning
bae06dddf7
3 ändrade filer med 195 tillägg och 197 borttagningar
  1. 56 30
      ColorizeVisualization.ipynb
  2. 0 31
      fasterai/images.py
  3. 139 136
      fasterai/tensorboard.py

+ 56 - 30
ColorizeVisualization.ipynb

@@ -1,5 +1,15 @@
 {
  "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import os\n",
+    "os.environ['CUDA_VISIBLE_DEVICES']='3' "
+   ]
+  },
   {
    "cell_type": "code",
    "execution_count": null,
@@ -30,7 +40,6 @@
     "from fasterai.generators import *\n",
     "from pathlib import Path\n",
     "from itertools import repeat\n",
-    "torch.cuda.set_device(2)\n",
     "plt.style.use('dark_background')\n",
     "torch.backends.cudnn.benchmark=True"
    ]
@@ -47,7 +56,18 @@
     "#11GB can take a factor of 42 max.  Performance generally gracefully degrades with lower factors, \n",
     "#though you may also find that certain images will actually render better at lower numbers.  \n",
     "#This tends to be the case with the oldest photos.\n",
-    "render_factor=62"
+    "render_factor=16"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def colorize_gen_learner_exp(data:ImageDataBunch, gen_loss=FeatureLoss(), arch=models.resnet34):\n",
+    "    return unet_learner3(data, arch, wd=1e-3, blur=True, norm_type=NormType.Spectral,\n",
+    "                        self_attention=True, y_range=(-3.,3.), loss_func=gen_loss, nf_factor=1.5)"
    ]
   },
   {
@@ -56,10 +76,11 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "data = get_colorize_data(sz=256, bs=4, crappy_path=path, good_path=path, keep_pct=0.01)\n",
-    "learn = colorize_gen_learner(data=data)\n",
+    "data = get_colorize_data(sz=128, bs=32, crappy_path=path, good_path=path, keep_pct=0.01)\n",
+    "learn = colorize_gen_learner_exp(data=data)\n",
     "learn.path = Path('data/imagenet/ILSVRC/Data/CLS-LOC/bandw')\n",
-    "learn.load('gen-pre-a')\n",
+    "learn.load('colorize3_gen_96')\n",
+    "learn.model.eval()\n",
     "filtr = MasterFilter([ColorizerFilter(learn=learn)], render_factor=render_factor)\n",
     "vis = ModelImageVisualizer(filtr, results_dir='result_images')"
    ]
@@ -376,7 +397,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "vis.plot_transformed_image(\"test_images/teddy_rubble.jpg\", render_factor=40)"
+    "vis.plot_transformed_image(\"test_images/teddy_rubble.jpg\")"
    ]
   },
   {
@@ -457,7 +478,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "vis.plot_transformed_image(\"test_images/AnselAdamsYosemite.jpg\", render_factor=35)"
+    "vis.plot_transformed_image(\"test_images/AnselAdamsYosemite.jpg\")"
    ]
   },
   {
@@ -466,7 +487,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "vis.plot_transformed_image(\"test_images/unnamed.jpg\", render_factor=40)"
+    "vis.plot_transformed_image(\"test_images/unnamed.jpg\")"
    ]
   },
   {
@@ -601,7 +622,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "vis.plot_transformed_image(\"test_images/women-bikers.png\", figsize=(60,60), render_factor=42)"
+    "vis.plot_transformed_image(\"test_images/women-bikers.png\")"
    ]
   },
   {
@@ -664,7 +685,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "vis.plot_transformed_image(\"test_images/poverty.jpg\", render_factor=40)"
+    "vis.plot_transformed_image(\"test_images/poverty.jpg\")"
    ]
   },
   {
@@ -709,7 +730,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "vis.plot_transformed_image(\"test_images/FatMenClub.jpg\", render_factor=31)"
+    "vis.plot_transformed_image(\"test_images/FatMenClub.jpg\")"
    ]
   },
   {
@@ -718,7 +739,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "vis.plot_transformed_image(\"test_images/EgyptColosus.jpg\", render_factor=31)"
+    "vis.plot_transformed_image(\"test_images/EgyptColosus.jpg\")"
    ]
   },
   {
@@ -781,7 +802,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "vis.plot_transformed_image(\"test_images/covered-wagons-traveling.jpg\", render_factor=19)"
+    "vis.plot_transformed_image(\"test_images/covered-wagons-traveling.jpg\")"
    ]
   },
   {
@@ -1042,7 +1063,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "vis.plot_transformed_image(\"test_images/1910Racket.png\", render_factor=31)"
+    "vis.plot_transformed_image(\"test_images/1910Racket.png\")"
    ]
   },
   {
@@ -1096,7 +1117,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "vis.plot_transformed_image(\"test_images/1860Girls.jpg\", render_factor=41)"
+    "vis.plot_transformed_image(\"test_images/1860Girls.jpg\")"
    ]
   },
   {
@@ -1186,7 +1207,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "vis.plot_transformed_image(\"test_images/ViennaBoys1880s.png\", render_factor=17)"
+    "vis.plot_transformed_image(\"test_images/ViennaBoys1880s.png\")"
    ]
   },
   {
@@ -1213,7 +1234,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "vis.plot_transformed_image(\"test_images/ArkansasCowboys1880s.jpg\", render_factor=24)"
+    "vis.plot_transformed_image(\"test_images/ArkansasCowboys1880s.jpg\")"
    ]
   },
   {
@@ -1222,7 +1243,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "vis.plot_transformed_image(\"test_images/Ballet1890Russia.jpg\", render_factor=34)"
+    "vis.plot_transformed_image(\"test_images/Ballet1890Russia.jpg\")"
    ]
   },
   {
@@ -1258,7 +1279,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "vis.plot_transformed_image(\"test_images/Harlem1932.jpg\", render_factor=41)"
+    "vis.plot_transformed_image(\"test_images/Harlem1932.jpg\")"
    ]
   },
   {
@@ -1321,7 +1342,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "vis.plot_transformed_image(\"test_images/1890sTouristsEgypt.png\", render_factor=40)"
+    "vis.plot_transformed_image(\"test_images/1890sTouristsEgypt.png\")"
    ]
   },
   {
@@ -1357,7 +1378,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "vis.plot_transformed_image(\"test_images/Texas1938Woman.png\", render_factor=38)"
+    "vis.plot_transformed_image(\"test_images/Texas1938Woman.png\")"
    ]
   },
   {
@@ -1582,7 +1603,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "vis.plot_transformed_image(\"test_images/Apsaroke1908.png\", render_factor=38)"
+    "vis.plot_transformed_image(\"test_images/Apsaroke1908.png\")"
    ]
   },
   {
@@ -1672,7 +1693,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "vis.plot_transformed_image(\"test_images/MadisonSquare1900.jpg\", render_factor=40)"
+    "vis.plot_transformed_image(\"test_images/MadisonSquare1900.jpg\")"
    ]
   },
   {
@@ -1762,7 +1783,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "vis.plot_transformed_image(\"test_images/SenecaNative1908.jpg\", render_factor=19)"
+    "vis.plot_transformed_image(\"test_images/SenecaNative1908.jpg\")"
    ]
   },
   {
@@ -1780,7 +1801,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "vis.plot_transformed_image(\"test_images/1892WaterLillies.jpg\", render_factor=33)"
+    "vis.plot_transformed_image(\"test_images/1892WaterLillies.jpg\")"
    ]
   },
   {
@@ -2050,7 +2071,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "vis.plot_transformed_image(\"test_images/Donegal1907Yarn.jpg\", render_factor=40)"
+    "vis.plot_transformed_image(\"test_images/Donegal1907Yarn.jpg\")"
    ]
   },
   {
@@ -2518,7 +2539,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "vis.plot_transformed_image(\"test_images/TheatreGroupBombay1875.jpg\", render_factor=42)"
+    "vis.plot_transformed_image(\"test_images/TheatreGroupBombay1875.jpg\")"
    ]
   },
   {
@@ -2680,7 +2701,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "vis.plot_transformed_image(\"test_images/SecondHandClothesLondonLate1800s.jpg\", render_factor=43)"
+    "vis.plot_transformed_image(\"test_images/SecondHandClothesLondonLate1800s.jpg\")"
    ]
   },
   {
@@ -2740,14 +2761,19 @@
    "execution_count": null,
    "metadata": {},
    "outputs": [],
-   "source": []
+   "source": [
+    "vis.plot_transformed_image(\"test_images//ParisLate1800s.jpg\", render_factor=40)"
+   ]
   },
   {
    "cell_type": "code",
    "execution_count": null,
    "metadata": {},
    "outputs": [],
-   "source": []
+   "source": [
+    "for factor in range(20,66):\n",
+    "    vis.plot_transformed_image(\"test_images/1890sMedStudents.png\", render_factor=factor)"
+   ]
   },
   {
    "cell_type": "markdown",

+ 0 - 31
fasterai/images.py

@@ -1,31 +0,0 @@
-import numpy as np
-from fastai.core import *
-from fastai.vision import *
-from pathlib import Path
-from itertools import repeat
-from PIL import Image as PilImage
-from numpy import ndarray
-from datetime import datetime
-from fastai.vision.image import *
-
-
-class ModelImageSet():
-    @staticmethod
-    def get_list_from_model(learn: Learner, ds_type: DatasetType, batch:Tuple)->[]:
-        image_sets = []
-        x,y = batch[0],batch[1]
-        #x,y = learn.data.one_batch(ds_type, detach=False, denorm=False)
-        preds = learn.pred_batch(ds_type=ds_type, batch=(x,y), reconstruct=True)
-        
-        for orig,real,gen_image in zip(x,y,preds):
-            orig_image = Image(orig)
-            real_image = Image(real)
-            image_set = ModelImageSet(orig_image, real_image, gen_image)
-            image_sets.append(image_set)
-
-        return image_sets  
-
-    def __init__(self, orig:Image, real:Image, gen:Image):
-        self.orig=orig
-        self.real=real
-        self.gen=gen

+ 139 - 136
fasterai/tensorboard.py

@@ -5,132 +5,133 @@ from fastai.callbacks import *
 from fastai.vision.gan import *
 from fastai.core import *
 import statistics
-from .images import ModelImageSet
 import torchvision.utils as vutils
 from tensorboardX import SummaryWriter
 
-
+class ModelImageSet():
+    @staticmethod
+    def get_list_from_model(learn:Learner, ds_type:DatasetType, batch:Tuple)->[]:
+        image_sets = []
+        x,y = batch[0],batch[1]
+        preds = learn.pred_batch(ds_type=ds_type, batch=(x,y), reconstruct=True)
+        
+        for orig_px, real_px, gen in zip(x,y,preds):
+            orig = Image(px=orig_px)
+            real = Image(px=real_px)
+            image_set = ModelImageSet(orig=orig, real=real, gen=gen)
+            image_sets.append(image_set)
+
+        return image_sets  
+
+    def __init__(self, orig:Image, real:Image, gen:Image):
+        self.orig = orig
+        self.real = real
+        self.gen = gen
+
+#TODO:  There aren't any callbacks using this yet.  Not sure if we want this included (not sure if it's useful, honestly)
 class ModelGraphVisualizer():
     def __init__(self):
         return
 
-    def write_model_graph_to_tensorboard(self, md: DataBunch, model: nn.Module, tbwriter: SummaryWriter):
-        try:
-            x, y = md.one_batch(DatasetType.Valid, detach=False, denorm=False)
-            tbwriter.add_graph(model, x)
-        except Exception as e:
-            print(("Failed to generate graph for model: {0}. Note that there's an outstanding issue with "
-                   + "scopes being addressed here:  https://github.com/pytorch/pytorch/pull/12400").format(e))
+    def write_model_graph_to_tensorboard(self, md:DataBunch, model:nn.Module, tbwriter:SummaryWriter):
+        x,y = md.one_batch(ds_type=DatasetType.Valid, detach=False, denorm=False)
+        tbwriter.add_graph(model=model, input_to_model=x)
 
 
 class ModelHistogramVisualizer():
     def __init__(self):
         return
 
-    def write_tensorboard_histograms(self, model: nn.Module, iteration: int, tbwriter: SummaryWriter, name: str = 'model'):
-        try:
-            for param_name, param in model.named_parameters():
-                tbwriter.add_histogram(
-                    name + '/weights/' + param_name, param, iteration)
-        except Exception as e:
-            print(("Failed to update histogram for model:  {0}").format(e))
+    def write_tensorboard_histograms(self, model:nn.Module, iteration:int, tbwriter:SummaryWriter, name:str='model'):
+        for param_name, values in model.named_parameters():
+            tag = name + '/weights/' + param_name
+            tbwriter.add_histogram(tag=tag, values=values, global_step=iteration)
 
 
 class ModelStatsVisualizer():
     def __init__(self):
-        return
+        self.gradients_root = '/gradients/'
 
-    def write_tensorboard_stats(self, model: nn.Module, iteration: int, tbwriter: SummaryWriter, name: str = 'model_stats'):
-        try:
-            gradients = [x.grad for x in model.parameters()
-                         if x.grad is not None]
-            gradient_nps = [to_np(x.data) for x in gradients]
+    def write_tensorboard_stats(self, model:nn.Module, iteration:int, tbwriter:SummaryWriter, name:str='model_stats'):
+        gradients = [x.grad for x in model.parameters() if x.grad is not None]
+        gradient_nps = [to_np(x.data) for x in gradients]
 
-            if len(gradients) == 0:
-                return
+        if len(gradients) == 0: return
 
-            avg_norm = sum(x.data.norm() for x in gradients)/len(gradients)
-            tbwriter.add_scalar(
-                name + '/gradients/avg_norm', avg_norm, iteration)
+        avg_norm = sum(x.data.norm() for x in gradients)/len(gradients)
+        tbwriter.add_scalar(
+            tag=name + self.gradients_root + 'avg_norm', scalar_value=avg_norm, global_step=iteration)
 
-            median_norm = statistics.median(x.data.norm() for x in gradients)
-            tbwriter.add_scalar(
-                name + '/gradients/median_norm', median_norm, iteration)
+        median_norm = statistics.median(x.data.norm() for x in gradients)
+        tbwriter.add_scalar(
+            tag=name + self.gradients_root + 'median_norm', scalar_value=median_norm, global_step=iteration)
 
-            max_norm = max(x.data.norm() for x in gradients)
-            tbwriter.add_scalar(
-                name + '/gradients/max_norm', max_norm, iteration)
+        max_norm = max(x.data.norm() for x in gradients)
+        tbwriter.add_scalar(
+            tag=name + self.gradients_root + 'max_norm', scalar_value=max_norm, global_step=iteration)
 
-            min_norm = min(x.data.norm() for x in gradients)
-            tbwriter.add_scalar(
-                name + '/gradients/min_norm', min_norm, iteration)
+        min_norm = min(x.data.norm() for x in gradients)
+        tbwriter.add_scalar(
+            tag=name + self.gradients_root + 'min_norm', scalar_value=min_norm, global_step=iteration)
 
-            num_zeros = sum((np.asarray(x) == 0.0).sum() for x in gradient_nps)
-            tbwriter.add_scalar(
-                name + '/gradients/num_zeros', num_zeros, iteration)
+        num_zeros = sum((np.asarray(x) == 0.0).sum() for x in gradient_nps)
+        tbwriter.add_scalar(
+            tag=name + self.gradients_root + 'num_zeros', scalar_value=num_zeros, global_step=iteration)
 
-            avg_gradient = sum(x.data.mean() for x in gradients)/len(gradients)
-            tbwriter.add_scalar(
-                name + '/gradients/avg_gradient', avg_gradient, iteration)
+        avg_gradient = sum(x.data.mean() for x in gradients)/len(gradients)
+        tbwriter.add_scalar(
+            tag=name + self.gradients_root + 'avg_gradient', scalar_value=avg_gradient, global_step=iteration)
 
-            median_gradient = statistics.median(
-                x.data.median() for x in gradients)
-            tbwriter.add_scalar(
-                name + '/gradients/median_gradient', median_gradient, iteration)
+        median_gradient = statistics.median(x.data.median() for x in gradients)
+        tbwriter.add_scalar(
+            tag=name + self.gradients_root + 'median_gradient', scalar_value=median_gradient, global_step=iteration)
 
-            max_gradient = max(x.data.max() for x in gradients)
-            tbwriter.add_scalar(
-                name + '/gradients/max_gradient', max_gradient, iteration)
+        max_gradient = max(x.data.max() for x in gradients)
+        tbwriter.add_scalar(
+            tag=name + self.gradients_root + 'max_gradient', scalar_value=max_gradient, global_step=iteration)
 
-            min_gradient = min(x.data.min() for x in gradients)
-            tbwriter.add_scalar(
-                name + '/gradients/min_gradient', min_gradient, iteration)
-        except Exception as e:
-            print(
-                ("Failed to update tensorboard stats for model:  {0}").format(e))
+        min_gradient = min(x.data.min() for x in gradients)
+        tbwriter.add_scalar(
+            tag=name + self.gradients_root + 'min_gradient', scalar_value=min_gradient, global_step=iteration)
 
 
 class ImageGenVisualizer():
-    def output_image_gen_visuals(self, learn: Learner, trn_batch: Tuple, val_batch: Tuple, iteration: int, tbwriter: SummaryWriter):
+    def output_image_gen_visuals(self, learn:Learner, trn_batch:Tuple, val_batch:Tuple, iteration:int, tbwriter:SummaryWriter):
         self._output_visuals(learn=learn, batch=val_batch, iteration=iteration,
                              tbwriter=tbwriter, ds_type=DatasetType.Valid)
         self._output_visuals(learn=learn, batch=trn_batch, iteration=iteration,
                              tbwriter=tbwriter, ds_type=DatasetType.Train)
 
-    def _output_visuals(self, learn: Learner, batch: Tuple, iteration: int, tbwriter: SummaryWriter, ds_type: DatasetType):
+    def _output_visuals(self, learn:Learner, batch:Tuple, iteration:int, tbwriter:SummaryWriter, ds_type:DatasetType):
         image_sets = ModelImageSet.get_list_from_model(
             learn=learn, batch=batch, ds_type=ds_type)
         self._write_tensorboard_images(
             image_sets=image_sets, iteration=iteration, tbwriter=tbwriter, ds_type=ds_type)
 
-    def _write_tensorboard_images(self, image_sets: [ModelImageSet], iteration: int, tbwriter: SummaryWriter, ds_type: DatasetType):
-        try:
-            orig_images = []
-            gen_images = []
-            real_images = []
+    def _write_tensorboard_images(self, image_sets:[ModelImageSet], iteration:int, tbwriter:SummaryWriter, ds_type:DatasetType):
+        orig_images = []
+        gen_images = []
+        real_images = []
 
-            for image_set in image_sets:
-                orig_images.append(image_set.orig.px)
-                gen_images.append(image_set.gen.px)
-                real_images.append(image_set.real.px)
+        for image_set in image_sets:
+            orig_images.append(image_set.orig.px)
+            gen_images.append(image_set.gen.px)
+            real_images.append(image_set.real.px)
 
-            prefix = str(ds_type)
+        prefix = ds_type.name
 
-            tbwriter.add_image(
-                prefix + ' orig images', vutils.make_grid(orig_images, normalize=True), iteration)
-            tbwriter.add_image(
-                prefix + ' gen images', vutils.make_grid(gen_images, normalize=True), iteration)
-            tbwriter.add_image(
-                prefix + ' real images', vutils.make_grid(real_images, normalize=True), iteration)
-        except Exception as e:
-            print(
-                ("Failed to update tensorboard images for model:  {0}").format(e))
+        tbwriter.add_image(
+            tag=prefix + ' orig images', img_tensor=vutils.make_grid(orig_images, normalize=True), global_step=iteration)
+        tbwriter.add_image(
+            tag=prefix + ' gen images', img_tensor=vutils.make_grid(gen_images, normalize=True), global_step=iteration)
+        tbwriter.add_image(
+            tag=prefix + ' real images', img_tensor=vutils.make_grid(real_images, normalize=True), global_step=iteration)
 
 
 #--------Below are what you actually want to use, in practice----------------#
 
 class LearnerTensorboardWriter(LearnerCallback):
-    def __init__(self, learn: Learner, base_dir: Path, name: str, loss_iters: int = 25, weight_iters: int = 1000, stats_iters: int = 1000):
+    def __init__(self, learn:Learner, base_dir:Path, name:str, loss_iters:int=25, weight_iters:int=1000, stats_iters:int=1000):
         super().__init__(learn=learn)
         self.base_dir = base_dir
         self.name = name
@@ -151,60 +152,58 @@ class LearnerTensorboardWriter(LearnerCallback):
         if update_batches:
             self.data = self.learn.data
             self.trn_batch = self.learn.data.one_batch(
-                DatasetType.Train, detach=True, denorm=False, cpu=False)
+                ds_type=DatasetType.Train, detach=True, denorm=False, cpu=False)
             self.val_batch = self.learn.data.one_batch(
-                DatasetType.Valid, detach=True, denorm=False, cpu=False)
+                ds_type=DatasetType.Valid, detach=True, denorm=False, cpu=False)
 
-    def _write_model_stats(self, iteration):
+    def _write_model_stats(self, iteration:int):
         self.model_vis.write_tensorboard_stats(
             model=self.learn.model, iteration=iteration, tbwriter=self.tbwriter)
 
-    def _write_training_loss(self, iteration, last_loss):
-        trn_loss = to_np(last_loss)
-        self.tbwriter.add_scalar(
-            self.metrics_root + 'train_loss', trn_loss, iteration)
+    def _write_training_loss(self, iteration:int, last_loss:Tensor):
+        scalar_value = to_np(last_loss)
+        tag = self.metrics_root + 'train_loss'
+        self.tbwriter.add_scalar(tag=tag, scalar_value=scalar_value, global_step=iteration)
 
-    def _write_weight_histograms(self, iteration):
+    def _write_weight_histograms(self, iteration:int):
         self.weight_vis.write_tensorboard_histograms(
             model=self.learn.model, iteration=iteration, tbwriter=self.tbwriter)
 
-    def _write_metrics(self, iteration, last_metrics, start_idx: int = 2):
+    #TODO:  Relying on a specific hardcoded start_idx here isn't great.  Is there a better solution?
+    def _write_metrics(self, iteration:int, last_metrics:MetricsList, start_idx:int=2):
         recorder = self.learn.recorder
 
         for i, name in enumerate(recorder.names[start_idx:]):
-            if len(last_metrics) < i+1:
-                return
-            value = last_metrics[i]
-            self.tbwriter.add_scalar(
-                self.metrics_root + name, value, iteration)
-
-    def on_batch_end(self, last_loss, metrics, iteration, **kwargs):
-        if iteration == 0:
-            return
+            if len(last_metrics) < i+1: return
+            scalar_value = last_metrics[i]
+            tag = self.metrics_root + name
+            self.tbwriter.add_scalar(tag=tag, scalar_value=scalar_value, global_step=iteration)
+
+    def on_batch_end(self, last_loss:Tensor, iteration:int, **kwargs):
+        if iteration == 0: return
         self._update_batches_if_needed()
 
         if iteration % self.loss_iters == 0:
-            self._write_training_loss(iteration, last_loss)
+            self._write_training_loss(iteration=iteration, last_loss=last_loss)
 
         if iteration % self.weight_iters == 0:
-            self._write_weight_histograms(iteration)
+            self._write_weight_histograms(iteration=iteration)
 
     # Doing stuff here that requires gradient info, because they get zeroed out afterwards in training loop
-    def on_backward_end(self, iteration, **kwargs):
-        if iteration == 0:
-            return
+    def on_backward_end(self, iteration:int, **kwargs):
+        if iteration == 0: return
         self._update_batches_if_needed()
 
         if iteration % self.stats_iters == 0:
-            self._write_model_stats(iteration)
-
-    def on_epoch_end(self, metrics, last_metrics, iteration, **kwargs):
-        self._write_metrics(iteration, last_metrics)
+            self._write_model_stats(iteration=iteration)
 
+    def on_epoch_end(self, last_metrics:MetricsList, iteration:int, **kwargs):
+        self._write_metrics(iteration=iteration, last_metrics=last_metrics)
 
+# TODO:  We're overriding almost everything here.  Seems like a good idea to question that ("is a" vs "has a")
 class GANTensorboardWriter(LearnerTensorboardWriter):
-    def __init__(self, learn: Learner, base_dir: Path, name: str, loss_iters: int = 25, weight_iters: int = 1000,
-                 stats_iters: int = 1000, visual_iters: int = 100):
+    def __init__(self, learn:Learner, base_dir:Path, name:str, loss_iters:int=25, weight_iters:int=1000,
+                 stats_iters:int=1000, visual_iters:int=100):
         super().__init__(learn=learn, base_dir=base_dir, name=name, loss_iters=loss_iters,
                          weight_iters=weight_iters, stats_iters=stats_iters)
         self.visual_iters = visual_iters
@@ -213,7 +212,7 @@ class GANTensorboardWriter(LearnerTensorboardWriter):
         self.crit_stats_updated = True
 
     # override
-    def _write_weight_histograms(self, iteration):
+    def _write_weight_histograms(self, iteration:int):
         trainer = self.learn.gan_trainer
         generator = trainer.generator
         critic = trainer.critic
@@ -223,81 +222,85 @@ class GANTensorboardWriter(LearnerTensorboardWriter):
             model=critic, iteration=iteration, tbwriter=self.tbwriter, name='critic')
 
     # override
-    def _write_model_stats(self, iteration):
+    def _write_model_stats(self, iteration:int):
         trainer = self.learn.gan_trainer
         generator = trainer.generator
         critic = trainer.critic
 
-        # Don't want to write stats when model has zeroed out gradients
+        # Don't want to write stats when model is not iterated on and hence has zeroed out gradients
         gen_mode = trainer.gen_mode
 
-        if gen_mode:
+        if gen_mode and not self.gen_stats_updated:
             self.model_vis.write_tensorboard_stats(
                 model=generator, iteration=iteration, tbwriter=self.tbwriter, name='gen_model_stats')
             self.gen_stats_updated = True
-        else:
+
+        if not gen_mode and not self.crit_stats_updated:
             self.model_vis.write_tensorboard_stats(
                 model=critic, iteration=iteration, tbwriter=self.tbwriter, name='crit_model_stats')
             self.crit_stats_updated = True
 
     # override
-    def _write_training_loss(self, iteration, last_loss):
+    def _write_training_loss(self, iteration:int, last_loss:Tensor):
         trainer = self.learn.gan_trainer
         recorder = trainer.recorder
 
         if len(recorder.losses) > 0:
-            trn_loss = to_np((recorder.losses[-1:])[0])
-            self.tbwriter.add_scalar(
-                self.metrics_root + 'train_loss', trn_loss, iteration)
+            scalar_value = to_np((recorder.losses[-1:])[0])
+            tag = self.metrics_root + 'train_loss'
+            self.tbwriter.add_scalar(tag=tag, scalar_value=scalar_value, global_step=iteration)
 
-    def _write_images(self, iteration):
+    def _write_images(self, iteration:int):
         trainer = self.learn.gan_trainer
+        #TODO:  Switching gen_mode temporarily seems a bit hacky here.  Certainly not a good side-effect.  Is there a better way?
         gen_mode = trainer.gen_mode
-        trainer.switch(gen_mode=True)
-        self.img_gen_vis.output_image_gen_visuals(learn=self.learn, trn_batch=self.trn_batch, val_batch=self.val_batch,
-                                                  iteration=iteration, tbwriter=self.tbwriter)
-        trainer.switch(gen_mode=gen_mode)
+
+        try:
+            trainer.switch(gen_mode=True)
+            self.img_gen_vis.output_image_gen_visuals(learn=self.learn, trn_batch=self.trn_batch, val_batch=self.val_batch,
+                                                    iteration=iteration, tbwriter=self.tbwriter)
+        finally:                                      
+            trainer.switch(gen_mode=gen_mode)
 
     # override
-    def on_batch_end(self, metrics, iteration, **kwargs):
-        super().on_batch_end(metrics=metrics, iteration=iteration, **kwargs)
-        if iteration == 0:
-            return
+    def on_batch_end(self, iteration:int, **kwargs):
+        super().on_batch_end(iteration=iteration, **kwargs)
+        if iteration == 0: return
         if iteration % self.visual_iters == 0:
-            self._write_images(iteration)
+            self._write_images(iteration=iteration)
 
     # override
-    def on_backward_end(self, iteration, **kwargs):
-        if iteration == 0:
-            return
+    def on_backward_end(self, iteration:int, **kwargs):
+        if iteration == 0: return
         self._update_batches_if_needed()
 
+        #TODO:  This could perhaps be implemented as queues of requests instead but that seemed like overkill. 
+        # But I'm not the biggest fan of maintaining these boolean flags either... Review pls.
         if iteration % self.stats_iters == 0:
             self.gen_stats_updated = False
             self.crit_stats_updated = False
 
         if not (self.gen_stats_updated and self.crit_stats_updated):
-            self._write_model_stats(iteration)
+            self._write_model_stats(iteration=iteration)
 
 
 class ImageGenTensorboardWriter(LearnerTensorboardWriter):
-    def __init__(self, learn: Learner, base_dir: Path, name: str, loss_iters: int = 25, weight_iters: int = 1000,
+    def __init__(self, learn:Learner, base_dir:Path, name:str, loss_iters:int=25, weight_iters:int=1000,
                  stats_iters: int = 1000, visual_iters: int = 100):
         super().__init__(learn=learn, base_dir=base_dir, name=name, loss_iters=loss_iters, weight_iters=weight_iters,
                          stats_iters=stats_iters)
         self.visual_iters = visual_iters
         self.img_gen_vis = ImageGenVisualizer()
 
-    def _write_images(self, iteration):
+    def _write_images(self, iteration:int):
         self.img_gen_vis.output_image_gen_visuals(learn=self.learn, trn_batch=self.trn_batch, val_batch=self.val_batch,
                                                   iteration=iteration, tbwriter=self.tbwriter)
 
     # override
-    def on_batch_end(self, metrics, iteration, **kwargs):
-        super().on_batch_end(metrics=metrics, iteration=iteration, **kwargs)
+    def on_batch_end(self, iteration:int, **kwargs):
+        super().on_batch_end(iteration=iteration, **kwargs)
 
-        if iteration == 0:
-            return
+        if iteration == 0: return
 
         if iteration % self.visual_iters == 0:
-            self._write_images(iteration)
+            self._write_images(iteration=iteration)