text.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221
  1. from .core import *
  2. from .learner import *
  3. from .lm_rnn import *
  4. from torch.utils.data.sampler import Sampler
  5. import spacy
  6. from spacy.symbols import ORTH
  7. re_tok = re.compile(f'([{string.punctuation}“”¨«»®´·º½¾¿¡§£₤‘’])')
  8. def tokenize(s): return re_tok.sub(r' \1 ', s).split()
  9. def texts_labels_from_folders(path, folders):
  10. texts,labels = [],[]
  11. for idx,label in enumerate(folders):
  12. for fname in glob(os.path.join(path, label, '*.*')):
  13. texts.append(open(fname, 'r').read())
  14. labels.append(idx)
  15. return texts, np.array(labels).astype(np.int64)
  16. def numericalize_tok(tokens, max_vocab=50000, min_freq=0, unk_tok="_unk_", pad_tok="_pad_", bos_tok="_bos_", eos_tok="_eos_"):
  17. """Takes in text tokens and returns int2tok and tok2int converters
  18. Arguments:
  19. tokens(list): List of tokens. Can be a list of strings, or a list of lists of strings.
  20. max_vocab(int): Number of tokens to return in the vocab (sorted by frequency)
  21. min_freq(int): Minimum number of instances a token must be present in order to be preserved.
  22. unk_tok(str): Token to use when unknown tokens are encountered in the source text.
  23. pad_tok(str): Token to use when padding sequences.
  24. """
  25. if isinstance(tokens, str):
  26. raise ValueError("Expected to receive a list of tokens. Received a string instead")
  27. if isinstance(tokens[0], list):
  28. tokens = [p for o in tokens for p in o]
  29. freq = Counter(tokens)
  30. int2tok = [o for o,c in freq.most_common(max_vocab) if c>min_freq]
  31. unk_id = 3
  32. int2tok.insert(0, bos_tok)
  33. int2tok.insert(1, pad_tok)
  34. int2tok.insert(2, eos_tok)
  35. int2tok.insert(unk_id, unk_tok)
  36. tok2int = collections.defaultdict(lambda:unk_id, {v:k for k,v in enumerate(int2tok)})
  37. return int2tok, tok2int
  38. class Tokenizer():
  39. def __init__(self, lang='en'):
  40. self.re_br = re.compile(r'<\s*br\s*/?>', re.IGNORECASE)
  41. self.tok = spacy.load(lang)
  42. for w in ('<eos>','<bos>','<unk>'):
  43. self.tok.tokenizer.add_special_case(w, [{ORTH: w}])
  44. def sub_br(self,x): return self.re_br.sub("\n", x)
  45. def spacy_tok(self,x):
  46. return [t.text for t in self.tok.tokenizer(self.sub_br(x))]
  47. re_rep = re.compile(r'(\S)(\1{3,})')
  48. re_word_rep = re.compile(r'(\b\w+\W+)(\1{3,})')
  49. @staticmethod
  50. def replace_rep(m):
  51. TK_REP = 'tk_rep'
  52. c,cc = m.groups()
  53. return f' {TK_REP} {len(cc)+1} {c} '
  54. @staticmethod
  55. def replace_wrep(m):
  56. TK_WREP = 'tk_wrep'
  57. c,cc = m.groups()
  58. return f' {TK_WREP} {len(cc.split())+1} {c} '
  59. @staticmethod
  60. def do_caps(ss):
  61. TOK_UP,TOK_SENT,TOK_MIX = ' t_up ',' t_st ',' t_mx '
  62. res = []
  63. prev='.'
  64. re_word = re.compile('\w')
  65. re_nonsp = re.compile('\S')
  66. for s in re.findall(r'\w+|\W+', ss):
  67. res += ([TOK_UP,s.lower()] if (s.isupper() and (len(s)>2))
  68. # else [TOK_SENT,s.lower()] if (s.istitle() and re_word.search(prev))
  69. else [s.lower()])
  70. # if re_nonsp.search(s): prev = s
  71. return ''.join(res)
  72. def proc_text(self, s):
  73. s = self.re_rep.sub(Tokenizer.replace_rep, s)
  74. s = self.re_word_rep.sub(Tokenizer.replace_wrep, s)
  75. s = Tokenizer.do_caps(s)
  76. s = re.sub(r'([/#])', r' \1 ', s)
  77. s = re.sub(' {2,}', ' ', s)
  78. return self.spacy_tok(s)
  79. @staticmethod
  80. def proc_all(ss, lang):
  81. tok = Tokenizer(lang)
  82. return [tok.proc_text(s) for s in ss]
  83. @staticmethod
  84. def proc_all_mp(ss, lang='en'):
  85. ncpus = num_cpus()//2
  86. with ProcessPoolExecutor(ncpus) as e:
  87. return sum(e.map(Tokenizer.proc_all, ss, [lang]*len(ss)), [])
  88. class TextDataset(Dataset):
  89. def __init__(self, x, y, backwards=False, sos=None, eos=None):
  90. self.x,self.y,self.backwards,self.sos,self.eos = x,y,backwards,sos,eos
  91. def __getitem__(self, idx):
  92. x = self.x[idx]
  93. if self.backwards: x = list(reversed(x))
  94. if self.eos is not None: x = x + [self.eos]
  95. if self.sos is not None: x = [self.sos]+x
  96. return np.array(x),self.y[idx]
  97. def __len__(self): return len(self.x)
  98. class SortSampler(Sampler):
  99. def __init__(self, data_source, key): self.data_source,self.key = data_source,key
  100. def __len__(self): return len(self.data_source)
  101. def __iter__(self):
  102. return iter(sorted(range(len(self.data_source)), key=self.key, reverse=True))
  103. class SortishSampler(Sampler):
  104. """Returns an iterator that traverses the the data in randomly ordered batches that are approximately the same size.
  105. The max key size batch is always returned in the first call because of pytorch cuda memory allocation sequencing.
  106. Without that max key returned first multiple buffers may be allocated when the first created isn't large enough
  107. to hold the next in the sequence.
  108. """
  109. def __init__(self, data_source, key, bs):
  110. self.data_source,self.key,self.bs = data_source,key,bs
  111. def __len__(self): return len(self.data_source)
  112. def __iter__(self):
  113. idxs = np.random.permutation(len(self.data_source))
  114. sz = self.bs*50
  115. ck_idx = [idxs[i:i+sz] for i in range(0, len(idxs), sz)]
  116. sort_idx = np.concatenate([sorted(s, key=self.key, reverse=True) for s in ck_idx])
  117. sz = self.bs
  118. ck_idx = [sort_idx[i:i+sz] for i in range(0, len(sort_idx), sz)]
  119. max_ck = np.argmax([self.key(ck[0]) for ck in ck_idx]) # find the chunk with the largest key,
  120. ck_idx[0],ck_idx[max_ck] = ck_idx[max_ck],ck_idx[0] # then make sure it goes first.
  121. sort_idx = np.concatenate(np.random.permutation(ck_idx[1:]))
  122. sort_idx = np.concatenate((ck_idx[0], sort_idx))
  123. return iter(sort_idx)
  124. class LanguageModelLoader():
  125. """ Returns a language model iterator that iterates through batches that are of length N(bptt,5)
  126. The first batch returned is always bptt+25; the max possible width. This is done because of they way that pytorch
  127. allocates cuda memory in order to prevent multiple buffers from being created as the batch width grows.
  128. """
  129. def __init__(self, nums, bs, bptt, backwards=False):
  130. self.bs,self.bptt,self.backwards = bs,bptt,backwards
  131. self.data = self.batchify(nums)
  132. self.i,self.iter = 0,0
  133. self.n = len(self.data)
  134. def __iter__(self):
  135. self.i,self.iter = 0,0
  136. while self.i < self.n-1 and self.iter<len(self):
  137. if self.i == 0:
  138. seq_len = self.bptt + 5 * 5
  139. else:
  140. bptt = self.bptt if np.random.random() < 0.95 else self.bptt / 2.
  141. seq_len = max(5, int(np.random.normal(bptt, 5)))
  142. res = self.get_batch(self.i, seq_len)
  143. self.i += seq_len
  144. self.iter += 1
  145. yield res
  146. def __len__(self): return self.n // self.bptt - 1
  147. def batchify(self, data):
  148. nb = data.shape[0] // self.bs
  149. data = np.array(data[:nb*self.bs])
  150. data = data.reshape(self.bs, -1).T
  151. if self.backwards: data=data[::-1]
  152. return T(data)
  153. def get_batch(self, i, seq_len):
  154. source = self.data
  155. seq_len = min(seq_len, len(source) - 1 - i)
  156. return source[i:i+seq_len], source[i+1:i+1+seq_len].view(-1)
  157. class LanguageModel(BasicModel):
  158. def get_layer_groups(self):
  159. m = self.model[0]
  160. return [*zip(m.rnns, m.dropouths), (self.model[1], m.dropouti)]
  161. class LanguageModelData():
  162. def __init__(self, path, pad_idx, n_tok, trn_dl, val_dl, test_dl=None, **kwargs):
  163. self.path,self.pad_idx,self.n_tok = path,pad_idx,n_tok
  164. self.trn_dl,self.val_dl,self.test_dl = trn_dl,val_dl,test_dl
  165. def get_model(self, opt_fn, emb_sz, n_hid, n_layers, **kwargs):
  166. m = get_language_model(self.n_tok, emb_sz, n_hid, n_layers, self.pad_idx, **kwargs)
  167. model = LanguageModel(to_gpu(m))
  168. return RNN_Learner(self, model, opt_fn=opt_fn)
  169. class RNN_Learner(Learner):
  170. def __init__(self, data, models, **kwargs):
  171. super().__init__(data, models, **kwargs)
  172. def _get_crit(self, data): return F.cross_entropy
  173. def fit(self, *args, **kwargs): return super().fit(*args, **kwargs, seq_first=True)
  174. def save_encoder(self, name): save_model(self.model[0], self.get_model_path(name))
  175. def load_encoder(self, name): load_model(self.model[0], self.get_model_path(name))
  176. class TextModel(BasicModel):
  177. def get_layer_groups(self):
  178. m = self.model[0]
  179. return [(m.encoder, m.dropouti), *zip(m.rnns, m.dropouths), (self.model[1])]