1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253 |
- import numpy as np
- from fastai.torch_imports import *
- from fastai.core import *
- from fastai.dataset import FilesDataset
- from pathlib import Path
- from itertools import repeat
- from PIL import Image
- from numpy import ndarray
- from datetime import datetime
- class EasyTensorImage():
- def __init__(self, source_tensor:torch.Tensor, ds:FilesDataset):
- self.array = self._convert_to_denormed_ndarray(source_tensor, ds=ds)
- self.tensor = self._convert_to_denormed_tensor(self.array)
-
- def _convert_to_denormed_ndarray(self, raw_tensor:torch.Tensor, ds:FilesDataset):
- raw_array = raw_tensor.clone().data.cpu().numpy()
- if raw_array.shape[1] != 3:
- array = np.zeros((3, 1, 1))
- return array
- else:
- return ds.denorm(raw_array)[0]
- def _convert_to_denormed_tensor(self, denormed_array: ndarray):
- return V(np.moveaxis(denormed_array,2,0))
- class ModelImageSet():
- @staticmethod
- def get_list_from_model(ds:FilesDataset, model:nn.Module, idxs:[int]):
- image_sets = []
- training = model.training
- model.eval()
-
- for idx in idxs:
- x,y=ds[idx]
- orig_tensor = VV(x[None])
- real_tensor = V(y[None])
- gen_tensor = model(orig_tensor)
- gen_easy = EasyTensorImage(gen_tensor, ds)
- orig_easy = EasyTensorImage(orig_tensor, ds)
- real_easy = EasyTensorImage(real_tensor, ds)
- image_set = ModelImageSet(orig_easy,real_easy,gen_easy)
- image_sets.append(image_set)
- if training:
- model.train()
- return image_sets
- def __init__(self, orig:EasyTensorImage, real:EasyTensorImage, gen:EasyTensorImage):
- self.orig=orig
- self.real=real
- self.gen=gen
|