当前位置:网站首页>LGB, XGB, cat, k-fold cross validation
LGB, XGB, cat, k-fold cross validation
2022-04-23 08:56:00 【Ring tone】
import lightgbm as lgb
import xgboost as xgb
import catboost as cat
# Pass in : Model , Training set x, Training set y, Test set x, The name of the model
def cv_model(clf, train_x, train_y, test_x, clf_name):
# Fold times 5, Random number seed 2020
folds = 5
seed = 2020
# Divide into 4 Training set ,1 Test set
kf = KFold(n_splits=folds, shuffle=True, random_state=seed)
# Return to a given shape and type 0 Filled array
train = np.zeros(train_x.shape[0])
test = np.zeros(test_x.shape[0])
cv_scores = []
# Generate 3 Tuples ,for Loop over these three tuples , obtain train and test The index of
for i, (train_index, valid_index) in enumerate(kf.split(train_x, train_y)):
print('************************************ {} ************************************'.format(str(i+1)))
# Training set x, Training set y, Test set x, Test set y
trn_x, trn_y, val_x, val_y = train_x.iloc[train_index], train_y[train_index], train_x.iloc[valid_index], train_y[valid_index]
if clf_name == "lgb":
# Training data should be put into Dataset Medium supply lgb Use
train_matrix = clf.Dataset(trn_x, label=trn_y)
valid_matrix = clf.Dataset(val_x, label=val_y)
# Parameters
params = {
'boosting_type': 'gbdt',
'objective': 'binary',
'metric': 'auc',
'min_child_weight': 5,
'num_leaves': 2 ** 5,
'lambda_l2': 10,
'feature_fraction': 0.8,
'bagging_fraction': 0.8,
'bagging_freq': 4,
'learning_rate': 0.1,
'seed': 2020,
'nthread': 28,
'n_jobs':24,
'silent': True,
'verbose': -1,
}
model = clf.train(params, train_matrix, 50000, valid_sets=[train_matrix, valid_matrix], verbose_eval=200,early_stopping_rounds=200)
val_pred = model.predict(val_x, num_iteration=model.best_iteration)
test_pred = model.predict(test_x, num_iteration=model.best_iteration)
# print(list(sorted(zip(features, model.feature_importance("gain")), key=lambda x: x[1], reverse=True))[:20])
if clf_name == "xgb":
train_matrix = clf.DMatrix(trn_x , label=trn_y)
valid_matrix = clf.DMatrix(val_x , label=val_y)
test_matrix = clf.DMatrix(test_x)
params = {
'booster': 'gbtree',
'objective': 'binary:logistic',
'eval_metric': 'auc',
'gamma': 1,
'min_child_weight': 1.5,
'max_depth': 5,
'lambda': 10,
'subsample': 0.7,
'colsample_bytree': 0.7,
'colsample_bylevel': 0.7,
'eta': 0.04,
'tree_method': 'exact',
'seed': 2020,
'nthread': 36,
"silent": True,
}
watchlist = [(train_matrix, 'train'),(valid_matrix, 'eval')]
model = clf.train(params, train_matrix, num_boost_round=50000, evals=watchlist, verbose_eval=200, early_stopping_rounds=200)
val_pred = model.predict(valid_matrix, ntree_limit=model.best_ntree_limit)
test_pred = model.predict(test_matrix , ntree_limit=model.best_ntree_limit)
if clf_name == "cat":
params = {
'learning_rate': 0.05, 'depth': 5, 'l2_leaf_reg': 10, 'bootstrap_type': 'Bernoulli',
'od_type': 'Iter', 'od_wait': 50, 'random_seed': 11, 'allow_writing_files': False}
model = clf(iterations=20000, **params)
model.fit(trn_x, trn_y, eval_set=(val_x, val_y),
cat_features=[], use_best_model=True, verbose=500)
val_pred = model.predict(val_x)
test_pred = model.predict(test_x)
train[valid_index] = val_pred
test = test_pred / kf.n_splits
cv_scores.append(roc_auc_score(val_y, val_pred))
print(cv_scores)
print("%s_scotrainre_list:" % clf_name, cv_scores)
print("%s_score_mean:" % clf_name, np.mean(cv_scores))
print("%s_score_std:" % clf_name, np.std(cv_scores))
return train, test
版权声明
本文为[Ring tone]所创,转载请带上原文链接,感谢
https://yzsam.com/2022/04/202204230758387342.html
边栏推荐
猜你喜欢
Please arrange star trek in advance to break through the new playing method of chain tour, and the market heat continues to rise
Test your machine learning pipeline
Star Trek强势来袭 开启元宇宙虚拟与现实的梦幻联动
Use of Arthas in JVM tools
After a circle, I sorted out this set of interview questions..
DJ音乐管理软件Pioneer DJ rekordbox
Cadence process angle simulation, Monte Carlo simulation, PSRR
LeetCode_ DFS_ Medium_ 1254. Count the number of closed islands
MySQL小練習(僅適合初學者,非初學者勿進)
L2-024 tribe (25 points) (and check the collection)
随机推荐
Yangtao electronic STM32 Internet of things entry 30 step notes IV. engineering compilation and download
是否同一棵二叉搜索树 (25 分)
MySQL查询两张表属性值非重复的数据
Share the office and improve the settled experience
MySQL small exercise (only suitable for beginners, non beginners are not allowed to enter)
在sqli-liabs学习SQL注入之旅(第十一关~第二十关)
[indexof] [lastIndexOf] [split] [substring] usage details
ONEFLOW learning notes: from functor to opexprinter
Idea is configured to connect to the remote database mysql, or Navicat fails to connect to the remote database (solved)
Reference passing 1
PLC point table (register address and point table definition) cracking detection scheme -- convenient for industrial Internet data acquisition
求简单类型的矩阵和
Pctp test experience sharing
LaTeX论文排版操作
uni-app和微信小程序中的getCurrentPages()
应纳税所得额
Find the sum of simple types of matrices
Complete binary search tree (30 points)
Swagger document export custom V2 / API docs interception
L2-024 部落 (25 分)(并查集)