column_data.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224
  1. from .imports import *
  2. from .torch_imports import *
  3. from .dataset import *
  4. from .learner import *
  5. class PassthruDataset(Dataset):
  6. def __init__(self,*args, is_reg=True, is_multi=False):
  7. *xs,y=args
  8. self.xs,self.y = xs,y
  9. self.is_reg = is_reg
  10. self.is_multi = is_multi
  11. def __len__(self): return len(self.y)
  12. def __getitem__(self, idx): return [o[idx] for o in self.xs] + [self.y[idx]]
  13. @classmethod
  14. def from_data_frame(cls, df, cols_x, col_y, is_reg=True, is_multi=False):
  15. cols = [df[o] for o in cols_x+[col_y]]
  16. return cls(*cols, is_reg=is_reg, is_multi=is_multi)
  17. class ColumnarDataset(Dataset):
  18. def __init__(self, cats, conts, y, is_reg, is_multi):
  19. n = len(cats[0]) if cats else len(conts[0])
  20. self.cats = np.stack(cats, 1).astype(np.int64) if cats else np.zeros((n,1))
  21. self.conts = np.stack(conts, 1).astype(np.float32) if conts else np.zeros((n,1))
  22. self.y = np.zeros((n,1)) if y is None else y
  23. if is_reg:
  24. self.y = self.y[:,None]
  25. self.is_reg = is_reg
  26. self.is_multi = is_multi
  27. def __len__(self): return len(self.y)
  28. def __getitem__(self, idx):
  29. return [self.cats[idx], self.conts[idx], self.y[idx]]
  30. @classmethod
  31. def from_data_frames(cls, df_cat, df_cont, y=None, is_reg=True, is_multi=False):
  32. cat_cols = [c.values for n,c in df_cat.items()]
  33. cont_cols = [c.values for n,c in df_cont.items()]
  34. return cls(cat_cols, cont_cols, y, is_reg, is_multi)
  35. @classmethod
  36. def from_data_frame(cls, df, cat_flds, y=None, is_reg=True, is_multi=False):
  37. return cls.from_data_frames(df[cat_flds], df.drop(cat_flds, axis=1), y, is_reg, is_multi)
  38. class ColumnarModelData(ModelData):
  39. def __init__(self, path, trn_ds, val_ds, bs, test_ds=None, shuffle=True):
  40. test_dl = DataLoader(test_ds, bs, shuffle=False, num_workers=1) if test_ds is not None else None
  41. super().__init__(path, DataLoader(trn_ds, bs, shuffle=shuffle, num_workers=1),
  42. DataLoader(val_ds, bs*2, shuffle=False, num_workers=1), test_dl)
  43. @classmethod
  44. def from_arrays(cls, path, val_idxs, xs, y, is_reg=True, is_multi=False, bs=64, test_xs=None, shuffle=True):
  45. ((val_xs, trn_xs), (val_y, trn_y)) = split_by_idx(val_idxs, xs, y)
  46. test_ds = PassthruDataset(*(test_xs.T), [0] * len(test_xs), is_reg=is_reg, is_multi=is_multi) if test_xs is not None else None
  47. return cls(path, PassthruDataset(*(trn_xs.T), trn_y, is_reg=is_reg, is_multi=is_multi),
  48. PassthruDataset(*(val_xs.T), val_y, is_reg=is_reg, is_multi=is_multi),
  49. bs=bs, shuffle=shuffle, test_ds=test_ds)
  50. @classmethod
  51. def from_data_frames(cls, path, trn_df, val_df, trn_y, val_y, cat_flds, bs, is_reg, is_multi, test_df=None, shuffle=True):
  52. trn_ds = ColumnarDataset.from_data_frame(trn_df, cat_flds, trn_y, is_reg, is_multi)
  53. val_ds = ColumnarDataset.from_data_frame(val_df, cat_flds, val_y, is_reg, is_multi)
  54. test_ds = ColumnarDataset.from_data_frame(test_df, cat_flds, None, is_reg, is_multi) if test_df is not None else None
  55. return cls(path, trn_ds, val_ds, bs, test_ds=test_ds, shuffle=shuffle)
  56. @classmethod
  57. def from_data_frame(cls, path, val_idxs, df, y, cat_flds, bs, is_reg=True, is_multi=False, test_df=None, shuffle=True):
  58. ((val_df, trn_df), (val_y, trn_y)) = split_by_idx(val_idxs, df, y)
  59. return cls.from_data_frames(path, trn_df, val_df, trn_y, val_y, cat_flds, bs, is_reg, is_multi, test_df=test_df, shuffle=shuffle)
  60. def get_learner(self, emb_szs, n_cont, emb_drop, out_sz, szs, drops,
  61. y_range=None, use_bn=False, **kwargs):
  62. model = MixedInputModel(emb_szs, n_cont, emb_drop, out_sz, szs, drops, y_range, use_bn, self.is_reg, self.is_multi)
  63. return StructuredLearner(self, StructuredModel(to_gpu(model)), opt_fn=optim.Adam, **kwargs)
  64. def emb_init(x):
  65. x = x.weight.data
  66. sc = 2/(x.size(1)+1)
  67. x.uniform_(-sc,sc)
  68. class MixedInputModel(nn.Module):
  69. def __init__(self, emb_szs, n_cont, emb_drop, out_sz, szs, drops,
  70. y_range=None, use_bn=False, is_reg=True, is_multi=False):
  71. super().__init__()
  72. self.embs = nn.ModuleList([nn.Embedding(c, s) for c,s in emb_szs])
  73. for emb in self.embs: emb_init(emb)
  74. n_emb = sum(e.embedding_dim for e in self.embs)
  75. self.n_emb, self.n_cont=n_emb, n_cont
  76. szs = [n_emb+n_cont] + szs
  77. self.lins = nn.ModuleList([
  78. nn.Linear(szs[i], szs[i+1]) for i in range(len(szs)-1)])
  79. self.bns = nn.ModuleList([
  80. nn.BatchNorm1d(sz) for sz in szs[1:]])
  81. for o in self.lins: kaiming_normal(o.weight.data)
  82. self.outp = nn.Linear(szs[-1], out_sz)
  83. kaiming_normal(self.outp.weight.data)
  84. self.emb_drop = nn.Dropout(emb_drop)
  85. self.drops = nn.ModuleList([nn.Dropout(drop) for drop in drops])
  86. self.bn = nn.BatchNorm1d(n_cont)
  87. self.use_bn,self.y_range = use_bn,y_range
  88. self.is_reg = is_reg
  89. self.is_multi = is_multi
  90. def forward(self, x_cat, x_cont):
  91. if self.n_emb != 0:
  92. x = [e(x_cat[:,i]) for i,e in enumerate(self.embs)]
  93. x = torch.cat(x, 1)
  94. x = self.emb_drop(x)
  95. if self.n_cont != 0:
  96. x2 = self.bn(x_cont)
  97. x = torch.cat([x, x2], 1) if self.n_emb != 0 else x2
  98. for l,d,b in zip(self.lins, self.drops, self.bns):
  99. x = F.relu(l(x))
  100. if self.use_bn: x = b(x)
  101. x = d(x)
  102. x = self.outp(x)
  103. if not self.is_reg:
  104. if self.is_multi:
  105. x = F.sigmoid(x)
  106. else:
  107. x = F.log_softmax(x)
  108. elif self.y_range:
  109. x = F.sigmoid(x)
  110. x = x*(self.y_range[1] - self.y_range[0])
  111. x = x+self.y_range[0]
  112. return x
  113. class StructuredLearner(Learner):
  114. def __init__(self, data, models, **kwargs):
  115. super().__init__(data, models, **kwargs)
  116. def _get_crit(self, data): return F.mse_loss if data.is_reg else F.binary_cross_entropy if data.is_multi else F.nll_loss
  117. def summary(self):
  118. x = [torch.ones(3, self.data.trn_ds.cats.shape[1], dtype=torch.int64), torch.rand(3, self.data.trn_ds.conts.shape[1])]
  119. return model_summary(self.model, x)
  120. class StructuredModel(BasicModel):
  121. def get_layer_groups(self):
  122. m=self.model
  123. return [m.embs, children(m.lins)+children(m.bns), m.outp]
  124. class CollabFilterDataset(Dataset):
  125. def __init__(self, path, user_col, item_col, ratings):
  126. self.ratings,self.path = ratings.values.astype(np.float32),path
  127. self.n = len(ratings)
  128. (self.users,self.user2idx,self.user_col,self.n_users) = self.proc_col(user_col)
  129. (self.items,self.item2idx,self.item_col,self.n_items) = self.proc_col(item_col)
  130. self.min_score,self.max_score = min(ratings),max(ratings)
  131. self.cols = [self.user_col,self.item_col,self.ratings]
  132. @classmethod
  133. def from_data_frame(cls, path, df, user_name, item_name, rating_name):
  134. return cls(path, df[user_name], df[item_name], df[rating_name])
  135. @classmethod
  136. def from_csv(cls, path, csv, user_name, item_name, rating_name):
  137. df = pd.read_csv(os.path.join(path,csv))
  138. return cls.from_data_frame(path, df, user_name, item_name, rating_name)
  139. def proc_col(self,col):
  140. uniq = col.unique()
  141. name2idx = {o:i for i,o in enumerate(uniq)}
  142. return (uniq, name2idx, np.array([name2idx[x] for x in col]), len(uniq))
  143. def __len__(self): return self.n
  144. def __getitem__(self, idx): return [o[idx] for o in self.cols]
  145. def get_data(self, val_idxs, bs):
  146. val, trn = zip(*split_by_idx(val_idxs, *self.cols))
  147. return ColumnarModelData(self.path, PassthruDataset(*trn), PassthruDataset(*val), bs)
  148. def get_model(self, n_factors):
  149. model = EmbeddingDotBias(n_factors, self.n_users, self.n_items, self.min_score, self.max_score)
  150. return CollabFilterModel(to_gpu(model))
  151. def get_learner(self, n_factors, val_idxs, bs, **kwargs):
  152. return CollabFilterLearner(self.get_data(val_idxs, bs), self.get_model(n_factors), **kwargs)
  153. def get_emb(ni,nf):
  154. e = nn.Embedding(ni, nf)
  155. e.weight.data.uniform_(-0.05,0.05)
  156. return e
  157. class EmbeddingDotBias(nn.Module):
  158. def __init__(self, n_factors, n_users, n_items, min_score, max_score):
  159. super().__init__()
  160. self.min_score,self.max_score = min_score,max_score
  161. (self.u, self.i, self.ub, self.ib) = [get_emb(*o) for o in [
  162. (n_users, n_factors), (n_items, n_factors), (n_users,1), (n_items,1)
  163. ]]
  164. def forward(self, users, items):
  165. um = self.u(users)* self.i(items)
  166. res = um.sum(1) + self.ub(users).squeeze() + self.ib(items).squeeze()
  167. return F.sigmoid(res) * (self.max_score-self.min_score) + self.min_score
  168. class CollabFilterLearner(Learner):
  169. def __init__(self, data, models, **kwargs):
  170. super().__init__(data, models, **kwargs)
  171. def _get_crit(self, data): return F.mse_loss
  172. def summary(self): return model_summary(self.model, [torch.ones(3, dtype=torch.int64), torch.ones(3, dtype=torch.int64)])
  173. class CollabFilterModel(BasicModel):
  174. def get_layer_groups(self): return self.model