当前位置:网站首页>lgb,xgb,cat k折交叉验证
lgb,xgb,cat k折交叉验证
2022-04-23 07:58:00 【鈴音.】
import lightgbm as lgb
import xgboost as xgb
import catboost as cat
#传入:模型,训练集x,训练集y,测试集x,模型的名字
def cv_model(clf, train_x, train_y, test_x, clf_name):
#折叠次数5,随机数种子2020
folds = 5
seed = 2020
#分成4份训练集,1份测试集
kf = KFold(n_splits=folds, shuffle=True, random_state=seed)
#返回来一个给定形状和类型的用0填充的数组
train = np.zeros(train_x.shape[0])
test = np.zeros(test_x.shape[0])
cv_scores = []
#生成3个元组,for循环迭代这三个元组,得到train和test的索引
for i, (train_index, valid_index) in enumerate(kf.split(train_x, train_y)):
print('************************************ {} ************************************'.format(str(i+1)))
#训练集x,训练集y,测试集x,测试集y
trn_x, trn_y, val_x, val_y = train_x.iloc[train_index], train_y[train_index], train_x.iloc[valid_index], train_y[valid_index]
if clf_name == "lgb":
#训练数据要放到Dataset中供lgb使用
train_matrix = clf.Dataset(trn_x, label=trn_y)
valid_matrix = clf.Dataset(val_x, label=val_y)
#参数
params = {
'boosting_type': 'gbdt',
'objective': 'binary',
'metric': 'auc',
'min_child_weight': 5,
'num_leaves': 2 ** 5,
'lambda_l2': 10,
'feature_fraction': 0.8,
'bagging_fraction': 0.8,
'bagging_freq': 4,
'learning_rate': 0.1,
'seed': 2020,
'nthread': 28,
'n_jobs':24,
'silent': True,
'verbose': -1,
}
model = clf.train(params, train_matrix, 50000, valid_sets=[train_matrix, valid_matrix], verbose_eval=200,early_stopping_rounds=200)
val_pred = model.predict(val_x, num_iteration=model.best_iteration)
test_pred = model.predict(test_x, num_iteration=model.best_iteration)
# print(list(sorted(zip(features, model.feature_importance("gain")), key=lambda x: x[1], reverse=True))[:20])
if clf_name == "xgb":
train_matrix = clf.DMatrix(trn_x , label=trn_y)
valid_matrix = clf.DMatrix(val_x , label=val_y)
test_matrix = clf.DMatrix(test_x)
params = {
'booster': 'gbtree',
'objective': 'binary:logistic',
'eval_metric': 'auc',
'gamma': 1,
'min_child_weight': 1.5,
'max_depth': 5,
'lambda': 10,
'subsample': 0.7,
'colsample_bytree': 0.7,
'colsample_bylevel': 0.7,
'eta': 0.04,
'tree_method': 'exact',
'seed': 2020,
'nthread': 36,
"silent": True,
}
watchlist = [(train_matrix, 'train'),(valid_matrix, 'eval')]
model = clf.train(params, train_matrix, num_boost_round=50000, evals=watchlist, verbose_eval=200, early_stopping_rounds=200)
val_pred = model.predict(valid_matrix, ntree_limit=model.best_ntree_limit)
test_pred = model.predict(test_matrix , ntree_limit=model.best_ntree_limit)
if clf_name == "cat":
params = {
'learning_rate': 0.05, 'depth': 5, 'l2_leaf_reg': 10, 'bootstrap_type': 'Bernoulli',
'od_type': 'Iter', 'od_wait': 50, 'random_seed': 11, 'allow_writing_files': False}
model = clf(iterations=20000, **params)
model.fit(trn_x, trn_y, eval_set=(val_x, val_y),
cat_features=[], use_best_model=True, verbose=500)
val_pred = model.predict(val_x)
test_pred = model.predict(test_x)
train[valid_index] = val_pred
test = test_pred / kf.n_splits
cv_scores.append(roc_auc_score(val_y, val_pred))
print(cv_scores)
print("%s_scotrainre_list:" % clf_name, cv_scores)
print("%s_score_mean:" % clf_name, np.mean(cv_scores))
print("%s_score_std:" % clf_name, np.std(cv_scores))
return train, test
版权声明
本文为[鈴音.]所创,转载请带上原文链接,感谢
https://blog.csdn.net/qq_52691614/article/details/124334578
边栏推荐
猜你喜欢
stm32以及freertos 堆栈解析
跨域配置报错: When allowCredentials is true, allowedOrigins cannot contain the special value “*“
Community group purchase applet source code + interface DIY + nearby leader + supplier + group collage + recipe + second kill + pre-sale + distribution + live broadcast
How to generate assembly file
The simple problem of leetcode is to calculate the numerical sum of strings
【深度好文】Flink SQL流批⼀体化技术详解(一)
项目上传部分
396. Rotate Function
Failed to convert a NumPy array to a Tensor(Unsupported Object type int)
Campus transfer second-hand market source code download
随机推荐
WordPress love navigation theme 1.1.3 simple atmosphere website navigation source code website navigation source code
对li类数组对象随机添加特性,并进行排序
Comparison of indoor positioning technology
Record: JS several methods to delete one or more items in the array
Navicat远程连接mysql
The third divisor of leetcode simple question
【学习】从零开始的音视频开发(9)——NuPlayer
DOM learning - add + - button
2022.4.11-4.17 AI行业周刊(第93期):AI行业的困局
室内定位技术对比
pgsql想实现mysql一样样的列子查询操作
信息收集相关知识点及题解
Let the earth have less "carbon" and rest on the road
Ansible Automation Operation and Maintenance details (ⅰ) Installation and Deployment, Parameter use, list Management, Profile Parameters and user level ansible operating environment Construction
DOM 学习之—添加+-按钮
JSP page coding
Ear acupoint diagnosis and treatment essay 0421
Qt读写XML文件
npm安装yarn
5.6 comprehensive case - RTU-