dataloader.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. import torch, queue
  2. from torch.utils.data.sampler import SequentialSampler, RandomSampler, BatchSampler
  3. from .imports import *
  4. from .core import *
  5. import collections,sys,traceback,threading
  6. string_classes = (str, bytes)
  7. def get_tensor(batch, pin, half=False):
  8. if isinstance(batch, (np.ndarray, np.generic)):
  9. batch = T(batch, half=half, cuda=False).contiguous()
  10. if pin: batch = batch.pin_memory()
  11. return to_gpu(batch)
  12. elif isinstance(batch, string_classes):
  13. return batch
  14. elif isinstance(batch, collections.Mapping):
  15. return {k: get_tensor(sample, pin, half) for k, sample in batch.items()}
  16. elif isinstance(batch, collections.Sequence):
  17. return [get_tensor(sample, pin, half) for sample in batch]
  18. raise TypeError(f"batch must contain numbers, dicts or lists; found {type(batch)}")
  19. class DataLoader(object):
  20. def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, pad_idx=0,
  21. num_workers=None, pin_memory=False, drop_last=False, pre_pad=True, half=False,
  22. transpose=False, transpose_y=False):
  23. self.dataset,self.batch_size,self.num_workers = dataset,batch_size,num_workers
  24. self.pin_memory,self.drop_last,self.pre_pad = pin_memory,drop_last,pre_pad
  25. self.transpose,self.transpose_y,self.pad_idx,self.half = transpose,transpose_y,pad_idx,half
  26. if batch_sampler is not None:
  27. if batch_size > 1 or shuffle or sampler is not None or drop_last:
  28. raise ValueError('batch_sampler is mutually exclusive with '
  29. 'batch_size, shuffle, sampler, and drop_last')
  30. if sampler is not None and shuffle:
  31. raise ValueError('sampler is mutually exclusive with shuffle')
  32. if batch_sampler is None:
  33. if sampler is None:
  34. sampler = RandomSampler(dataset) if shuffle else SequentialSampler(dataset)
  35. batch_sampler = BatchSampler(sampler, batch_size, drop_last)
  36. if num_workers is None:
  37. self.num_workers = num_cpus()
  38. self.sampler = sampler
  39. self.batch_sampler = batch_sampler
  40. def __len__(self): return len(self.batch_sampler)
  41. def jag_stack(self, b):
  42. if len(b[0].shape) not in (1,2): return np.stack(b)
  43. ml = max(len(o) for o in b)
  44. if min(len(o) for o in b)==ml: return np.stack(b)
  45. res = np.zeros((len(b), ml), dtype=b[0].dtype) + self.pad_idx
  46. for i,o in enumerate(b):
  47. if self.pre_pad: res[i, -len(o):] = o
  48. else: res[i, :len(o)] = o
  49. return res
  50. def np_collate(self, batch):
  51. b = batch[0]
  52. if isinstance(b, (np.ndarray, np.generic)): return self.jag_stack(batch)
  53. elif isinstance(b, (int, float)): return np.array(batch)
  54. elif isinstance(b, string_classes): return batch
  55. elif isinstance(b, collections.Mapping):
  56. return {key: self.np_collate([d[key] for d in batch]) for key in b}
  57. elif isinstance(b, collections.Sequence):
  58. return [self.np_collate(samples) for samples in zip(*batch)]
  59. raise TypeError(("batch must contain numbers, dicts or lists; found {}".format(type(b))))
  60. def get_batch(self, indices):
  61. res = self.np_collate([self.dataset[i] for i in indices])
  62. if self.transpose: res[0] = res[0].T
  63. if self.transpose_y: res[1] = res[1].T
  64. return res
  65. def __iter__(self):
  66. if self.num_workers==0:
  67. for batch in map(self.get_batch, iter(self.batch_sampler)):
  68. yield get_tensor(batch, self.pin_memory, self.half)
  69. else:
  70. with ThreadPoolExecutor(max_workers=self.num_workers) as e:
  71. # avoid py3.6 issue where queue is infinite and can result in memory exhaustion
  72. for c in chunk_iter(iter(self.batch_sampler), self.num_workers*10):
  73. for batch in e.map(self.get_batch, c):
  74. yield get_tensor(batch, self.pin_memory, self.half)