|
@@ -63,10 +63,12 @@ class ModelImageVisualizer():
|
|
|
def get_transformed_image_ndarray(self, path:Path, model:nn.Module, sz:int=None, tfms:[Transform]=[]):
|
|
|
training = model.training
|
|
|
model.eval()
|
|
|
- orig = self._get_model_ready_image_ndarray(path, model, sz, tfms)
|
|
|
- orig = VV(orig[None])
|
|
|
- result = model(orig).detach().cpu().numpy()
|
|
|
- result = self._denorm(result)
|
|
|
+ with torch.no_grad():
|
|
|
+ orig = self._get_model_ready_image_ndarray(path, model, sz, tfms)
|
|
|
+ orig = VV_(orig[None])
|
|
|
+ result = model(orig).detach().cpu().numpy()
|
|
|
+ result = self._denorm(result)
|
|
|
+
|
|
|
if training:
|
|
|
model.train()
|
|
|
return result[0]
|