transforms_pil.py 1.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243
  1. import torch
  2. import numpy as np
  3. class Cutout(object):
  4. """Randomly mask out one or more patches from an image.
  5. Args:
  6. n_holes (int): Number of patches to cut out of each image.
  7. length (int): The length (in pixels) of each square patch.
  8. """
  9. def __init__(self, n_holes, length):
  10. self.n_holes = n_holes
  11. self.length = length
  12. def __call__(self, img):
  13. """
  14. Args:
  15. img (Tensor): Tensor image of size (C, H, W).
  16. Returns:
  17. Tensor: Image with n_holes of dimension length x length cut out of it.
  18. """
  19. h = img.size(1)
  20. w = img.size(2)
  21. mask = np.ones((h, w), np.float32)
  22. for n in range(self.n_holes):
  23. y = np.random.randint(h)
  24. x = np.random.randint(w)
  25. y1 = np.clip(y - self.length / 2, 0, h)
  26. y2 = np.clip(y + self.length / 2, 0, h)
  27. x1 = np.clip(x - self.length / 2, 0, w)
  28. x2 = np.clip(x + self.length / 2, 0, w)
  29. mask[y1: y2, x1: x2] = 0.
  30. mask = torch.from_numpy(mask)
  31. mask = mask.expand_as(img)
  32. img = img * mask
  33. return img