|
@@ -37,11 +37,15 @@ class ModelImageVisualizer():
|
|
|
def get_transformed_image_as_pil(self, path:str, model:nn.Module, sz:int=None, tfms:[Transform]=[])->Image:
|
|
|
path = Path(path)
|
|
|
array = self.get_transformed_image_ndarray(path, model, sz, tfms=tfms)
|
|
|
- return misc.toimage(array)
|
|
|
+ return self._convert_array_to_pil_image(array)
|
|
|
+
|
|
|
+ def _convert_array_to_pil_image(self, array:ndarray):
|
|
|
+ return Image.fromarray((array*255).astype('uint8'))
|
|
|
|
|
|
def _save_result_image(self, source_path:Path, result:ndarray):
|
|
|
result_path = self.results_dir/source_path.name
|
|
|
- misc.imsave(result_path, result)
|
|
|
+ im = self._convert_array_to_pil_image(result)
|
|
|
+ im.save(result_path)
|
|
|
|
|
|
def plot_images_from_image_sets(self, image_sets:[ModelImageSet], validation:bool, figsize:(int,int)=(20,20),
|
|
|
max_columns:int=6, immediate_display:bool=True):
|