import torch, queue from torch.utils.data.sampler import SequentialSampler, RandomSampler, BatchSampler from .imports import * from .core import * import collections,sys,traceback,threading string_classes = (str, bytes) def get_tensor(batch, pin, half=False): if isinstance(batch, (np.ndarray, np.generic)): batch = T(batch, half=half, cuda=False).contiguous() if pin: batch = batch.pin_memory() return to_gpu(batch) elif isinstance(batch, string_classes): return batch elif isinstance(batch, collections.Mapping): return {k: get_tensor(sample, pin, half) for k, sample in batch.items()} elif isinstance(batch, collections.Sequence): return [get_tensor(sample, pin, half) for sample in batch] raise TypeError(f"batch must contain numbers, dicts or lists; found {type(batch)}") class DataLoader(object): def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, pad_idx=0, num_workers=None, pin_memory=False, drop_last=False, pre_pad=True, half=False, transpose=False, transpose_y=False): self.dataset,self.batch_size,self.num_workers = dataset,batch_size,num_workers self.pin_memory,self.drop_last,self.pre_pad = pin_memory,drop_last,pre_pad self.transpose,self.transpose_y,self.pad_idx,self.half = transpose,transpose_y,pad_idx,half if batch_sampler is not None: if batch_size > 1 or shuffle or sampler is not None or drop_last: raise ValueError('batch_sampler is mutually exclusive with ' 'batch_size, shuffle, sampler, and drop_last') if sampler is not None and shuffle: raise ValueError('sampler is mutually exclusive with shuffle') if batch_sampler is None: if sampler is None: sampler = RandomSampler(dataset) if shuffle else SequentialSampler(dataset) batch_sampler = BatchSampler(sampler, batch_size, drop_last) if num_workers is None: self.num_workers = num_cpus() self.sampler = sampler self.batch_sampler = batch_sampler def __len__(self): return len(self.batch_sampler) def jag_stack(self, b): if len(b[0].shape) not in (1,2): return np.stack(b) ml = max(len(o) for o in b) if min(len(o) for o in b)==ml: return np.stack(b) res = np.zeros((len(b), ml), dtype=b[0].dtype) + self.pad_idx for i,o in enumerate(b): if self.pre_pad: res[i, -len(o):] = o else: res[i, :len(o)] = o return res def np_collate(self, batch): b = batch[0] if isinstance(b, (np.ndarray, np.generic)): return self.jag_stack(batch) elif isinstance(b, (int, float)): return np.array(batch) elif isinstance(b, string_classes): return batch elif isinstance(b, collections.Mapping): return {key: self.np_collate([d[key] for d in batch]) for key in b} elif isinstance(b, collections.Sequence): return [self.np_collate(samples) for samples in zip(*batch)] raise TypeError(("batch must contain numbers, dicts or lists; found {}".format(type(b)))) def get_batch(self, indices): res = self.np_collate([self.dataset[i] for i in indices]) if self.transpose: res[0] = res[0].T if self.transpose_y: res[1] = res[1].T return res def __iter__(self): if self.num_workers==0: for batch in map(self.get_batch, iter(self.batch_sampler)): yield get_tensor(batch, self.pin_memory, self.half) else: with ThreadPoolExecutor(max_workers=self.num_workers) as e: # avoid py3.6 issue where queue is infinite and can result in memory exhaustion for c in chunk_iter(iter(self.batch_sampler), self.num_workers*10): for batch in e.map(self.get_batch, c): yield get_tensor(batch, self.pin_memory, self.half)