当前位置:网站首页>使用XGboost进行分类,判断该患者是否患有糖尿病
使用XGboost进行分类,判断该患者是否患有糖尿病
2022-08-08 06:20:00 【波尔德】
详细代码在此:
# First XGBoost model for Pima Indians dataset
# 使用 XGBoost算法+给定的一堆参数进行分类,分为2类:该患者是否患有糖尿病的分类。
from numpy import loadtxt
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
from xgboost import XGBClassifier
# 读取csv文件
dataset = loadtxt('pima-indians-diabetes.csv', delimiter=',')
# split data into X and y
X = dataset[:, 0:8] # 前8列是特征
Y = dataset[:, 8] # 最后一列是label
# split data into train and test sets
seed = 7 # 随机种子
test_size = 0.33 # 67%是训练数据,33%是测试数据
# 数据集切分
X_train, X_test, y_train, y_test = \
train_test_split(X, Y, test_size=test_size, random_state=seed)
model = XGBClassifier() # 拿到XGBoost分类模型
# fit model
model.fit(X_train, y_train)
# make predictions for test data
y_pred = model.predict(X_test)
predictions = [round(value) for value in y_pred]
# evaluate predictions
accuracy = accuracy_score(y_test, predictions)
# 输出小数点后2位
print("Accuracy:%.2f%%" % (accuracy * 100.0))
输出结果:
使用eval_set,每加上一个模型,我们都可以对它的分类效果进行监控。
# First XGBoost model for Pima Indians dataset
# 每加上一个模型,我们都可以对它的分类效果进行监控。
from numpy import loadtxt
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
from xgboost import XGBClassifier
# load data
dataset = loadtxt('pima-indians-diabetes.csv', delimiter=',')
# split data into X and y
X = dataset[:, 0:8] # 前8列是特征
Y = dataset[:, 8] # 最后一列是label
# split data into train and test sets
seed = 7 # 随机种子
test_size = 0.33 # 67%是训练数据,33%是测试数据
# 数据集切分
X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=test_size, random_state=seed)
# fit model no training data
model = XGBClassifier() # 拿到模型
# 每加入一个模型,就会拿eval_set进行测试
eval_set = [(X_test,y_test)]
# 如果连续10次,loss值都是没有下降的。我们就停止训练。
# 评估表标准 logloss
# verbose(啰嗦) 每加入一个模型都会打印当前效果
model.fit(X_train, y_train,early_stopping_rounds=10,eval_metric="logloss",eval_set=eval_set,verbose=True)
# make predictions for test data
y_pred = model.predict(X_test)
predictions = [round(value) for value in y_pred]
# evaluate predictions
accuracy = accuracy_score(y_test, predictions)
# 输出小数点后2位
print("Accuracy:%.2f%%" % (accuracy * 100.0))
使用plot_importance ,绘制特征值的重要程度
from numpy import loadtxt
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
from xgboost import XGBClassifier
from xgboost import plot_importance #plot_importance 可以绘制特征值的重要程度
from matplotlib import pyplot
# load data
dataset = loadtxt('pima-indians-diabetes.csv', delimiter=',')
# split data into X and y
X = dataset[:, 0:8] # 前8列是特征
y = dataset[:, 8] # 最后一列是label
# fit model no training data
model = XGBClassifier()
model.fit(X,y)
# plot feature importance
plot_importance(model)
pyplot.show()

使用GridSearchCV挑选出最好的学习率
from numpy import loadtxt
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
from xgboost import XGBClassifier
from sklearn.model_selection import GridSearchCV # 遍历后可以通过GridSearchCV挑选出最好的学习率
from sklearn.model_selection import StratifiedKFold
# load data
dataset = loadtxt('pima-indians-diabetes.csv', delimiter=',')
# split data into X and y
X = dataset[:, 0:8] # 前8列是特征
y = dataset[:, 8] # 最后一列是label
# fit model no training data
model = XGBClassifier()
learning_rate = [0.0001,0.001,0.01,0.1,0.2,0.3]
# 把学习率转换成字典
param_grid =dict(learning_rate = learning_rate)
# 交叉验证
kfold = StratifiedKFold(n_splits=10,shuffle=True,random_state=7)
# n_jobs 当前所有空闲的CPU都去进行计算
grid_search = GridSearchCV(model,param_grid,scoring='neg_log_loss',n_jobs=-1,cv=kfold)
grid_result = grid_search.fit(X,y)
# summarize results
print(" Best: %f using %s" %(grid_result. best_score_ ,grid_result.best_params_))
means = grid_result.cv_results_['mean_test_score']
params = grid_result.cv_results_['params']
for mean,param in zip(means,params):
print(" %f with %r "%(mean,param))
运行结果如下:
边栏推荐
猜你喜欢
随机推荐
torch.gather() usage interpretation
4G/5G频谱资源协同关键技术
时钟的同步与异步问题
神经网络一般训练多少次,神经网络训练时间过长
Horizontal version of the generated image uniapp H5 signature
navicat15 连接Oracle数据库 报错ORA-28547: connection to server failed, probable Oracle Net admin error的解决方案
学习残差神经网络(ResNet)
【分布式】链路追踪 jaeger
uniapp H5 签名横版生成图片
oracle的插入sql错误
MySQL数据库
状态压缩复习
don't know what to name
apifox使用文档之环境变量 / 全局变量 / 临时变量附apifox学习路线图
七千字带你了解封装等机制
KDD'22 Recommendation System Papers (24 Research & 36 Application Papers)
轮播文字! QPainter
clue binary tree
Tensorboard的使用 ---- SummaryWriter类(pytorch版)
Pit Filling Simulated Hash Table









