当前位置:网站首页>A case of violent parameter tuning in machine learning
A case of violent parameter tuning in machine learning
2022-08-10 11:33:00 【51CTO】
暴力调参案例
使用的数据集为
from sklearn.datasets import fetch_20newsgroups
因为在线下载慢,可以提前下载保存到

首先引入所需库
import
numpy
as
np
import
pandas
as
pd
defaultencoding
=
'utf-8'
import
matplotlib
as
mpl
import
matplotlib.
pyplot
as
plt
from
sklearn.
naive_bayes
import
MultinomialNB
from
sklearn.
neighbors
import
KNeighborsClassifier
from
sklearn.
linear_model
import
LogisticRegression
from
sklearn.
ensemble
import
RandomForestClassifier
from
sklearn.
feature_extraction.
text
import
TfidfVectorizer
from
sklearn.
model_selection
import
GridSearchCV
from
sklearn.
feature_selection
import
SelectKBest,
chi2
import
sklearn.
metrics
as
metrics
from
sklearn.
datasets
import
fetch_20newsgroups
import
sys
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
- 10.
- 11.
- 12.
- 13.
- 14.
- 15.
编码问题显示
如果报错的话可以改为
用来正常显示中文 mpl.rcParams['font.sans-serif']=[u'simHei'] 用来正常正负号 mpl.rcParams['axes.unicode_minus']=False
获取数据
#data_home="./datas/"下载的新闻的保存地址subset='train'表示从训练集获取新闻categories获取哪些种类的新闻
datas
=
fetch_20newsgroups(
data_home
=
"./datas/",
subset
=
'train',
categories
=[
'alt.atheism',
'comp.graphics',
'comp.os.ms-windows.misc'])
datas_test
=
fetch_20newsgroups(
data_home
=
"./datas/",
subset
=
'test',
categories
=[
'alt.atheism',
'comp.graphics',
'comp.os.ms-windows.misc'])
train_x
=
datas.
data
#获取新闻X
train_y
=
datas.
target
#获取新闻Y
test_x
=
datas_test.
data
#获取测试集的x
test_y
=
datas_test.
target
#获取测试集的y
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
自动调参
import
time
def
setParam(
algo,
name):
gridSearch
=
GridSearchCV(
algo,
param_grid
=[],
cv
=
5)
m
=
0
if
hasattr(
algo,
"alpha"):
n
=
np.
logspace(
-
2,
9,
10)
gridSearch.
set_params(
param_grid
={
"alpha":
n})
m
=
10
if
hasattr(
algo,
"max_depth"):
depth
=[
2,
7,
10,
14,
20,
30]
gridSearch.
set_params(
param_grid
={
"max_depth":
depth})
m
=
len(
depth)
if
hasattr(
algo,
"n_neighbors"):
neighbors
=[
2,
7,
10]
gridSearch.
set_params(
param_grid
={
"n_neighbors":
neighbors})
m
=
len(
neighbors)
t1
=
time.
time()
gridSearch.
fit(
train_x,
train_y)
test_y_hat
=
gridSearch.
predict(
test_x)
train_y_hat
=
gridSearch.
predict(
train_x)
t2
=
time.
time()
-
t1
print(
name,
gridSearch.
best_estimator_)
train_error
=
1
-
metrics.
accuracy_score(
train_y,
train_y_hat)
test_error
=
1
-
metrics.
accuracy_score(
test_y,
test_y_hat)
return
name,
t2
/
5
*
m,
train_error,
test_error
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
- 10.
- 11.
- 12.
- 13.
- 14.
- 15.
- 16.
- 17.
- 18.
- 19.
- 20.
- 21.
- 22.
- 23.
- 24.
- 25.
选择算法调参
朴素贝叶斯,随机森林,KNN
可视化
#把名称,花费时间,训练错误率,测试错误率分别存到单个数组
names,
times,
train_err,
test_err
=[[
x[
i]
for
x
in
results]
for
i
in
range(
0,
4)]
axes
=
plt.
axes()
axes.
bar(
np.
arange(
len(
names)),
times,
color
=
"red",
label
=
"耗费时间",
width
=
0.1)
axes.
bar(
np.
arange(
len(
names))
+
0.1,
train_err,
color
=
"green",
label
=
"训练集错误",
width
=
0.1)
axes.
bar(
np.
arange(
len(
names))
+
0.2,
test_err,
color
=
"blue",
label
=
"测试集错误",
width
=
0.1)
plt.
xticks(
np.
arange(
len(
names)),
names)
plt.
legend()
plt.
show()
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
- 10.
代码整合:
#coding=UTF-8
import
numpy
as
np
import
pandas
as
pd
defaultencoding
=
'utf-8'
import
matplotlib
as
mpl
import
matplotlib.
pyplot
as
plt
from
sklearn.
naive_bayes
import
MultinomialNB
from
sklearn.
neighbors
import
KNeighborsClassifier
from
sklearn.
linear_model
import
LogisticRegression
from
sklearn.
ensemble
import
RandomForestClassifier
from
sklearn.
feature_extraction.
text
import
TfidfVectorizer
from
sklearn.
model_selection
import
GridSearchCV
from
sklearn.
feature_selection
import
SelectKBest,
chi2
import
sklearn.
metrics
as
metrics
from
sklearn.
datasets
import
fetch_20newsgroups
import
sys
import
importlib,
sys
if
sys.
getdefaultencoding()
!=
defaultencoding:
# reload(sys)
importlib.
reload(
sys)
sys.
setdefaultencoding(
defaultencoding)
mpl.
rcParams[
'font.sans-serif']
=[
u'simHei']
mpl.
rcParams[
'axes.unicode_minus']
=
False
#data_home="./datas/"下载的新闻的保存地址subset='train'表示从训练集获取新闻categories获取哪些种类的新闻
datas
=
fetch_20newsgroups(
data_home
=
"./datas/",
subset
=
'train',
categories
=[
'alt.atheism',
'comp.graphics',
'comp.os.ms-windows.misc'])
datas_test
=
fetch_20newsgroups(
data_home
=
"./datas/",
subset
=
'test',
categories
=[
'alt.atheism',
'comp.graphics',
'comp.os.ms-windows.misc'])
train_x
=
datas.
data
#获取新闻X
train_y
=
datas.
target
#获取新闻Y
test_x
=
datas_test.
data
#获取测试集的x
test_y
=
datas_test.
target
#获取测试集的y
tfidf
=
TfidfVectorizer(
stop_words
=
"english")
train_x
=
tfidf.
fit_transform(
train_x,
train_y)
#向量转化
test_x
=
tfidf.
transform(
test_x)
#向量转化
print(
train_x.
shape)
best
=
SelectKBest(
chi2,
k
=
1000)
#降维变成一千列
train_x
=
best.
fit_transform(
train_x,
train_y)
#转换
test_x
=
best.
transform(
test_x)
import
time
def
setParam(
algo,
name):
gridSearch
=
GridSearchCV(
algo,
param_grid
=[],
cv
=
5)
m
=
0
if
hasattr(
algo,
"alpha"):
n
=
np.
logspace(
-
2,
9,
10)
gridSearch.
set_params(
param_grid
={
"alpha":
n})
m
=
10
if
hasattr(
algo,
"max_depth"):
depth
=[
2,
7,
10,
14,
20,
30]
gridSearch.
set_params(
param_grid
={
"max_depth":
depth})
m
=
len(
depth)
if
hasattr(
algo,
"n_neighbors"):
neighbors
=[
2,
7,
10]
gridSearch.
set_params(
param_grid
={
"n_neighbors":
neighbors})
m
=
len(
neighbors)
t1
=
time.
time()
gridSearch.
fit(
train_x,
train_y)
test_y_hat
=
gridSearch.
predict(
test_x)
train_y_hat
=
gridSearch.
predict(
train_x)
t2
=
time.
time()
-
t1
print(
name,
gridSearch.
best_estimator_)
train_error
=
1
-
metrics.
accuracy_score(
train_y,
train_y_hat)
test_error
=
1
-
metrics.
accuracy_score(
test_y,
test_y_hat)
return
name,
t2
/
5
*
m,
train_error,
test_error
results
=[]
plt.
figure()
algorithm
=[(
"mnb",
MultinomialNB()),(
"random",
RandomForestClassifier()),(
"knn",
KNeighborsClassifier())]
for
name,
algo
in
algorithm:
result
=
setParam(
algo,
name)
results.
append(
result)
#把名称,花费时间,训练错误率,测试错误率分别存到单个数组
names,
times,
train_err,
test_err
=[[
x[
i]
for
x
in
results]
for
i
in
range(
0,
4)]
axes
=
plt.
axes()
axes.
bar(
np.
arange(
len(
names)),
times,
color
=
"red",
label
=
"耗费时间",
width
=
0.1)
axes.
bar(
np.
arange(
len(
names))
+
0.1,
train_err,
color
=
"green",
label
=
"训练集错误",
width
=
0.1)
axes.
bar(
np.
arange(
len(
names))
+
0.2,
test_err,
color
=
"blue",
label
=
"测试集错误",
width
=
0.1)
plt.
xticks(
np.
arange(
len(
names)),
names)
plt.
legend()
plt.
show()
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
- 10.
- 11.
- 12.
- 13.
- 14.
- 15.
- 16.
- 17.
- 18.
- 19.
- 20.
- 21.
- 22.
- 23.
- 24.
- 25.
- 26.
- 27.
- 28.
- 29.
- 30.
- 31.
- 32.
- 33.
- 34.
- 35.
- 36.
- 37.
- 38.
- 39.
- 40.
- 41.
- 42.
- 43.
- 44.
- 45.
- 46.
- 47.
- 48.
- 49.
- 50.
- 51.
- 52.
- 53.
- 54.
- 55.
- 56.
- 57.
- 58.
- 59.
- 60.
- 61.
- 62.
- 63.
- 64.
- 65.
- 66.
- 67.
- 68.
- 69.
- 70.
- 71.
- 72.
- 73.
- 74.
- 75.
- 76.
- 77.
- 78.
- 79.
- 80.
- 81.
- 82.
- 83.
- 84.
结果:


边栏推荐
猜你喜欢

2022年裁员潮,失业程序员何去何从?

GPU加速Pinterest推荐模型,参数量增加100倍,用户活跃度提高16%

Memory problems difficult to locate, it is because you do not use ASAN

3款不同类型的自媒体免费工具,有效提高创作、运营效率

【勇敢饭饭,不怕刷题之链表】链表倒数节点问题
今天面了个腾讯拿38K出来的大佬,让我见识到了基础的天花板

怎么加入自媒体,了解这5种变现模式,让账号快速变现

谷歌数据中心发生“电力事故”造成 3 人受伤

STM32封装ESP8266一键配置函数:实现实现AP模式和STA模式切换、服务器与客户端创建

即时零售业态下如何实现自动做账?
随机推荐
Research on motion capture system for indoor combined positioning technology
内存问题难定位,那是因为你没用ASAN
HDU 1520 Anniversary party (tree dp)
越折腾越好用的 3 款开源 APP
[Brave food, not afraid to write the linked list] The problem of the penultimate node of the linked list
runtime-core.esm-bundler.js?d2dd:218 Uncaught TypeError: formRef.value?.validate is not a function
【勇敢饭饭,不怕刷题之链表】链表反转的几种情况
Centos7环境使用Mysql离线安装包安装Mysql5.7
POJ 3101 Astronomy (Mathematics)
4 of huawei offer levels, incredibly side is easing the bit in the interview ali?
Hangdian Multi-School-Loop-(uncertainty greedy + line segment tree)
2022年裁员潮,失业程序员何去何从?
1-IMU参数解析以及选择
Nocalhost - 让云原生时代的开发更高效
flask-restplus接口地址404问题
Several small projects that I have open sourced over the years
阻塞 非阻塞 poll机制 异步
Codeforces 814 C. An impassioned circulation of affection (dp)
英特尔推送20220809 CPU微码更新 修补Intel-SA-00657安全漏洞
【勇敢饭饭,不怕刷题之链表】链表中有环的问题