Просмотр исходного кода

Optimizing performance of tensorboard histogram writes

Jason Antic 6 лет назад
Родитель
Сommit
20af68f395
1 измененных файлов с 43 добавлено и 7 удалено
  1. 43 7
      fasterai/tensorboard.py

+ 43 - 7
fasterai/tensorboard.py

@@ -4,6 +4,9 @@ from fastai.vision import *
 from fastai.callbacks import *
 from fastai.vision.gan import *
 from fastai.core import *
+from threading import Thread
+from time import sleep
+from queue import Queue
 import statistics
 import torchvision.utils as vutils
 from tensorboardX import SummaryWriter
@@ -37,16 +40,50 @@ class ModelGraphVisualizer():
         x,y = md.one_batch(ds_type=DatasetType.Valid, detach=False, denorm=False)
         tbwriter.add_graph(model=model, input_to_model=x)
 
+class HistogramRequest():
+    def __init__(self, model:nn.Module, iteration:int, tbwriter:SummaryWriter, name:str):
+        self.params = [(name, values.clone().detach()) for (name, values) in model.named_parameters()]
+        self.iteration = iteration
+        self.tbwriter = tbwriter
+        self.name = name
 
 class ModelHistogramVisualizer():
     def __init__(self):
-        return
-
+        self.exit = False
+        self.queue = Queue()
+        self.thread = Thread(target=self._queue_processor)
+        self.thread.start()
+
+    def _queue_processor(self):
+        while not self.exit:
+            while not self.queue.empty():
+                request = self.queue.get()
+                self._write_async(request)
+            sleep(0.1)
+
+    def _write_async(self, request:HistogramRequest):
+        try:
+            params = request.params
+            iteration = request.iteration
+            tbwriter = request.tbwriter
+            name = request.name
+
+            for param_name, values in params:
+                tag = name + '/weights/' + param_name
+                tbwriter.add_histogram(tag=tag, values=values, global_step=iteration)
+        except Exception as e:
+            print(("Failed to write model histograms to Tensorboard:  {0}").format(e))
+
+    #If this isn't done async then this is sloooooow
     def write_tensorboard_histograms(self, model:nn.Module, iteration:int, tbwriter:SummaryWriter, name:str='model'):
-        for param_name, values in model.named_parameters():
-            tag = name + '/weights/' + param_name
-            tbwriter.add_histogram(tag=tag, values=values, global_step=iteration)
+        request = HistogramRequest(model, iteration, tbwriter, name)
+        self.queue.put(request)
+
+    def __del__(self):
+        self.exit = True
+        self.thread.join()
 
+    
 
 class ModelStatsVisualizer():
     def __init__(self):
@@ -103,8 +140,7 @@ class ImageGenVisualizer():
                              tbwriter=tbwriter, ds_type=DatasetType.Train)
 
     def _output_visuals(self, learn:Learner, batch:Tuple, iteration:int, tbwriter:SummaryWriter, ds_type:DatasetType):
-        image_sets = ModelImageSet.get_list_from_model(
-            learn=learn, batch=batch, ds_type=ds_type)
+        image_sets = ModelImageSet.get_list_from_model(learn=learn, batch=batch, ds_type=ds_type)
         self._write_tensorboard_images(
             image_sets=image_sets, iteration=iteration, tbwriter=tbwriter, ds_type=ds_type)