当前位置:网站首页>机器学习之暴力调参案例
机器学习之暴力调参案例
2022-08-10 10:56:00 【51CTO】
暴力调参案例
使用的数据集为
from sklearn.datasets import fetch_20newsgroups
因为在线下载慢,可以提前下载保存到
首先引入所需库
import numpy as np
import pandas as pd
defaultencoding = 'utf-8'
import matplotlib as mpl
import matplotlib. pyplot as plt
from sklearn. naive_bayes import MultinomialNB
from sklearn. neighbors import KNeighborsClassifier
from sklearn. linear_model import LogisticRegression
from sklearn. ensemble import RandomForestClassifier
from sklearn. feature_extraction. text import TfidfVectorizer
from sklearn. model_selection import GridSearchCV
from sklearn. feature_selection import SelectKBest, chi2
import sklearn. metrics as metrics
from sklearn. datasets import fetch_20newsgroups
import sys
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
- 10.
- 11.
- 12.
- 13.
- 14.
- 15.
编码问题显示
如果报错的话可以改为
用来正常显示中文 mpl.rcParams['font.sans-serif']=[u'simHei'] 用来正常正负号 mpl.rcParams['axes.unicode_minus']=False
获取数据
#data_home="./datas/"下载的新闻的保存地址subset='train'表示从训练集获取新闻categories获取哪些种类的新闻
datas = fetch_20newsgroups( data_home = "./datas/", subset = 'train', categories =[ 'alt.atheism', 'comp.graphics', 'comp.os.ms-windows.misc'])
datas_test = fetch_20newsgroups( data_home = "./datas/", subset = 'test', categories =[ 'alt.atheism', 'comp.graphics', 'comp.os.ms-windows.misc'])
train_x = datas. data #获取新闻X
train_y = datas. target #获取新闻Y
test_x = datas_test. data #获取测试集的x
test_y = datas_test. target #获取测试集的y
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
自动调参
import time
def setParam( algo, name):
gridSearch = GridSearchCV( algo, param_grid =[], cv = 5)
m = 0
if hasattr( algo, "alpha"):
n = np. logspace( - 2, 9, 10)
gridSearch. set_params( param_grid ={ "alpha": n})
m = 10
if hasattr( algo, "max_depth"):
depth =[ 2, 7, 10, 14, 20, 30]
gridSearch. set_params( param_grid ={ "max_depth": depth})
m = len( depth)
if hasattr( algo, "n_neighbors"):
neighbors =[ 2, 7, 10]
gridSearch. set_params( param_grid ={ "n_neighbors": neighbors})
m = len( neighbors)
t1 = time. time()
gridSearch. fit( train_x, train_y)
test_y_hat = gridSearch. predict( test_x)
train_y_hat = gridSearch. predict( train_x)
t2 = time. time() - t1
print( name, gridSearch. best_estimator_)
train_error = 1 - metrics. accuracy_score( train_y, train_y_hat)
test_error = 1 - metrics. accuracy_score( test_y, test_y_hat)
return name, t2 / 5 * m, train_error, test_error
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
- 10.
- 11.
- 12.
- 13.
- 14.
- 15.
- 16.
- 17.
- 18.
- 19.
- 20.
- 21.
- 22.
- 23.
- 24.
- 25.
选择算法调参
朴素贝叶斯,随机森林,KNN
可视化
#把名称,花费时间,训练错误率,测试错误率分别存到单个数组
names, times, train_err, test_err =[[ x[ i] for x in results] for i in range( 0, 4)]
axes = plt. axes()
axes. bar( np. arange( len( names)), times, color = "red", label = "耗费时间", width = 0.1)
axes. bar( np. arange( len( names)) + 0.1, train_err, color = "green", label = "训练集错误", width = 0.1)
axes. bar( np. arange( len( names)) + 0.2, test_err, color = "blue", label = "测试集错误", width = 0.1)
plt. xticks( np. arange( len( names)), names)
plt. legend()
plt. show()
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
- 10.
代码整合:
#coding=UTF-8
import numpy as np
import pandas as pd
defaultencoding = 'utf-8'
import matplotlib as mpl
import matplotlib. pyplot as plt
from sklearn. naive_bayes import MultinomialNB
from sklearn. neighbors import KNeighborsClassifier
from sklearn. linear_model import LogisticRegression
from sklearn. ensemble import RandomForestClassifier
from sklearn. feature_extraction. text import TfidfVectorizer
from sklearn. model_selection import GridSearchCV
from sklearn. feature_selection import SelectKBest, chi2
import sklearn. metrics as metrics
from sklearn. datasets import fetch_20newsgroups
import sys
import importlib, sys
if sys. getdefaultencoding() != defaultencoding:
# reload(sys)
importlib. reload( sys)
sys. setdefaultencoding( defaultencoding)
mpl. rcParams[ 'font.sans-serif'] =[ u'simHei']
mpl. rcParams[ 'axes.unicode_minus'] = False
#data_home="./datas/"下载的新闻的保存地址subset='train'表示从训练集获取新闻categories获取哪些种类的新闻
datas = fetch_20newsgroups( data_home = "./datas/", subset = 'train', categories =[ 'alt.atheism', 'comp.graphics', 'comp.os.ms-windows.misc'])
datas_test = fetch_20newsgroups( data_home = "./datas/", subset = 'test', categories =[ 'alt.atheism', 'comp.graphics', 'comp.os.ms-windows.misc'])
train_x = datas. data #获取新闻X
train_y = datas. target #获取新闻Y
test_x = datas_test. data #获取测试集的x
test_y = datas_test. target #获取测试集的y
tfidf = TfidfVectorizer( stop_words = "english")
train_x = tfidf. fit_transform( train_x, train_y) #向量转化
test_x = tfidf. transform( test_x) #向量转化
print( train_x. shape)
best = SelectKBest( chi2, k = 1000) #降维变成一千列
train_x = best. fit_transform( train_x, train_y) #转换
test_x = best. transform( test_x)
import time
def setParam( algo, name):
gridSearch = GridSearchCV( algo, param_grid =[], cv = 5)
m = 0
if hasattr( algo, "alpha"):
n = np. logspace( - 2, 9, 10)
gridSearch. set_params( param_grid ={ "alpha": n})
m = 10
if hasattr( algo, "max_depth"):
depth =[ 2, 7, 10, 14, 20, 30]
gridSearch. set_params( param_grid ={ "max_depth": depth})
m = len( depth)
if hasattr( algo, "n_neighbors"):
neighbors =[ 2, 7, 10]
gridSearch. set_params( param_grid ={ "n_neighbors": neighbors})
m = len( neighbors)
t1 = time. time()
gridSearch. fit( train_x, train_y)
test_y_hat = gridSearch. predict( test_x)
train_y_hat = gridSearch. predict( train_x)
t2 = time. time() - t1
print( name, gridSearch. best_estimator_)
train_error = 1 - metrics. accuracy_score( train_y, train_y_hat)
test_error = 1 - metrics. accuracy_score( test_y, test_y_hat)
return name, t2 / 5 * m, train_error, test_error
results =[]
plt. figure()
algorithm =[( "mnb", MultinomialNB()),( "random", RandomForestClassifier()),( "knn", KNeighborsClassifier())]
for name, algo in algorithm:
result = setParam( algo, name)
results. append( result)
#把名称,花费时间,训练错误率,测试错误率分别存到单个数组
names, times, train_err, test_err =[[ x[ i] for x in results] for i in range( 0, 4)]
axes = plt. axes()
axes. bar( np. arange( len( names)), times, color = "red", label = "耗费时间", width = 0.1)
axes. bar( np. arange( len( names)) + 0.1, train_err, color = "green", label = "训练集错误", width = 0.1)
axes. bar( np. arange( len( names)) + 0.2, test_err, color = "blue", label = "测试集错误", width = 0.1)
plt. xticks( np. arange( len( names)), names)
plt. legend()
plt. show()
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
- 10.
- 11.
- 12.
- 13.
- 14.
- 15.
- 16.
- 17.
- 18.
- 19.
- 20.
- 21.
- 22.
- 23.
- 24.
- 25.
- 26.
- 27.
- 28.
- 29.
- 30.
- 31.
- 32.
- 33.
- 34.
- 35.
- 36.
- 37.
- 38.
- 39.
- 40.
- 41.
- 42.
- 43.
- 44.
- 45.
- 46.
- 47.
- 48.
- 49.
- 50.
- 51.
- 52.
- 53.
- 54.
- 55.
- 56.
- 57.
- 58.
- 59.
- 60.
- 61.
- 62.
- 63.
- 64.
- 65.
- 66.
- 67.
- 68.
- 69.
- 70.
- 71.
- 72.
- 73.
- 74.
- 75.
- 76.
- 77.
- 78.
- 79.
- 80.
- 81.
- 82.
- 83.
- 84.
结果:
边栏推荐
猜你喜欢
随机推荐
【勇敢饭饭,不怕刷题之链表】链表反转的几种情况
动作捕捉系统用于室内组合定位技术研究
what is bsp in rtems
1-IMU参数解析以及选择
mysql出现:ERROR 1524 (HY000): Plugin ‘123‘ is not loaded
第二十二章 源代码文件 REST API 参考(四)
Some tips for using Unsafe
2022年裁员潮,失业程序员何去何从?
蔚来-软件开发工程师一面记录
Several small projects that I have open sourced over the years
让软件飞——“X+”技术揭秘
接口定义与实现
What is an abstract class
【无标题】
从产品角度看 L2 应用:为什么说这是一个游乐场?
使用cpolar远程连接群晖NAS(升级固定链接2)
为什么Redis很快
学长告诉我,大厂MySQL都是通过SSH连接的
LeetCode_152_乘积最大子数组
Weilai-software development engineer side record