tasks.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238
  1. import json
  2. import logging
  3. import matplotlib
  4. matplotlib.use('Agg')
  5. import pandas as pd
  6. from pandas.api.types import is_string_dtype
  7. from pandas.api.types import is_numeric_dtype
  8. from sklearn.preprocessing import LabelEncoder
  9. import matplotlib.pyplot as plt
  10. from sklearn.preprocessing import MinMaxScaler
  11. from sklearn.model_selection import train_test_split
  12. import os
  13. import pyreadr
  14. from sklearn.linear_model import LogisticRegression
  15. import joblib
  16. from sklearn.metrics import roc_curve, auc
  17. import numpy as np
  18. from celery import shared_task
  19. from files.models import Files as Files_models
  20. from datetime import datetime, timedelta
  21. from django.utils import timezone
  22. import os
  23. from order.models import Order as Order_models
  24. from files.models import Files as Files_models
  25. from celery.utils.log import get_task_logger
  26. logger = get_task_logger(__name__)
  27. @shared_task
  28. def get_AUC(order_id):
  29. matplotlib.use('Agg')
  30. print(31, order_id)
  31. # logger.info(28282828, order_id)
  32. # 更新order状态
  33. try:
  34. order = Order_models.objects.get(id=order_id)
  35. order.status = 2
  36. order.save()
  37. # return JsonResponse(setSuccess(msg='更新成功'))
  38. except Order_models.DoesNotExist:
  39. logging.warning(json.dumps({
  40. 'msg': '订单不存在',
  41. 'order_id': order_id
  42. }))
  43. # return JsonResponse(setFailure(msg='订单不存在'))
  44. files = Files_models.objects.filter(order_id=order_id)
  45. logger.info(files[0].file_path)
  46. logger.info('-----')
  47. logger.info(files[1].file_path)
  48. logger.info('-----')
  49. logger.info(str(order_id))
  50. # body = json.loads(request.body)
  51. # logger.info(26, body)
  52. # 用户订单数据
  53. base_path = 'Rdata/' + str(order_id) + '/'
  54. logger.info(base_path)
  55. # return
  56. if not os.path.exists(base_path):
  57. os.makedirs(base_path)
  58. def ROC_curve(clf, X_test, y_test, name):
  59. logger.info('130130130130 ' + name + ' gene_list.npy 2')
  60. # ROC曲线绘制,AUC值计算
  61. # 计算每一类的得分
  62. yscore_rf = clf.predict_proba(X_test)
  63. logger.info('6666666666 ')
  64. fpr_rf, tpr_rf, thersholds = roc_curve(y_test, yscore_rf[:, 1])
  65. logger.info('797979797979 ')
  66. plt.figure()
  67. logger.info('8080808080 ')
  68. plt.plot(fpr_rf, tpr_rf, color='darkorange',
  69. lw=2, label='ROC curve (arebody = json.loads(request.body)a = %0.2f)' % auc(fpr_rf, tpr_rf))
  70. plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
  71. plt.xlim([0.0, 1.0])
  72. plt.ylim([0.0, 1.05])
  73. plt.xlabel('False Positive Rate')
  74. plt.ylabel('True Positive Rate')
  75. # plt.title('Receiver operating characteristic example')
  76. plt.title(name)
  77. plt.legend(loc="lower right")
  78. # plt.savefig(plt.savefig('figs/savefig_example.png'))
  79. plt.savefig(base_path + name + '.png')
  80. # plt.show()
  81. def a_curve(clf, X_test, y_test, name):
  82. plt.plot(X_test, y_test, color='darkorange',
  83. lw=2, label='ROC curve (arebody = json.loads(request.body)a = %0.2f)' % auc(fpr_rf, tpr_rf))
  84. plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
  85. plt.xlim([0.0, 1.0])
  86. plt.ylim([0.0, 1.05])
  87. plt.xlabel('False Positive Rate')
  88. plt.ylabel('True Positive Rate')
  89. # plt.title('Receiver operating characteristic example')
  90. plt.title(name)
  91. plt.legend(loc="lower right")
  92. # plt.savefig(plt.savefig('figs/savefig_example.png'))
  93. plt.savefig(base_path + name + '.png')
  94. logger.info('797979')
  95. # 读取非胃癌数据
  96. data_health = pyreadr.read_r(str(files[0].file_path)) # 'COAD.Rdata'
  97. data_health = data_health['data'] # .head(10000) # 提取前10000行,后续把.head()删掉
  98. logger.info('8383838383')
  99. # 读取胃癌数据
  100. data_sick = pyreadr.read_r(str(files[1].file_path)) # 'STAD.Rdata'
  101. data_sick = data_sick['data'] # .head(10000) # 提取前10000行,后续把.head()删掉
  102. logger.info('1111')
  103. logger.info('8888888')
  104. # 数据处理
  105. data_health['D'] = 0
  106. data_sick['D'] = 1
  107. data = pd.concat([data_health, data_sick])
  108. # data_sick[['Variant_Classification','Tumor_Sample_Barcode']]
  109. sample_list = list(data.groupby(['Tumor_Sample_Barcode']).groups.keys()) # 把获取类别转为列表
  110. gene_list = list(data.groupby(['Variant_Classification']).groups.keys()) # 把获取类别转为列表
  111. np.save(base_path + 'gene_list.npy', np.array(gene_list)) # 保存为.npy格式
  112. np.save(base_path + 'sample_list.npy', np.array(sample_list)) # 保存为.npy格式
  113. logger.info('2222')
  114. # 计算基因数目
  115. list_final = gene_list + ['Tumor_Sample_Barcode', 'D']
  116. data_final = pd.DataFrame(columns=list_final)
  117. ind = 0
  118. for sample in sample_list:
  119. tempdata = data[data['Tumor_Sample_Barcode'] == sample]
  120. message = []
  121. for gene in gene_list:
  122. num = len(tempdata[tempdata['Variant_Classification'] == gene])
  123. message.append(num)
  124. message = message + [sample, tempdata.iloc[0][-1]]
  125. data_final.loc[ind] = message
  126. ind = ind + 1
  127. # data_final
  128. logger.info('113113113113113')
  129. # -------------------------------------SC_GRS方法-----------------------------------------
  130. data_final['Gi_sum'] = data_final[gene_list].sum(axis=1)
  131. # 模型训练
  132. X_train, X_test, y_train, y_test = train_test_split(data_final['Gi_sum'].values.reshape(-1, 1),
  133. data_final['D'].values.astype('int'), test_size=0.2)
  134. clf = LogisticRegression(penalty='l2')
  135. clf.fit(X_train, y_train)
  136. # 模型系数
  137. coef = clf.coef_
  138. intercept = clf.intercept_
  139. # 储存模型
  140. joblib.dump(clf, base_path + 'clf_scgrs.model')
  141. logger.info('3333')
  142. # ROC绘图 SC_GRS方法
  143. ROC_curve(clf, X_test, y_test, name='SC_GRS')
  144. logger.info('130130130130 '+ base_path + 'gene_list.npy')
  145. # 模型预测,假设data_final是输入数据,注意,data_final已经是处理好对应的基因数目的数据
  146. # 模型加载
  147. gene_list = np.load(base_path + 'gene_list.npy')
  148. logger.info('134134134134134134134')
  149. gene_list = gene_list.tolist()
  150. clf_scgrs = joblib.load(base_path + 'clf_scgrs.model')
  151. logger.info('137137137137137')
  152. # 数据处理
  153. X_test = data_final[gene_list].sum(axis=1).values.reshape(-1, 1)
  154. y_pre = clf.predict(X_test) # 分类结果预测
  155. prob_pre = clf.predict_proba(X_test) # 结果概率预测
  156. logger.info('4444')
  157. # logger.info(95, prob_pre)
  158. # -------------------------------------DL_GRS方法-----------------------------------------
  159. # 模型训练
  160. X_train, X_test, y_train, y_test = train_test_split(data_final[gene_list], data_final['D'].values.astype('int'),
  161. test_size=0.2)
  162. clf = LogisticRegression(penalty='l2', max_iter=1000)
  163. clf.fit(X_train, y_train)
  164. # 模型系数
  165. coef = clf.coef_
  166. intercept = clf.intercept_
  167. # 储存模型
  168. joblib.dump(clf, base_path + 'clf_dlgrs.model')
  169. # ROC曲线绘制,AUC值计算
  170. ROC_curve(clf, X_test, y_test, name='DL_GRS')
  171. # 模型预测,假设data_final是输入数据,注意,data_final已经是处理好对应的基因数目的数据
  172. # 模型加载
  173. gene_list = np.load(base_path + 'gene_list.npy')
  174. gene_list = gene_list.tolist()
  175. clf_scgrs = joblib.load(base_path + 'clf_dlgrs.model')
  176. # 数据处理
  177. X_test = data_final[gene_list]
  178. y_pre = clf.predict(X_test) # 分类结果预测
  179. prob_pre = clf.predict_proba(X_test) # 结果概率预测
  180. # logger.info(119, prob_pre)
  181. # -------------------------------------OR_GRS方法-----------------------------------------
  182. # or值计算
  183. tempdata = data_final[data_final['D'] == 0]
  184. gene_mean0 = tempdata[gene_list].mean()
  185. tempdata = data_final[data_final['D'] == 1]
  186. gene_mean1 = tempdata[gene_list].mean()
  187. omega_or = np.log(gene_mean1 / gene_mean0)
  188. # 模型训练
  189. data_final['GRS'] = (data_final[gene_list] * omega_or).sum(axis=1)
  190. X_train, X_test, y_train, y_test = train_test_split(data_final['GRS'].values.reshape(-1, 1),
  191. data_final['D'].values.astype('int'), test_size=0.2)
  192. clf = LogisticRegression(penalty='l2')
  193. clf.fit(X_train, y_train)
  194. # 模型系数
  195. coef = clf.coef_
  196. intercept = clf.intercept_
  197. # 储存模型
  198. joblib.dump(clf, base_path + 'clf_orgrs.model')
  199. np.save(base_path + 'omega_or.npy', np.array(omega_or)) # 保存为.npy格式
  200. # ROC曲线绘制,AUC值计算 OR_GRS方法
  201. ROC_curve(clf, X_test, y_test, name='OR_GRS')
  202. # 模型预测,假设data_final是输入数据,注意,data_final已经是处理好对应的基因数目的数据
  203. # 模型加载
  204. gene_list = np.load(base_path + 'gene_list.npy')
  205. gene_list = gene_list.tolist()
  206. clf_scgrs = joblib.load(base_path + 'clf_dlgrs.model')
  207. omega_or = np.load(base_path + 'omega_or.npy')
  208. # 数据处理
  209. X_test = (data_final[gene_list] * omega_or).sum(axis=1).values.reshape(-1, 1)
  210. y_pre = clf.predict(X_test) # 分类结果预测
  211. prob_pre = clf.predict_proba(X_test) # 结果概率预测
  212. # logger.info(154, prob_pre)
  213. # 更新order状态
  214. try:
  215. order = Order_models.objects.get(id=order_id)
  216. order.status = 3
  217. order.save()
  218. # return JsonResponse(setSuccess(msg='更新成功'))
  219. except Order_models.DoesNotExist:
  220. logging.warning(json.dumps({
  221. 'msg': '订单不存在',
  222. 'order_id': order_id
  223. }))