tasks.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320
  1. import json
  2. import logging
  3. import time
  4. import matplotlib
  5. matplotlib.use('Agg')
  6. import pandas as pd
  7. from pandas.api.types import is_string_dtype
  8. from pandas.api.types import is_numeric_dtype
  9. from sklearn.preprocessing import LabelEncoder
  10. import matplotlib.pyplot as plt
  11. from sklearn.preprocessing import MinMaxScaler
  12. from sklearn.model_selection import train_test_split
  13. import os
  14. import pyreadr
  15. from sklearn.linear_model import LogisticRegression
  16. import joblib
  17. from sklearn.metrics import roc_curve, auc
  18. import numpy as np
  19. from celery import shared_task
  20. from files.models import Files as Files_models
  21. from datetime import datetime, timedelta
  22. from django.utils import timezone
  23. import os
  24. from order.models import Order as Order_models
  25. from files.models import Files as Files_models
  26. from celery.utils.log import get_task_logger
  27. logger = get_task_logger(__name__)
  28. @shared_task
  29. def get_AUC(order_id):
  30. matplotlib.use('Agg')
  31. print(31, order_id)
  32. # logger.info(28282828, order_id)
  33. # 更新order状态
  34. try:
  35. order = Order_models.objects.get(id=order_id)
  36. order.status = 2
  37. order.save()
  38. # return JsonResponse(setSuccess(msg='更新成功'))
  39. except Order_models.DoesNotExist:
  40. logging.warning(json.dumps({
  41. 'msg': '订单不存在',
  42. 'order_id': order_id
  43. }))
  44. # return JsonResponse(setFailure(msg='订单不存在'))
  45. files = Files_models.objects.filter(order_id=order_id)
  46. logger.info(files[0].file_path)
  47. logger.info('-----')
  48. logger.info(files[1].file_path)
  49. logger.info('-----')
  50. logger.info(str(order_id))
  51. # body = json.loads(request.body)
  52. # logger.info(26, body)
  53. # 用户订单数据
  54. base_path = 'Rdata/' + str(order_id) + '/'
  55. logger.info(base_path)
  56. # return
  57. if not os.path.exists(base_path):
  58. os.makedirs(base_path)
  59. def ROC_curve(clf, X_test, y_test, name):
  60. logger.info('130130130130 ' + name + ' gene_list.npy 2')
  61. # ROC曲线绘制,AUC值计算
  62. # 计算每一类的得分
  63. yscore_rf = clf.predict_proba(X_test)
  64. logger.info('6666666666 ')
  65. fpr_rf, tpr_rf, thersholds = roc_curve(y_test, yscore_rf[:, 1])
  66. logger.info('797979797979 ')
  67. plt.figure()
  68. logger.info('8080808080 ')
  69. plt.plot(fpr_rf, tpr_rf, color='darkorange',
  70. lw=2, label='ROC curve (arebody = json.loads(request.body)a = %0.2f)' % auc(fpr_rf, tpr_rf))
  71. plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
  72. plt.xlim([0.0, 1.0])
  73. plt.ylim([0.0, 1.05])
  74. plt.xlabel('False Positive Rate')
  75. plt.ylabel('True Positive Rate')
  76. # plt.title('Receiver operating characteristic example')
  77. plt.title(name)
  78. plt.legend(loc="lower right")
  79. # plt.savefig(plt.savefig('figs/savefig_example.png'))
  80. plt.savefig(base_path + name + '.png')
  81. # plt.show()
  82. logger.info('797979')
  83. # 读取非胃癌数据
  84. data_health = pyreadr.read_r(str(files[0].file_path)) # 'COAD.Rdata'
  85. data_health = data_health['data'] # .head(10000) # 提取前10000行,后续把.head()删掉
  86. logger.info('8383838383')
  87. # 读取胃癌数据
  88. data_sick = pyreadr.read_r(str(files[1].file_path)) # 'STAD.Rdata'
  89. data_sick = data_sick['data'] # .head(10000) # 提取前10000行,后续把.head()删掉
  90. logger.info('1111')
  91. logger.info('8888888')
  92. # 数据处理
  93. data_health['D'] = 0
  94. data_sick['D'] = 1
  95. data = pd.concat([data_health, data_sick])
  96. # data_sick[['Variant_Classification','Tumor_Sample_Barcode']]
  97. sample_list = list(data.groupby(['Tumor_Sample_Barcode']).groups.keys()) # 把获取类别转为列表
  98. gene_list = list(data.groupby(['Variant_Classification']).groups.keys()) # 把获取类别转为列表
  99. np.save(base_path + 'gene_list.npy', np.array(gene_list)) # 保存为.npy格式
  100. np.save(base_path + 'sample_list.npy', np.array(sample_list)) # 保存为.npy格式
  101. logger.info('2222')
  102. # 计算基因数目
  103. list_final = gene_list + ['Tumor_Sample_Barcode', 'D']
  104. data_final = pd.DataFrame(columns=list_final)
  105. def save_Variant_Type():
  106. Variant_Type = data['Variant_Type'].value_counts()
  107. # 创建一个条形图
  108. plt.bar(Variant_Type.index, Variant_Type.values)
  109. # 添加标题和轴标签
  110. plt.title('Variant Type Counts')
  111. plt.xlabel('Variant Type')
  112. plt.ylabel('Count')
  113. # 保存图形为 PNG 文件
  114. plt.savefig(base_path + 'Variant_Type' + '.png')
  115. def save_Variant_Type2(name, max_value=100):
  116. Variant_Type2 = data[name].value_counts()
  117. # 创建一个条形图
  118. plt.plot(Variant_Type2.values, Variant_Type2.index)
  119. plt.ylim([0.0, max_value])
  120. plt.xlim([0.0, len(Variant_Type2.index)])
  121. # 添加标题和轴标签
  122. plt.title(name)
  123. plt.xlabel(name)
  124. plt.ylabel('Count')
  125. # 保存图形为 PNG 文件
  126. plt.savefig(base_path + name + '.png')
  127. def save_end_data(self_prob_pre, name, pro_name): # 输出大于0.5的数据
  128. second_column = self_prob_pre[:, 1]
  129. # 找到大于0.5的数据的下标
  130. indices = np.where(second_column >= 0.5)[0]
  131. # print(144, pro_name, indices)
  132. with open(f'{base_path}{name}.txt', 'w') as f:
  133. for item in indices:
  134. f.write("%s\n" % item)
  135. #
  136. # # 绘制折线图
  137. # # plt.plot(second_column)
  138. # plt.plot(indices, second_column[indices], 'ro')
  139. # # 设置 x 轴和 y 轴标签
  140. # plt.title(f"{pro_name} Probability prediction greater than 0.5")
  141. # plt.xlabel("Index")
  142. # plt.ylabel("Value")
  143. # # 保存图像
  144. # plt.savefig(f"{base_path}{name}概率预测.png", format="png")
  145. # # 清除图像
  146. # plt.clf()
  147. # time.sleep(1)
  148. # save_Variant_Type()
  149. # time.sleep(1)
  150. # save_Variant_Type2('t_ref_count', max_value=1000)
  151. # time.sleep(1)
  152. # save_Variant_Type2('t_alt_count', max_value=1000)
  153. # time.sleep(1)
  154. # save_Variant_Type2('t_depth', max_value=1000)
  155. # time.sleep(1)
  156. # 绘制图像
  157. # plt.imshow(data['t_depth'])
  158. # 保存为图片文件
  159. # plt.savefig(base_path + "t_depth.png", format='png')
  160. # t_depth
  161. #
  162. # t_ref_count
  163. #
  164. # t_alt_count
  165. #
  166. # n_depth
  167. ind = 0
  168. for sample in sample_list:
  169. tempdata = data[data['Tumor_Sample_Barcode'] == sample]
  170. message = []
  171. for gene in gene_list:
  172. num = len(tempdata[tempdata['Variant_Classification'] == gene])
  173. message.append(num)
  174. message = message + [sample, tempdata.iloc[0][-1]]
  175. data_final.loc[ind] = message
  176. ind = ind + 1
  177. # data_final
  178. logger.info('113113113113113')
  179. # -------------------------------------SC_GRS方法-----------------------------------------
  180. data_final['Gi_sum'] = data_final[gene_list].sum(axis=1)
  181. # 模型训练
  182. X_train, X_test, y_train, y_test = train_test_split(data_final['Gi_sum'].values.reshape(-1, 1),
  183. data_final['D'].values.astype('int'), test_size=0.2)
  184. clf = LogisticRegression(penalty='l2')
  185. clf.fit(X_train, y_train)
  186. # 模型系数
  187. coef = clf.coef_
  188. intercept = clf.intercept_
  189. # 储存模型
  190. joblib.dump(clf, base_path + 'clf_scgrs.model')
  191. logger.info('3333')
  192. # ROC绘图 SC_GRS方法
  193. ROC_curve(clf, X_test, y_test, name='SC_GRS')
  194. logger.info('130130130130 ' + base_path + 'gene_list.npy')
  195. # 模型预测,假设data_final是输入数据,注意,data_final已经是处理好对应的基因数目的数据
  196. # 模型加载
  197. gene_list = np.load(base_path + 'gene_list.npy')
  198. logger.info('134134134134134134134')
  199. gene_list = gene_list.tolist()
  200. clf_scgrs = joblib.load(base_path + 'clf_scgrs.model')
  201. logger.info('137137137137137')
  202. # 数据处理
  203. X_test = data_final[gene_list].sum(axis=1).values.reshape(-1, 1)
  204. y_pre = clf.predict(X_test) # 分类结果预测
  205. prob_pre = clf.predict_proba(X_test) # 结果概率预测
  206. logger.info('4444')
  207. # 绘制图像
  208. plt.imshow(prob_pre)
  209. # 保存为图片文件
  210. plt.savefig(base_path + "SC_GRS_prob_pre.png", format='png')
  211. # second_column = prob_pre[1, :]
  212. # # 找到大于0.5的数据的下标
  213. # indices = np.where(second_column > 0.5)
  214. # # 绘制折线图
  215. # plt.plot(second_column)
  216. # plt.plot(indices, second_column[indices], 'ro')
  217. # # 设置 x 轴和 y 轴标签
  218. # plt.title('SC GRS Probability prediction greater than 0.5')
  219. # plt.xlabel('Index')
  220. # plt.ylabel('Value')
  221. # # 显示图形
  222. # plt.savefig(base_path + "GRS概率预测.png", format='png')
  223. save_end_data(prob_pre, 'SC_GRS', 'SC GRS')
  224. # -------------------------------------DL_GRS方法-----------------------------------------
  225. # 模型训练
  226. X_train, X_test, y_train, y_test = train_test_split(data_final[gene_list], data_final['D'].values.astype('int'),
  227. test_size=0.2)
  228. clf = LogisticRegression(penalty='l2', max_iter=1000)
  229. clf.fit(X_train, y_train)
  230. # 模型系数
  231. coef = clf.coef_
  232. intercept = clf.intercept_
  233. # 储存模型
  234. joblib.dump(clf, base_path + 'clf_dlgrs.model')
  235. # ROC曲线绘制,AUC值计算
  236. ROC_curve(clf, X_test, y_test, name='DL_GRS')
  237. # 模型预测,假设data_final是输入数据,注意,data_final已经是处理好对应的基因数目的数据
  238. # 模型加载
  239. gene_list = np.load(base_path + 'gene_list.npy')
  240. gene_list = gene_list.tolist()
  241. clf_scgrs = joblib.load(base_path + 'clf_dlgrs.model')
  242. # 数据处理
  243. X_test = data_final[gene_list]
  244. y_pre = clf.predict(X_test) # 分类结果预测
  245. prob_pre = clf.predict_proba(X_test) # 结果概率预测
  246. # 绘制图像
  247. plt.imshow(prob_pre)
  248. # 保存为图片文件
  249. plt.savefig(base_path + "DL_GRS_prob_pre.png", format='png')
  250. save_end_data(prob_pre, 'DL_GRS', 'DL GRS')
  251. # -------------------------------------OR_GRS方法-----------------------------------------
  252. # or值计算
  253. tempdata = data_final[data_final['D'] == 0]
  254. gene_mean0 = tempdata[gene_list].mean()
  255. tempdata = data_final[data_final['D'] == 1]
  256. gene_mean1 = tempdata[gene_list].mean()
  257. omega_or = np.log(gene_mean1 / gene_mean0)
  258. # 模型训练
  259. data_final['GRS'] = (data_final[gene_list] * omega_or).sum(axis=1)
  260. X_train, X_test, y_train, y_test = train_test_split(data_final['GRS'].values.reshape(-1, 1),
  261. data_final['D'].values.astype('int'), test_size=0.2)
  262. clf = LogisticRegression(penalty='l2')
  263. clf.fit(X_train, y_train)
  264. # 模型系数
  265. coef = clf.coef_
  266. intercept = clf.intercept_
  267. # 储存模型
  268. joblib.dump(clf, base_path + 'clf_orgrs.model')
  269. np.save(base_path + 'omega_or.npy', np.array(omega_or)) # 保存为.npy格式
  270. # ROC曲线绘制,AUC值计算 OR_GRS方法
  271. ROC_curve(clf, X_test, y_test, name='OR_GRS')
  272. # 模型预测,假设data_final是输入数据,注意,data_final已经是处理好对应的基因数目的数据
  273. # 模型加载
  274. gene_list = np.load(base_path + 'gene_list.npy')
  275. gene_list = gene_list.tolist()
  276. clf_scgrs = joblib.load(base_path + 'clf_dlgrs.model')
  277. omega_or = np.load(base_path + 'omega_or.npy')
  278. # 数据处理
  279. X_test = (data_final[gene_list] * omega_or).sum(axis=1).values.reshape(-1, 1)
  280. y_pre = clf.predict(X_test) # 分类结果预测
  281. prob_pre = clf.predict_proba(X_test) # 结果概率预测
  282. # 绘制图像
  283. plt.imshow(prob_pre)
  284. # 保存为图片文件
  285. plt.savefig(base_path + "OR_GRS_prob_pre.png", format='png')
  286. save_end_data(prob_pre, 'OR_GRS', 'OR GRS')
  287. # logger.info(154, prob_pre)
  288. # 更新order状态
  289. try:
  290. order = Order_models.objects.get(id=order_id)
  291. order.status = 3
  292. order.save()
  293. # return JsonResponse(setSuccess(msg='更新成功'))
  294. except Order_models.DoesNotExist:
  295. logging.warning(json.dumps({
  296. 'msg': '订单不存在',
  297. 'order_id': order_id
  298. }))