images.py 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253
  1. import numpy as np
  2. from fastai.torch_imports import *
  3. from fastai.core import *
  4. from fastai.dataset import FilesDataset
  5. from pathlib import Path
  6. from itertools import repeat
  7. from PIL import Image
  8. from numpy import ndarray
  9. from datetime import datetime
  10. class EasyTensorImage():
  11. def __init__(self, source_tensor:torch.Tensor, ds:FilesDataset):
  12. self.array = self._convert_to_denormed_ndarray(source_tensor, ds=ds)
  13. self.tensor = self._convert_to_denormed_tensor(self.array)
  14. def _convert_to_denormed_ndarray(self, raw_tensor:torch.Tensor, ds:FilesDataset):
  15. raw_array = raw_tensor.clone().data.cpu().numpy()
  16. if raw_array.shape[1] != 3:
  17. array = np.zeros((3, 1, 1))
  18. return array
  19. else:
  20. return ds.denorm(raw_array)[0]
  21. def _convert_to_denormed_tensor(self, denormed_array: ndarray):
  22. return V(np.moveaxis(denormed_array,2,0))
  23. class ModelImageSet():
  24. @staticmethod
  25. def get_list_from_model(ds:FilesDataset, model:nn.Module, idxs:[int]):
  26. image_sets = []
  27. training = model.training
  28. model.eval()
  29. for idx in idxs:
  30. x,y=ds[idx]
  31. orig_tensor = VV(x[None])
  32. real_tensor = V(y[None])
  33. gen_tensor = model(orig_tensor)
  34. gen_easy = EasyTensorImage(gen_tensor, ds)
  35. orig_easy = EasyTensorImage(orig_tensor, ds)
  36. real_easy = EasyTensorImage(real_tensor, ds)
  37. image_set = ModelImageSet(orig_easy,real_easy,gen_easy)
  38. image_sets.append(image_set)
  39. if training:
  40. model.train()
  41. return image_sets
  42. def __init__(self, orig:EasyTensorImage, real:EasyTensorImage, gen:EasyTensorImage):
  43. self.orig=orig
  44. self.real=real
  45. self.gen=gen