Преглед на файлове

Fixing visualization renders; Adding wasserstein based feature loss; Various fixes to tensorboard stuff

Jason Antic преди 6 години
родител
ревизия
ffa39c0886
променени са 3 файла, в които са добавени 219 реда и са изтрити 102 реда
  1. 5 5
      fasterai/filters.py
  2. 69 0
      fasterai/loss.py
  3. 145 97
      fasterai/tensorboard.py

+ 5 - 5
fasterai/filters.py

@@ -43,11 +43,11 @@ class BaseFilter(IFilter):
         x.div_(255)
         x,y = self.norm((x,x), do_x=True)
         result = self.learn.pred_batch(ds_type=DatasetType.Valid, 
-            batch=(x[None].cuda(),y[None]), reconstruct=False)
-        result = result[0]
-        result = self.denorm(result, do_x=True)
-        result = image2np(result*255).astype(np.uint8)
-        return PilImage.fromarray(result)
+            batch=(x[None].cuda(),y[None]), reconstruct=True)
+        out = result[0]
+        out = self.denorm(out.px, do_x=False)
+        out = image2np(out*255).astype(np.uint8)
+        return PilImage.fromarray(out)
 
     def _unsquare(self, image:PilImage, orig:PilImage)->PilImage:
         targ_sz = orig.size

+ 69 - 0
fasterai/loss.py

@@ -88,3 +88,72 @@ class FeatureLoss2(nn.Module):
     
     def __del__(self): 
         self.hooks.remove()
+
+
+#Includes wasserstein loss
+class FeatureLoss3(nn.Module):
+    def __init__(self, layer_wgts=[5,15,2], wass_wgts=[3.0,0.7,0.01]):
+        super().__init__()
+
+        self.m_feat = models.vgg16_bn(True).features.cuda().eval()
+        requires_grad(self.m_feat, False)
+        blocks = [i-1 for i,o in enumerate(children(self.m_feat)) if isinstance(o,nn.MaxPool2d)]
+        layer_ids = blocks[2:5]
+        self.loss_features = [self.m_feat[i] for i in layer_ids]
+        self.hooks = hook_outputs(self.loss_features, detach=False)
+        self.wgts = layer_wgts
+        self.wass_wgts = wass_wgts
+        self.metric_names = ['pixel',] + [f'feat_{i}' for i in range(len(layer_ids))
+              ] + [f'wass_{i}' for i in range(len(layer_ids))]
+        self.base_loss = F.l1_loss
+
+    def _make_features(self, x, clone=False):
+        self.m_feat(x)
+        return [(o.clone() if clone else o) for o in self.hooks.stored]
+
+    def _calc_2_moments(self, tensor):
+        chans = tensor.shape[1]
+        tensor = tensor.view(1, chans, -1)
+        n = tensor.shape[2] 
+        mu = tensor.mean(2)
+        tensor = (tensor - mu[:,:,None]).squeeze(0)
+        cov = torch.mm(tensor, tensor.t()) / float(n)   
+        return mu, cov
+
+    def _get_style_vals(self, tensor):
+        mean, cov = self._calc_2_moments(tensor) 
+        eigvals, eigvects = torch.symeig(cov, eigenvectors=True)
+        eigroot_mat = torch.diag(torch.sqrt(eigvals.clamp(min=0)))     
+        root_cov = torch.mm(torch.mm(eigvects, eigroot_mat), eigvects.t())  
+        tr_cov = eigvals.clamp(min=0).sum() 
+        return mean, tr_cov, root_cov
+
+    def _calc_l2wass_dist(self, mean_stl, tr_cov_stl, root_cov_stl, mean_synth, cov_synth):
+        tr_cov_synth = torch.symeig(cov_synth, eigenvectors=True)[0].clamp(min=0).sum()
+        mean_diff_squared = (mean_stl - mean_synth).pow(2).sum()
+        cov_prod = torch.mm(torch.mm(root_cov_stl, cov_synth), root_cov_stl)
+        var_overlap = torch.sqrt(torch.symeig(cov_prod, eigenvectors=True)[0].clamp(min=0)+1e-8).sum()
+        dist = mean_diff_squared + tr_cov_stl + tr_cov_synth - 2*var_overlap
+        return dist
+
+    def _single_wass_loss(self, pred, targ):
+        mean_test, tr_cov_test, root_cov_test = targ
+        mean_synth, cov_synth = self._calc_2_moments(pred)
+        loss = self._calc_l2wass_dist(mean_test, tr_cov_test, root_cov_test, mean_synth, cov_synth)
+        return loss
+    
+    def forward(self, input, target):
+        out_feat = self._make_features(target, clone=True)
+        in_feat = self._make_features(input)
+        self.feat_losses = [self.base_loss(input,target)]
+        self.feat_losses += [self.base_loss(f_in, f_out)*w
+                             for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)]
+        
+        styles = [self._get_style_vals(i) for i in out_feat]
+        self.feat_losses += [self._single_wass_loss(f_pred, f_targ)*w
+                            for f_pred, f_targ, w in zip(in_feat, styles, self.wass_wgts)]
+        
+        self.metrics = dict(zip(self.metric_names, self.feat_losses))
+        return sum(self.feat_losses)
+    
+    def __del__(self): self.hooks.remove()

+ 145 - 97
fasterai/tensorboard.py

@@ -26,16 +26,16 @@ class ModelHistogramVisualizer():
     def __init__(self):
         return 
 
-    def write_tensorboard_histograms(self, model:nn.Module, iter_count:int, tbwriter:SummaryWriter):
-        for name, param in model.named_parameters():
-            tbwriter.add_histogram('/weights/' + name, param, iter_count)
+    def write_tensorboard_histograms(self, model:nn.Module, iter_count:int, tbwriter:SummaryWriter, name:str='model'):
+        for param_name, param in model.named_parameters():
+            tbwriter.add_histogram(name + '/weights/' + param_name, param, iter_count)
 
 
 class ModelStatsVisualizer(): 
     def __init__(self):
         return 
 
-    def write_tensorboard_stats(self, model:nn.Module, iter_count:int, tbwriter:SummaryWriter):
+    def write_tensorboard_stats(self, model:nn.Module, iter_count:int, tbwriter:SummaryWriter, name:str='model'):
         gradients = [x.grad  for x in model.parameters() if x.grad is not None]
         gradient_nps = [to_np(x.data) for x in gradients]
  
@@ -43,32 +43,32 @@ class ModelStatsVisualizer():
             return 
 
         avg_norm = sum(x.data.norm() for x in gradients)/len(gradients)
-        tbwriter.add_scalar('/gradients/avg_norm', avg_norm, iter_count)
+        tbwriter.add_scalar(name + '/gradients/avg_norm', avg_norm, iter_count)
 
         median_norm = statistics.median(x.data.norm() for x in gradients)
-        tbwriter.add_scalar('/gradients/median_norm', median_norm, iter_count)
+        tbwriter.add_scalar(name + '/gradients/median_norm', median_norm, iter_count)
 
         max_norm = max(x.data.norm() for x in gradients)
-        tbwriter.add_scalar('/gradients/max_norm', max_norm, iter_count)
+        tbwriter.add_scalar(name + '/gradients/max_norm', max_norm, iter_count)
 
         min_norm = min(x.data.norm() for x in gradients)
-        tbwriter.add_scalar('/gradients/min_norm', min_norm, iter_count)
+        tbwriter.add_scalar(name + '/gradients/min_norm', min_norm, iter_count)
 
         num_zeros = sum((np.asarray(x)==0.0).sum() for x in  gradient_nps)
-        tbwriter.add_scalar('/gradients/num_zeros', num_zeros, iter_count)
+        tbwriter.add_scalar(name + '/gradients/num_zeros', num_zeros, iter_count)
 
 
         avg_gradient= sum(x.data.mean() for x in gradients)/len(gradients)
-        tbwriter.add_scalar('/gradients/avg_gradient', avg_gradient, iter_count)
+        tbwriter.add_scalar(name + '/gradients/avg_gradient', avg_gradient, iter_count)
 
         median_gradient = statistics.median(x.data.median() for x in gradients)
-        tbwriter.add_scalar('/gradients/median_gradient', median_gradient, iter_count)
+        tbwriter.add_scalar(name + '/gradients/median_gradient', median_gradient, iter_count)
 
         max_gradient = max(x.data.max() for x in gradients) 
-        tbwriter.add_scalar('/gradients/max_gradient', max_gradient, iter_count)
+        tbwriter.add_scalar(name + '/gradients/max_gradient', max_gradient, iter_count)
 
         min_gradient = min(x.data.min() for x in gradients) 
-        tbwriter.add_scalar('/gradients/min_gradient', min_gradient, iter_count)
+        tbwriter.add_scalar(name + '/gradients/min_gradient', min_gradient, iter_count)
 
 class ImageGenVisualizer():
     def output_image_gen_visuals(self, learn:Learner, trn_batch:Tuple, val_batch:Tuple, iter_count:int, tbwriter:SummaryWriter):
@@ -98,51 +98,27 @@ class ImageGenVisualizer():
 
 #--------Below are what you actually want ot use, in practice----------------#
 
-class ModelTensorboardStatsWriter():
-    def __init__(self, base_dir: Path, module: nn.Module, name: str, stats_iters: int=10):
-        self.base_dir = base_dir
-        self.name = name
-        log_dir = base_dir/name
-        self.tbwriter = SummaryWriter(log_dir=str(log_dir))
-        self.hook = module.register_forward_hook(self.forward_hook)
-        self.stats_iters = stats_iters
-        self.iter_count = 0
-        self.model_vis = ModelStatsVisualizer() 
-
-    def forward_hook(self, module:nn.Module, input, output): 
-        self.iter_count += 1
-        if self.iter_count % self.stats_iters == 0:
-            self.model_vis.write_tensorboard_stats(module, iter_count=self.iter_count, tbwriter=self.tbwriter)  
-
 
-    def close(self):
-        self.tbwriter.close()
-        self.hook.remove()
 
-class GANTensorboardWriter(LearnerCallback):
-    def __init__(self, learn:Learner, base_dir:Path, name:str, stats_iters:int=10, 
-            visual_iters:int=200, weight_iters:int=1000):
+class LearnerTensorboardWriter(LearnerCallback):
+    def __init__(self, learn:Learner, base_dir:Path, name:str, loss_iters:int=25, weight_iters:int=1000, stats_iters:int=1000):
         super().__init__(learn=learn)
         self.base_dir = base_dir
         self.name = name
         log_dir = base_dir/name
         self.tbwriter = SummaryWriter(log_dir=str(log_dir))
-        self.stats_iters = stats_iters
-        self.visual_iters = visual_iters
+        self.loss_iters = loss_iters
         self.weight_iters = weight_iters
-        self.img_gen_vis = ImageGenVisualizer()
-        self.graph_vis = ModelGraphVisualizer()
+        self.stats_iters = stats_iters
+        self.iter_count = 0
         self.weight_vis = ModelHistogramVisualizer()
+        self.model_vis = ModelStatsVisualizer() 
         self.data = None
+        #Keeping track of iterations in callback, because callback can be used for multiple epocs and multiple fit calls.
+        #This ensures that graphs show continuous iterations rather than resetting to 0 (which makes them much harder to read!)
+        self.iteration = -1
 
-    def on_batch_end(self, iteration, metrics, **kwargs):
-        if iteration==0:
-            return
-
-        trainer = self.learn.gan_trainer
-        generator = trainer.generator
-        critic = trainer.critic
-        recorder = trainer.recorder
+    def _update_batches_if_needed(self):
         #one_batch is extremely slow.  this is an optimization
         update_batches = self.data is not self.learn.data
 
@@ -151,71 +127,143 @@ class GANTensorboardWriter(LearnerCallback):
             self.trn_batch = self.learn.data.one_batch(DatasetType.Train, detach=False, denorm=False)
             self.val_batch = self.learn.data.one_batch(DatasetType.Valid, detach=False, denorm=False)
 
-        if iteration % self.stats_iters == 0:  
-            if len(recorder.losses) > 0:      
-                trn_loss = to_np((recorder.losses[-1:])[0])
-                self.tbwriter.add_scalar('/loss/trn_loss', trn_loss, iteration)
+    def _write_model_stats(self, iteration):
+        self.model_vis.write_tensorboard_stats(model=self.learn.model, iter_count=iteration, tbwriter=self.tbwriter) 
 
-            if len(recorder.val_losses) > 0:
-                val_loss = (recorder.val_losses[-1:])[0]
-                self.tbwriter.add_scalar('/loss/val_loss', val_loss, iteration) 
+    def _write_training_loss(self, iteration, last_loss):
+        trn_loss = to_np(last_loss)
+        self.tbwriter.add_scalar('/loss/trn_loss', trn_loss, iteration)
 
-            #TODO:  Figure out how to do metrics here and gan vs critic loss
-            #values = [met[-1:] for met in recorder.metrics]
+    def _write_weight_histograms(self, iteration):
+        self.weight_vis.write_tensorboard_histograms(model=self.learn.model, iter_count=iteration, tbwriter=self.tbwriter)
 
-        if iteration % self.visual_iters == 0:
-            gen_mode = trainer.gen_mode
-            trainer.switch(gen_mode=True)
-            self.img_gen_vis.output_image_gen_visuals(learn=self.learn, trn_batch=self.trn_batch, val_batch=self.val_batch, 
-                                                    iter_count=iteration, tbwriter=self.tbwriter)
-            trainer.switch(gen_mode=gen_mode)
+    def _write_val_loss(self, iteration, last_metrics):
+        #TODO: Not a fan of this indexing but...what to do?
+        val_loss = last_metrics[0]
+        self.tbwriter.add_scalar('/loss/val_loss', val_loss, iteration)  
+    
+    def _write_metrics(self, iteration):
+        rec = self.learn.recorder
+
+        for i, name in enumerate(rec.names[3:]):
+            if len(rec.metrics) == 0: continue
+            if len(rec.metrics[-1:]) == 0: continue
+            if len(rec.metrics[-1:][0]) == 0: continue
+            value = rec.metrics[-1:][0][i]
+            if value is None: continue
+            self.tbwriter.add_scalar('/metrics/' + name, to_np(value), iteration) 
+
+    def on_batch_end(self, last_loss, metrics, **kwargs):
+        self.iteration +=1
+        iteration = self.iteration
+
+        if iteration==0:
+            return
+
+        self._update_batches_if_needed()
+
+        if iteration % self.loss_iters == 0: 
+            self._write_training_loss(iteration, last_loss)
 
         if iteration % self.weight_iters == 0:
-            self.weight_vis.write_tensorboard_histograms(model=generator, iter_count=iteration, tbwriter=self.tbwriter)
-            self.weight_vis.write_tensorboard_histograms(model=critic, iter_count=iteration, tbwriter=self.tbwriter)
-              
+            self._write_weight_histograms(iteration)
 
+        if iteration % self.stats_iters == 0:
+            self._write_model_stats(iteration)
 
-class ImageGenTensorboardWriter(LearnerCallback):
-    def __init__(self, learn:Learner, base_dir:Path, name:str, stats_iters:int=25, 
-            visual_iters:int=200, weight_iters:int=25):
-        super().__init__(learn=learn)
-        self.base_dir = base_dir
-        self.name = name
-        log_dir = base_dir/name
-        self.tbwriter = SummaryWriter(log_dir=str(log_dir))
-        self.stats_iters = stats_iters
+    def on_epoch_end(self, metrics, last_metrics, **kwargs):
+        iteration = self.iteration  
+        self._write_val_loss(iteration, last_metrics)
+        self._write_metrics(iteration)
+
+
+class GANTensorboardWriter(LearnerTensorboardWriter):
+    def __init__(self, learn:Learner, base_dir:Path, name:str, loss_iters:int=25, weight_iters:int=1000, 
+                stats_iters:int=1000, visual_iters:int=100):
+        super().__init__(learn=learn, base_dir=base_dir, name=name, loss_iters=loss_iters, 
+                        weight_iters=weight_iters, stats_iters=stats_iters)
         self.visual_iters = visual_iters
-        self.weight_iters = weight_iters
-        self.iter_count = 0
-        self.weight_vis = ModelHistogramVisualizer()
         self.img_gen_vis = ImageGenVisualizer()
-        self.data = None
 
-    def on_batch_end(self, iteration, last_loss, metrics, **kwargs):
+    #override
+    def _write_training_loss(self, iteration, last_loss):
+        trainer = self.learn.gan_trainer
+        recorder = trainer.recorder
+
+        if len(recorder.losses) > 0:      
+            trn_loss = to_np((recorder.losses[-1:])[0])
+            self.tbwriter.add_scalar('/loss/trn_loss', trn_loss, iteration)
+
+    #override
+    def _write_weight_histograms(self, iteration):
+        trainer = self.learn.gan_trainer
+        generator = trainer.generator
+        critic = trainer.critic
+
+        self.weight_vis.write_tensorboard_histograms(model=generator, iter_count=iteration, tbwriter=self.tbwriter, name='generator')
+        self.weight_vis.write_tensorboard_histograms(model=critic, iter_count=iteration, tbwriter=self.tbwriter, name='critic')
+
+    #override
+    def _write_model_stats(self, iteration):
+        trainer = self.learn.gan_trainer
+        generator = trainer.generator
+        critic = trainer.critic
+
+        self.model_vis.write_tensorboard_stats(model=generator, iter_count=iteration, tbwriter=self.tbwriter, name='generator')
+        self.model_vis.write_tensorboard_stats(model=critic, iter_count=iteration, tbwriter=self.tbwriter, name='critic')
+
+    #override
+    def _write_val_loss(self, iteration, last_metrics):
+        trainer = self.learn.gan_trainer
+        recorder = trainer.recorder 
+
+        if len(recorder.val_losses) > 0:
+            val_loss = (recorder.val_losses[-1:])[0]
+            self.tbwriter.add_scalar('/loss/val_loss', val_loss, iteration) 
+
+
+    def _write_images(self, iteration):
+        trainer = self.learn.gan_trainer
+        recorder = trainer.recorder
+
+        gen_mode = trainer.gen_mode
+        trainer.switch(gen_mode=True)
+        self.img_gen_vis.output_image_gen_visuals(learn=self.learn, trn_batch=self.trn_batch, val_batch=self.val_batch, 
+                                                iter_count=iteration, tbwriter=self.tbwriter)
+        trainer.switch(gen_mode=gen_mode)
+
+    def on_batch_end(self, metrics, **kwargs):
+        super().on_batch_end(metrics=metrics, **kwargs)
+
+        iteration = self.iteration
+
         if iteration==0:
             return
 
-        #one_batch is extremely slow.  this is an optimization
-        update_batches = self.data is not self.learn.data
+        if iteration % self.visual_iters == 0:
+            self._write_images(iteration)
 
-        if update_batches:
-            self.data = self.learn.data
-            self.trn_batch = self.learn.data.one_batch(DatasetType.Train, detach=False, denorm=False)
-            self.val_batch = self.learn.data.one_batch(DatasetType.Valid, detach=False, denorm=False)
+              
 
-        if iteration % self.stats_iters == 0: 
-            trn_loss = to_np(last_loss)
-            self.tbwriter.add_scalar('/loss/trn_loss', trn_loss, iteration)
+class ImageGenTensorboardWriter(LearnerTensorboardWriter):
+    def __init__(self, learn:Learner, base_dir:Path, name:str, loss_iters:int=25, weight_iters:int=1000, 
+                stats_iters:int=1000, visual_iters:int=100):
+        super().__init__(learn=learn, base_dir=base_dir, name=name, loss_iters=loss_iters, weight_iters=weight_iters, 
+                        stats_iters=stats_iters)
+        self.visual_iters = visual_iters
+        self.img_gen_vis = ImageGenVisualizer()
 
-        if iteration % self.visual_iters == 0:
-            self.img_gen_vis.output_image_gen_visuals(learn=self.learn, trn_batch=self.trn_batch, val_batch=self.val_batch, 
-                iter_count=iteration, tbwriter=self.tbwriter)
+    def _write_images(self, iteration):
+        self.img_gen_vis.output_image_gen_visuals(learn=self.learn, trn_batch=self.trn_batch, val_batch=self.val_batch, 
+            iter_count=iteration, tbwriter=self.tbwriter)
 
-        if iteration % self.weight_iters == 0:
-            self.weight_vis.write_tensorboard_histograms(model=self.learn.model, iter_count=iteration, tbwriter=self.tbwriter)
+    def on_batch_end(self, metrics, **kwargs):
+        super().on_batch_end(metrics=metrics, **kwargs)
 
-    def on_epoch_end(self, iteration, metrics, last_metrics, **kwargs):  
-        #TODO: Not a fan of this indexing but...what to do?
-        val_loss = last_metrics[0]
-        self.tbwriter.add_scalar('/loss/val_loss', val_loss, iteration)   
+        iteration = self.iteration
+
+        if iteration==0:
+            return
+
+        if iteration % self.visual_iters == 0:
+            self._write_images(iteration)