|
@@ -4,6 +4,9 @@ from fastai.vision import *
|
|
|
from fastai.callbacks import *
|
|
|
from fastai.vision.gan import *
|
|
|
from fastai.core import *
|
|
|
+from threading import Thread
|
|
|
+from time import sleep
|
|
|
+from queue import Queue
|
|
|
import statistics
|
|
|
import torchvision.utils as vutils
|
|
|
from tensorboardX import SummaryWriter
|
|
@@ -37,16 +40,50 @@ class ModelGraphVisualizer():
|
|
|
x,y = md.one_batch(ds_type=DatasetType.Valid, detach=False, denorm=False)
|
|
|
tbwriter.add_graph(model=model, input_to_model=x)
|
|
|
|
|
|
+class HistogramRequest():
|
|
|
+ def __init__(self, model:nn.Module, iteration:int, tbwriter:SummaryWriter, name:str):
|
|
|
+ self.params = [(name, values.clone().detach()) for (name, values) in model.named_parameters()]
|
|
|
+ self.iteration = iteration
|
|
|
+ self.tbwriter = tbwriter
|
|
|
+ self.name = name
|
|
|
|
|
|
class ModelHistogramVisualizer():
|
|
|
def __init__(self):
|
|
|
- return
|
|
|
-
|
|
|
+ self.exit = False
|
|
|
+ self.queue = Queue()
|
|
|
+ self.thread = Thread(target=self._queue_processor)
|
|
|
+ self.thread.start()
|
|
|
+
|
|
|
+ def _queue_processor(self):
|
|
|
+ while not self.exit:
|
|
|
+ while not self.queue.empty():
|
|
|
+ request = self.queue.get()
|
|
|
+ self._write_async(request)
|
|
|
+ sleep(0.1)
|
|
|
+
|
|
|
+ def _write_async(self, request:HistogramRequest):
|
|
|
+ try:
|
|
|
+ params = request.params
|
|
|
+ iteration = request.iteration
|
|
|
+ tbwriter = request.tbwriter
|
|
|
+ name = request.name
|
|
|
+
|
|
|
+ for param_name, values in params:
|
|
|
+ tag = name + '/weights/' + param_name
|
|
|
+ tbwriter.add_histogram(tag=tag, values=values, global_step=iteration)
|
|
|
+ except Exception as e:
|
|
|
+ print(("Failed to write model histograms to Tensorboard: {0}").format(e))
|
|
|
+
|
|
|
+ #If this isn't done async then this is sloooooow
|
|
|
def write_tensorboard_histograms(self, model:nn.Module, iteration:int, tbwriter:SummaryWriter, name:str='model'):
|
|
|
- for param_name, values in model.named_parameters():
|
|
|
- tag = name + '/weights/' + param_name
|
|
|
- tbwriter.add_histogram(tag=tag, values=values, global_step=iteration)
|
|
|
+ request = HistogramRequest(model, iteration, tbwriter, name)
|
|
|
+ self.queue.put(request)
|
|
|
+
|
|
|
+ def __del__(self):
|
|
|
+ self.exit = True
|
|
|
+ self.thread.join()
|
|
|
|
|
|
+
|
|
|
|
|
|
class ModelStatsVisualizer():
|
|
|
def __init__(self):
|
|
@@ -103,8 +140,7 @@ class ImageGenVisualizer():
|
|
|
tbwriter=tbwriter, ds_type=DatasetType.Train)
|
|
|
|
|
|
def _output_visuals(self, learn:Learner, batch:Tuple, iteration:int, tbwriter:SummaryWriter, ds_type:DatasetType):
|
|
|
- image_sets = ModelImageSet.get_list_from_model(
|
|
|
- learn=learn, batch=batch, ds_type=ds_type)
|
|
|
+ image_sets = ModelImageSet.get_list_from_model(learn=learn, batch=batch, ds_type=ds_type)
|
|
|
self._write_tensorboard_images(
|
|
|
image_sets=image_sets, iteration=iteration, tbwriter=tbwriter, ds_type=ds_type)
|
|
|
|