当前位置:网站首页>lgb,xgb,cat k折交叉验证
lgb,xgb,cat k折交叉验证
2022-04-23 07:58:00 【鈴音.】
import lightgbm as lgb
import xgboost as xgb
import catboost as cat
#传入:模型,训练集x,训练集y,测试集x,模型的名字
def cv_model(clf, train_x, train_y, test_x, clf_name):
#折叠次数5,随机数种子2020
folds = 5
seed = 2020
#分成4份训练集,1份测试集
kf = KFold(n_splits=folds, shuffle=True, random_state=seed)
#返回来一个给定形状和类型的用0填充的数组
train = np.zeros(train_x.shape[0])
test = np.zeros(test_x.shape[0])
cv_scores = []
#生成3个元组,for循环迭代这三个元组,得到train和test的索引
for i, (train_index, valid_index) in enumerate(kf.split(train_x, train_y)):
print('************************************ {} ************************************'.format(str(i+1)))
#训练集x,训练集y,测试集x,测试集y
trn_x, trn_y, val_x, val_y = train_x.iloc[train_index], train_y[train_index], train_x.iloc[valid_index], train_y[valid_index]
if clf_name == "lgb":
#训练数据要放到Dataset中供lgb使用
train_matrix = clf.Dataset(trn_x, label=trn_y)
valid_matrix = clf.Dataset(val_x, label=val_y)
#参数
params = {
'boosting_type': 'gbdt',
'objective': 'binary',
'metric': 'auc',
'min_child_weight': 5,
'num_leaves': 2 ** 5,
'lambda_l2': 10,
'feature_fraction': 0.8,
'bagging_fraction': 0.8,
'bagging_freq': 4,
'learning_rate': 0.1,
'seed': 2020,
'nthread': 28,
'n_jobs':24,
'silent': True,
'verbose': -1,
}
model = clf.train(params, train_matrix, 50000, valid_sets=[train_matrix, valid_matrix], verbose_eval=200,early_stopping_rounds=200)
val_pred = model.predict(val_x, num_iteration=model.best_iteration)
test_pred = model.predict(test_x, num_iteration=model.best_iteration)
# print(list(sorted(zip(features, model.feature_importance("gain")), key=lambda x: x[1], reverse=True))[:20])
if clf_name == "xgb":
train_matrix = clf.DMatrix(trn_x , label=trn_y)
valid_matrix = clf.DMatrix(val_x , label=val_y)
test_matrix = clf.DMatrix(test_x)
params = {
'booster': 'gbtree',
'objective': 'binary:logistic',
'eval_metric': 'auc',
'gamma': 1,
'min_child_weight': 1.5,
'max_depth': 5,
'lambda': 10,
'subsample': 0.7,
'colsample_bytree': 0.7,
'colsample_bylevel': 0.7,
'eta': 0.04,
'tree_method': 'exact',
'seed': 2020,
'nthread': 36,
"silent": True,
}
watchlist = [(train_matrix, 'train'),(valid_matrix, 'eval')]
model = clf.train(params, train_matrix, num_boost_round=50000, evals=watchlist, verbose_eval=200, early_stopping_rounds=200)
val_pred = model.predict(valid_matrix, ntree_limit=model.best_ntree_limit)
test_pred = model.predict(test_matrix , ntree_limit=model.best_ntree_limit)
if clf_name == "cat":
params = {
'learning_rate': 0.05, 'depth': 5, 'l2_leaf_reg': 10, 'bootstrap_type': 'Bernoulli',
'od_type': 'Iter', 'od_wait': 50, 'random_seed': 11, 'allow_writing_files': False}
model = clf(iterations=20000, **params)
model.fit(trn_x, trn_y, eval_set=(val_x, val_y),
cat_features=[], use_best_model=True, verbose=500)
val_pred = model.predict(val_x)
test_pred = model.predict(test_x)
train[valid_index] = val_pred
test = test_pred / kf.n_splits
cv_scores.append(roc_auc_score(val_y, val_pred))
print(cv_scores)
print("%s_scotrainre_list:" % clf_name, cv_scores)
print("%s_score_mean:" % clf_name, np.mean(cv_scores))
print("%s_score_std:" % clf_name, np.std(cv_scores))
return train, test
版权声明
本文为[鈴音.]所创,转载请带上原文链接,感谢
https://blog.csdn.net/qq_52691614/article/details/124334578
边栏推荐
- ELK生产实践
- form中enctype属性
- 关于ORB——SLAM运行中关键帧位置越来越近的异常说明
- vmware 搭建ES8的常见错误
- dmp引擎工作总结(2021,光剑)
- redis主从服务器问题
- npm安装yarn
- [effective go Chinese translation] part I
- Qt读取路径下所有文件或指定类型文件(含递归、判断是否为空、创建路径)
- Listed on the Shenzhen Stock Exchange: the market value is 5.2 billion yuan. Lu is the East and his daughter is American
猜你喜欢

SYS_ CONNECT_ BY_ Path (column, 'char') combined with start with connect by prior

Shell script advanced

LeetCode簡單題之計算字符串的數字和

Transformer XL: attention language modelsbbeyond a fixed length context paper summary

Online yaml to XML tool

作文以记之 ~ 二叉树的后序遍历

ASAN 极简原理

Asan minimalism

396. Rotate Function

2022.4.11-4.17 AI行业周刊(第93期):AI行业的困局
随机推荐
Failed to convert a NumPy array to a Tensor(Unsupported Object type int)
JSP page coding
一键清理项目下pycharm和Jupyter缓存文件
分布式消息中间件框架选型-数字化架构设计(7)
Goland 调试go使用-大白记录
Description of the abnormity that the key frame is getting closer and closer in the operation of orb slam
RPC过程
pdf加水印
Record: JS several methods to delete one or more items in the array
什么是RPC
Qtablewidget header customization and beautification developed by pyqt5 (with source code download)
作文以记之 ~ 二叉树的后序遍历
396. Rotate Function
QT reads all files under the path or files of the specified type (including recursion, judging whether it is empty and creating the path)
Asan minimalism
ansible自动化运维详解(一)ansible的安装部署、参数使用、清单管理、配置文件参数及用户级ansible操作环境构建
Use of applicationreadyevent
Search the complete navigation program source code
excle加水印
Trust uses Tokio's notify and timeout to achieve the effect similar to the timeout condition variable