当前位置:网站首页>A case of violent parameter tuning in machine learning
A case of violent parameter tuning in machine learning
2022-08-10 11:33:00 【51CTO】
暴力调参案例
使用的数据集为
from sklearn.datasets import fetch_20newsgroups
因为在线下载慢,可以提前下载保存到
首先引入所需库
import numpy as np
import pandas as pd
defaultencoding = 'utf-8'
import matplotlib as mpl
import matplotlib. pyplot as plt
from sklearn. naive_bayes import MultinomialNB
from sklearn. neighbors import KNeighborsClassifier
from sklearn. linear_model import LogisticRegression
from sklearn. ensemble import RandomForestClassifier
from sklearn. feature_extraction. text import TfidfVectorizer
from sklearn. model_selection import GridSearchCV
from sklearn. feature_selection import SelectKBest, chi2
import sklearn. metrics as metrics
from sklearn. datasets import fetch_20newsgroups
import sys
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
- 10.
- 11.
- 12.
- 13.
- 14.
- 15.
编码问题显示
如果报错的话可以改为
用来正常显示中文 mpl.rcParams['font.sans-serif']=[u'simHei'] 用来正常正负号 mpl.rcParams['axes.unicode_minus']=False
获取数据
#data_home="./datas/"下载的新闻的保存地址subset='train'表示从训练集获取新闻categories获取哪些种类的新闻
datas = fetch_20newsgroups( data_home = "./datas/", subset = 'train', categories =[ 'alt.atheism', 'comp.graphics', 'comp.os.ms-windows.misc'])
datas_test = fetch_20newsgroups( data_home = "./datas/", subset = 'test', categories =[ 'alt.atheism', 'comp.graphics', 'comp.os.ms-windows.misc'])
train_x = datas. data #获取新闻X
train_y = datas. target #获取新闻Y
test_x = datas_test. data #获取测试集的x
test_y = datas_test. target #获取测试集的y
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
自动调参
import time
def setParam( algo, name):
gridSearch = GridSearchCV( algo, param_grid =[], cv = 5)
m = 0
if hasattr( algo, "alpha"):
n = np. logspace( - 2, 9, 10)
gridSearch. set_params( param_grid ={ "alpha": n})
m = 10
if hasattr( algo, "max_depth"):
depth =[ 2, 7, 10, 14, 20, 30]
gridSearch. set_params( param_grid ={ "max_depth": depth})
m = len( depth)
if hasattr( algo, "n_neighbors"):
neighbors =[ 2, 7, 10]
gridSearch. set_params( param_grid ={ "n_neighbors": neighbors})
m = len( neighbors)
t1 = time. time()
gridSearch. fit( train_x, train_y)
test_y_hat = gridSearch. predict( test_x)
train_y_hat = gridSearch. predict( train_x)
t2 = time. time() - t1
print( name, gridSearch. best_estimator_)
train_error = 1 - metrics. accuracy_score( train_y, train_y_hat)
test_error = 1 - metrics. accuracy_score( test_y, test_y_hat)
return name, t2 / 5 * m, train_error, test_error
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
- 10.
- 11.
- 12.
- 13.
- 14.
- 15.
- 16.
- 17.
- 18.
- 19.
- 20.
- 21.
- 22.
- 23.
- 24.
- 25.
选择算法调参
朴素贝叶斯,随机森林,KNN
可视化
#把名称,花费时间,训练错误率,测试错误率分别存到单个数组
names, times, train_err, test_err =[[ x[ i] for x in results] for i in range( 0, 4)]
axes = plt. axes()
axes. bar( np. arange( len( names)), times, color = "red", label = "耗费时间", width = 0.1)
axes. bar( np. arange( len( names)) + 0.1, train_err, color = "green", label = "训练集错误", width = 0.1)
axes. bar( np. arange( len( names)) + 0.2, test_err, color = "blue", label = "测试集错误", width = 0.1)
plt. xticks( np. arange( len( names)), names)
plt. legend()
plt. show()
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
- 10.
代码整合:
#coding=UTF-8
import numpy as np
import pandas as pd
defaultencoding = 'utf-8'
import matplotlib as mpl
import matplotlib. pyplot as plt
from sklearn. naive_bayes import MultinomialNB
from sklearn. neighbors import KNeighborsClassifier
from sklearn. linear_model import LogisticRegression
from sklearn. ensemble import RandomForestClassifier
from sklearn. feature_extraction. text import TfidfVectorizer
from sklearn. model_selection import GridSearchCV
from sklearn. feature_selection import SelectKBest, chi2
import sklearn. metrics as metrics
from sklearn. datasets import fetch_20newsgroups
import sys
import importlib, sys
if sys. getdefaultencoding() != defaultencoding:
# reload(sys)
importlib. reload( sys)
sys. setdefaultencoding( defaultencoding)
mpl. rcParams[ 'font.sans-serif'] =[ u'simHei']
mpl. rcParams[ 'axes.unicode_minus'] = False
#data_home="./datas/"下载的新闻的保存地址subset='train'表示从训练集获取新闻categories获取哪些种类的新闻
datas = fetch_20newsgroups( data_home = "./datas/", subset = 'train', categories =[ 'alt.atheism', 'comp.graphics', 'comp.os.ms-windows.misc'])
datas_test = fetch_20newsgroups( data_home = "./datas/", subset = 'test', categories =[ 'alt.atheism', 'comp.graphics', 'comp.os.ms-windows.misc'])
train_x = datas. data #获取新闻X
train_y = datas. target #获取新闻Y
test_x = datas_test. data #获取测试集的x
test_y = datas_test. target #获取测试集的y
tfidf = TfidfVectorizer( stop_words = "english")
train_x = tfidf. fit_transform( train_x, train_y) #向量转化
test_x = tfidf. transform( test_x) #向量转化
print( train_x. shape)
best = SelectKBest( chi2, k = 1000) #降维变成一千列
train_x = best. fit_transform( train_x, train_y) #转换
test_x = best. transform( test_x)
import time
def setParam( algo, name):
gridSearch = GridSearchCV( algo, param_grid =[], cv = 5)
m = 0
if hasattr( algo, "alpha"):
n = np. logspace( - 2, 9, 10)
gridSearch. set_params( param_grid ={ "alpha": n})
m = 10
if hasattr( algo, "max_depth"):
depth =[ 2, 7, 10, 14, 20, 30]
gridSearch. set_params( param_grid ={ "max_depth": depth})
m = len( depth)
if hasattr( algo, "n_neighbors"):
neighbors =[ 2, 7, 10]
gridSearch. set_params( param_grid ={ "n_neighbors": neighbors})
m = len( neighbors)
t1 = time. time()
gridSearch. fit( train_x, train_y)
test_y_hat = gridSearch. predict( test_x)
train_y_hat = gridSearch. predict( train_x)
t2 = time. time() - t1
print( name, gridSearch. best_estimator_)
train_error = 1 - metrics. accuracy_score( train_y, train_y_hat)
test_error = 1 - metrics. accuracy_score( test_y, test_y_hat)
return name, t2 / 5 * m, train_error, test_error
results =[]
plt. figure()
algorithm =[( "mnb", MultinomialNB()),( "random", RandomForestClassifier()),( "knn", KNeighborsClassifier())]
for name, algo in algorithm:
result = setParam( algo, name)
results. append( result)
#把名称,花费时间,训练错误率,测试错误率分别存到单个数组
names, times, train_err, test_err =[[ x[ i] for x in results] for i in range( 0, 4)]
axes = plt. axes()
axes. bar( np. arange( len( names)), times, color = "red", label = "耗费时间", width = 0.1)
axes. bar( np. arange( len( names)) + 0.1, train_err, color = "green", label = "训练集错误", width = 0.1)
axes. bar( np. arange( len( names)) + 0.2, test_err, color = "blue", label = "测试集错误", width = 0.1)
plt. xticks( np. arange( len( names)), names)
plt. legend()
plt. show()
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
- 10.
- 11.
- 12.
- 13.
- 14.
- 15.
- 16.
- 17.
- 18.
- 19.
- 20.
- 21.
- 22.
- 23.
- 24.
- 25.
- 26.
- 27.
- 28.
- 29.
- 30.
- 31.
- 32.
- 33.
- 34.
- 35.
- 36.
- 37.
- 38.
- 39.
- 40.
- 41.
- 42.
- 43.
- 44.
- 45.
- 46.
- 47.
- 48.
- 49.
- 50.
- 51.
- 52.
- 53.
- 54.
- 55.
- 56.
- 57.
- 58.
- 59.
- 60.
- 61.
- 62.
- 63.
- 64.
- 65.
- 66.
- 67.
- 68.
- 69.
- 70.
- 71.
- 72.
- 73.
- 74.
- 75.
- 76.
- 77.
- 78.
- 79.
- 80.
- 81.
- 82.
- 83.
- 84.
结果:
边栏推荐
- 是什么影响了MySQL性能?
- 负载均衡原理分析与源码解读
- 【TypeScript】接口类型与类型别名:这两者的用法与区别分别是什么?
- 3 injured in 'electrical accident' at Google data center
- Cybersecurity Notes 5 - Digital Signatures
- 2022年裁员潮,失业程序员何去何从?
- Codeforces 814 C. An impassioned circulation of affection (dp)
- Nocalhost - 让云原生时代的开发更高效
- HCIP ---- VLAN
- 快速上手,征服三种不同分布式架构调用方案
猜你喜欢
随机推荐
POJ 1026 Cipher (Permutation Groups)
【勇敢饭饭,不怕刷题之链表】链表反转的几种情况
3 injured in 'electrical accident' at Google data center
建校仅11年就入选“双一流” ,这所高校是凭什么做到的?
关于振弦采集模块及采集仪振弦频率值准确率的问题
[Brave food, not afraid to write the linked list] The problem of the penultimate node of the linked list
GPU accelerated Pinterest recommendation model, the number of parameters increased by 100 times, and the user activity increased by 16%
LAXCUS分布式操作系统安全管理
Emulate stm32 directly with proteus - the programmer can be completely discarded
SQL优化最强总结 (建议收藏~)
Mobile and PC compatible loading and toast message plugins
为什么Redis很快
英特尔推送20220809 CPU微码更新 修补Intel-SA-00657安全漏洞
快速上手,征服三种不同分布式架构调用方案
力扣练习——64 最长和谐子序列
做自媒体月入几万?博主们都在用的几个自媒体工具
使用JMeter进行MySQL的压力测试
Three-phase 380V rectified voltage
Double.doubleToLongBits()方法使用
OneFlow source code parsing: operator instructions executed in a virtual machine