nri.py 1.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445
  1. import pandas as pd
  2. from sklearn.linear_model import LogisticRegression
  3. from sklearn.metrics import classification_report
  4. from sklearn.metrics.regression import net_reclassification_index
  5. import matplotlib.pyplot as plt
  6. # 加载数据
  7. snp_data = pd.read_csv('snp_data.csv') # 假设SNP数据已保存在CSV文件中
  8. # other_data = pd.read_csv('other_data.csv') # 假设其他特征数据已保存在CSV文件中
  9. target = pd.read_csv('target.csv') # 假设目标数据已保存在CSV文件中
  10. # 合并数据
  11. data = pd.concat([snp_data, target], axis=1)
  12. # 划分训练集和测试集
  13. train_data = data.sample(frac=0.8, random_state=1)
  14. test_data = data.drop(train_data.index)
  15. # 训练基线模型和改进模型
  16. X_train = train_data.drop('target', axis=1)
  17. y_train = train_data['target']
  18. X_test = test_data.drop('target', axis=1)
  19. y_test = test_data['target']
  20. base_model = LogisticRegression(random_state=1)
  21. base_model.fit(X_train, y_train)
  22. improved_model = LogisticRegression(random_state=1, solver='liblinear', penalty='l1')
  23. improved_model.fit(X_train, y_train)
  24. # 使用模型进行分类并计算NRI指标
  25. base_proba = base_model.predict_proba(X_test)
  26. improved_proba = improved_model.predict_proba(X_test)
  27. nri = net_reclassification_index(y_test, base_proba[:, 1], improved_proba[:, 1])
  28. # 输出NRI指标
  29. print('NRI:', nri)
  30. # 可视化NRI指标
  31. nri_df = pd.DataFrame({'Model': ['Baseline', 'Improved'], 'NRI': [0, nri]})
  32. plt.bar(nri_df['Model'], nri_df['NRI'], color=['#1f77b4', '#ff7f0e'])
  33. plt.ylim([0, 1])
  34. plt.xlabel('Model')
  35. plt.ylabel('NRI')
  36. plt.title('Net Reclassification Improvement')
  37. plt.show()