|
@@ -47,10 +47,17 @@ class BaseFilter(IFilter):
|
|
|
x = x.to(self.device)
|
|
|
x.div_(255)
|
|
|
x, y = self.norm((x, x), do_x=True)
|
|
|
-
|
|
|
- result = self.learn.pred_batch(
|
|
|
- ds_type=DatasetType.Valid, batch=(x[None], y[None]), reconstruct=True
|
|
|
- )
|
|
|
+
|
|
|
+ try:
|
|
|
+ result = self.learn.pred_batch(
|
|
|
+ ds_type=DatasetType.Valid, batch=(x[None], y[None]), reconstruct=True
|
|
|
+ )
|
|
|
+ except RuntimeError as rerr:
|
|
|
+ if 'memory' not in str(rerr):
|
|
|
+ raise rerr
|
|
|
+ print('Warning: render_factor was set too high, and out of memory error resulted. Returning original image.')
|
|
|
+ return model_image
|
|
|
+
|
|
|
out = result[0]
|
|
|
out = self.denorm(out.px, do_x=False)
|
|
|
out = image2np(out * 255).astype(np.uint8)
|