当前位置:网站首页>A case of violent parameter tuning in machine learning
A case of violent parameter tuning in machine learning
2022-08-10 11:33:00 【51CTO】
暴力调参案例
使用的数据集为
from sklearn.datasets import fetch_20newsgroups
因为在线下载慢,可以提前下载保存到

首先引入所需库
import
numpy
as
np
import
pandas
as
pd
defaultencoding
=
'utf-8'
import
matplotlib
as
mpl
import
matplotlib.
pyplot
as
plt
from
sklearn.
naive_bayes
import
MultinomialNB
from
sklearn.
neighbors
import
KNeighborsClassifier
from
sklearn.
linear_model
import
LogisticRegression
from
sklearn.
ensemble
import
RandomForestClassifier
from
sklearn.
feature_extraction.
text
import
TfidfVectorizer
from
sklearn.
model_selection
import
GridSearchCV
from
sklearn.
feature_selection
import
SelectKBest,
chi2
import
sklearn.
metrics
as
metrics
from
sklearn.
datasets
import
fetch_20newsgroups
import
sys
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
- 10.
- 11.
- 12.
- 13.
- 14.
- 15.
编码问题显示
如果报错的话可以改为
用来正常显示中文 mpl.rcParams['font.sans-serif']=[u'simHei'] 用来正常正负号 mpl.rcParams['axes.unicode_minus']=False
获取数据
#data_home="./datas/"下载的新闻的保存地址subset='train'表示从训练集获取新闻categories获取哪些种类的新闻
datas
=
fetch_20newsgroups(
data_home
=
"./datas/",
subset
=
'train',
categories
=[
'alt.atheism',
'comp.graphics',
'comp.os.ms-windows.misc'])
datas_test
=
fetch_20newsgroups(
data_home
=
"./datas/",
subset
=
'test',
categories
=[
'alt.atheism',
'comp.graphics',
'comp.os.ms-windows.misc'])
train_x
=
datas.
data
#获取新闻X
train_y
=
datas.
target
#获取新闻Y
test_x
=
datas_test.
data
#获取测试集的x
test_y
=
datas_test.
target
#获取测试集的y
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
自动调参
import
time
def
setParam(
algo,
name):
gridSearch
=
GridSearchCV(
algo,
param_grid
=[],
cv
=
5)
m
=
0
if
hasattr(
algo,
"alpha"):
n
=
np.
logspace(
-
2,
9,
10)
gridSearch.
set_params(
param_grid
={
"alpha":
n})
m
=
10
if
hasattr(
algo,
"max_depth"):
depth
=[
2,
7,
10,
14,
20,
30]
gridSearch.
set_params(
param_grid
={
"max_depth":
depth})
m
=
len(
depth)
if
hasattr(
algo,
"n_neighbors"):
neighbors
=[
2,
7,
10]
gridSearch.
set_params(
param_grid
={
"n_neighbors":
neighbors})
m
=
len(
neighbors)
t1
=
time.
time()
gridSearch.
fit(
train_x,
train_y)
test_y_hat
=
gridSearch.
predict(
test_x)
train_y_hat
=
gridSearch.
predict(
train_x)
t2
=
time.
time()
-
t1
print(
name,
gridSearch.
best_estimator_)
train_error
=
1
-
metrics.
accuracy_score(
train_y,
train_y_hat)
test_error
=
1
-
metrics.
accuracy_score(
test_y,
test_y_hat)
return
name,
t2
/
5
*
m,
train_error,
test_error
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
- 10.
- 11.
- 12.
- 13.
- 14.
- 15.
- 16.
- 17.
- 18.
- 19.
- 20.
- 21.
- 22.
- 23.
- 24.
- 25.
选择算法调参
朴素贝叶斯,随机森林,KNN
可视化
#把名称,花费时间,训练错误率,测试错误率分别存到单个数组
names,
times,
train_err,
test_err
=[[
x[
i]
for
x
in
results]
for
i
in
range(
0,
4)]
axes
=
plt.
axes()
axes.
bar(
np.
arange(
len(
names)),
times,
color
=
"red",
label
=
"耗费时间",
width
=
0.1)
axes.
bar(
np.
arange(
len(
names))
+
0.1,
train_err,
color
=
"green",
label
=
"训练集错误",
width
=
0.1)
axes.
bar(
np.
arange(
len(
names))
+
0.2,
test_err,
color
=
"blue",
label
=
"测试集错误",
width
=
0.1)
plt.
xticks(
np.
arange(
len(
names)),
names)
plt.
legend()
plt.
show()
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
- 10.
代码整合:
#coding=UTF-8
import
numpy
as
np
import
pandas
as
pd
defaultencoding
=
'utf-8'
import
matplotlib
as
mpl
import
matplotlib.
pyplot
as
plt
from
sklearn.
naive_bayes
import
MultinomialNB
from
sklearn.
neighbors
import
KNeighborsClassifier
from
sklearn.
linear_model
import
LogisticRegression
from
sklearn.
ensemble
import
RandomForestClassifier
from
sklearn.
feature_extraction.
text
import
TfidfVectorizer
from
sklearn.
model_selection
import
GridSearchCV
from
sklearn.
feature_selection
import
SelectKBest,
chi2
import
sklearn.
metrics
as
metrics
from
sklearn.
datasets
import
fetch_20newsgroups
import
sys
import
importlib,
sys
if
sys.
getdefaultencoding()
!=
defaultencoding:
# reload(sys)
importlib.
reload(
sys)
sys.
setdefaultencoding(
defaultencoding)
mpl.
rcParams[
'font.sans-serif']
=[
u'simHei']
mpl.
rcParams[
'axes.unicode_minus']
=
False
#data_home="./datas/"下载的新闻的保存地址subset='train'表示从训练集获取新闻categories获取哪些种类的新闻
datas
=
fetch_20newsgroups(
data_home
=
"./datas/",
subset
=
'train',
categories
=[
'alt.atheism',
'comp.graphics',
'comp.os.ms-windows.misc'])
datas_test
=
fetch_20newsgroups(
data_home
=
"./datas/",
subset
=
'test',
categories
=[
'alt.atheism',
'comp.graphics',
'comp.os.ms-windows.misc'])
train_x
=
datas.
data
#获取新闻X
train_y
=
datas.
target
#获取新闻Y
test_x
=
datas_test.
data
#获取测试集的x
test_y
=
datas_test.
target
#获取测试集的y
tfidf
=
TfidfVectorizer(
stop_words
=
"english")
train_x
=
tfidf.
fit_transform(
train_x,
train_y)
#向量转化
test_x
=
tfidf.
transform(
test_x)
#向量转化
print(
train_x.
shape)
best
=
SelectKBest(
chi2,
k
=
1000)
#降维变成一千列
train_x
=
best.
fit_transform(
train_x,
train_y)
#转换
test_x
=
best.
transform(
test_x)
import
time
def
setParam(
algo,
name):
gridSearch
=
GridSearchCV(
algo,
param_grid
=[],
cv
=
5)
m
=
0
if
hasattr(
algo,
"alpha"):
n
=
np.
logspace(
-
2,
9,
10)
gridSearch.
set_params(
param_grid
={
"alpha":
n})
m
=
10
if
hasattr(
algo,
"max_depth"):
depth
=[
2,
7,
10,
14,
20,
30]
gridSearch.
set_params(
param_grid
={
"max_depth":
depth})
m
=
len(
depth)
if
hasattr(
algo,
"n_neighbors"):
neighbors
=[
2,
7,
10]
gridSearch.
set_params(
param_grid
={
"n_neighbors":
neighbors})
m
=
len(
neighbors)
t1
=
time.
time()
gridSearch.
fit(
train_x,
train_y)
test_y_hat
=
gridSearch.
predict(
test_x)
train_y_hat
=
gridSearch.
predict(
train_x)
t2
=
time.
time()
-
t1
print(
name,
gridSearch.
best_estimator_)
train_error
=
1
-
metrics.
accuracy_score(
train_y,
train_y_hat)
test_error
=
1
-
metrics.
accuracy_score(
test_y,
test_y_hat)
return
name,
t2
/
5
*
m,
train_error,
test_error
results
=[]
plt.
figure()
algorithm
=[(
"mnb",
MultinomialNB()),(
"random",
RandomForestClassifier()),(
"knn",
KNeighborsClassifier())]
for
name,
algo
in
algorithm:
result
=
setParam(
algo,
name)
results.
append(
result)
#把名称,花费时间,训练错误率,测试错误率分别存到单个数组
names,
times,
train_err,
test_err
=[[
x[
i]
for
x
in
results]
for
i
in
range(
0,
4)]
axes
=
plt.
axes()
axes.
bar(
np.
arange(
len(
names)),
times,
color
=
"red",
label
=
"耗费时间",
width
=
0.1)
axes.
bar(
np.
arange(
len(
names))
+
0.1,
train_err,
color
=
"green",
label
=
"训练集错误",
width
=
0.1)
axes.
bar(
np.
arange(
len(
names))
+
0.2,
test_err,
color
=
"blue",
label
=
"测试集错误",
width
=
0.1)
plt.
xticks(
np.
arange(
len(
names)),
names)
plt.
legend()
plt.
show()
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
- 10.
- 11.
- 12.
- 13.
- 14.
- 15.
- 16.
- 17.
- 18.
- 19.
- 20.
- 21.
- 22.
- 23.
- 24.
- 25.
- 26.
- 27.
- 28.
- 29.
- 30.
- 31.
- 32.
- 33.
- 34.
- 35.
- 36.
- 37.
- 38.
- 39.
- 40.
- 41.
- 42.
- 43.
- 44.
- 45.
- 46.
- 47.
- 48.
- 49.
- 50.
- 51.
- 52.
- 53.
- 54.
- 55.
- 56.
- 57.
- 58.
- 59.
- 60.
- 61.
- 62.
- 63.
- 64.
- 65.
- 66.
- 67.
- 68.
- 69.
- 70.
- 71.
- 72.
- 73.
- 74.
- 75.
- 76.
- 77.
- 78.
- 79.
- 80.
- 81.
- 82.
- 83.
- 84.
结果:


边栏推荐
- 使用哈工大LTP测试分词并且增加自定义字典
- 振弦传感器及核心VM系列振弦采集模块
- GPU accelerated Pinterest recommendation model, the number of parameters increased by 100 times, and the user activity increased by 16%
- POJ 2891 Strange Way to Express Integers (Extended Euclidean)
- 金九银十跳槽旺季:阿里、百度、京东、美团等技术面试题及答案
- Break through the dimensional barriers and let the dolls around you move on the screen!
- AutoCAD Map 3D功能之一暴力处理悬挂点(延伸)
- Codeforces 814 C. An impassioned circulation of affection (dp)
- ENVI 5.3软件安装包和安装教程
- 力扣练习——64 最长和谐子序列
猜你喜欢

【电商运营】你真的了解社交媒体营销(SMM)吗?

Research on motion capture system for indoor combined positioning technology

Spss-多元回归案例实操
今天面了个腾讯拿38K出来的大佬,让我见识到了基础的天花板

第2章-矩阵及其运算-矩阵运算(2)

英特尔推送20220809 CPU微码更新 修补Intel-SA-00657安全漏洞

一文带你搞懂中断按键驱动程序之poll机制

The brave rice rice, does not fear the brush list of 】 list has a ring
![[Brave food, not afraid of the linked list of brushing questions] Merging of ordered linked lists](/img/06/9d49fc99ab684f03740deb2abc38e2.png)
[Brave food, not afraid of the linked list of brushing questions] Merging of ordered linked lists

Article take you understand interrupt the key driver of polling mechanism
随机推荐
8月份DB-Engines 数据库排行榜最新战况
Some tips for using Unsafe
Short video software development - how to break the platform homogenization
使用.NET简单实现一个Redis的高性能克隆版(六)
推荐6个自媒体领域,轻松易上手
CPU多级缓存与缓存一致性
从源码角度分析UUID的实现原理
微信小程序提交审核历史版本记录从哪里查看
短视频软件开发——平台同质化如何破局
程序员追求技术夯实基础学习路线建议
力扣练习——58 验证二叉搜索树
【电商运营】你真的了解社交媒体营销(SMM)吗?
Hangdian Multi-School-Loop-(uncertainty greedy + line segment tree)
从产品角度看 L2 应用:为什么说这是一个游乐场?
GPU加速Pinterest推荐模型,参数量增加100倍,用户活跃度提高16%
三个绘图工具类详解Paint(画笔)Canvas(画布)Path(路径)
AUTOCAD——减少样条曲线控制点数、CAD进阶练习(三)
runtime-core.esm-bundler.js?d2dd:218 Uncaught TypeError: formRef.value?.validate is not a function
HDU 1520 Anniversary party (树型dp)
Will SQL and NoSQL eventually converge?