|
@@ -24,6 +24,7 @@ class BaseFilter(IFilter):
|
|
|
def __init__(self, learn: Learner, stats: tuple = imagenet_stats):
|
|
|
super().__init__()
|
|
|
self.learn = learn
|
|
|
+ self.device = next(self.learn.model.parameters()).device
|
|
|
self.norm, self.denorm = normalize_funcs(*stats)
|
|
|
|
|
|
def _transform(self, image: PilImage) -> PilImage:
|
|
@@ -43,10 +44,12 @@ class BaseFilter(IFilter):
|
|
|
def _model_process(self, orig: PilImage, sz: int) -> PilImage:
|
|
|
model_image = self._get_model_ready_image(orig, sz)
|
|
|
x = pil2tensor(model_image, np.float32)
|
|
|
+ x = x.to(self.device)
|
|
|
x.div_(255)
|
|
|
x, y = self.norm((x, x), do_x=True)
|
|
|
+
|
|
|
result = self.learn.pred_batch(
|
|
|
- ds_type=DatasetType.Valid, batch=(x[None].cuda(), y[None]), reconstruct=True
|
|
|
+ ds_type=DatasetType.Valid, batch=(x[None], y[None]), reconstruct=True
|
|
|
)
|
|
|
out = result[0]
|
|
|
out = self.denorm(out.px, do_x=False)
|