当前位置:网站首页>A case of violent parameter tuning in machine learning
A case of violent parameter tuning in machine learning
2022-08-10 11:33:00 【51CTO】
暴力调参案例
使用的数据集为
from sklearn.datasets import fetch_20newsgroups
因为在线下载慢,可以提前下载保存到
首先引入所需库
import numpy as np
import pandas as pd
defaultencoding = 'utf-8'
import matplotlib as mpl
import matplotlib. pyplot as plt
from sklearn. naive_bayes import MultinomialNB
from sklearn. neighbors import KNeighborsClassifier
from sklearn. linear_model import LogisticRegression
from sklearn. ensemble import RandomForestClassifier
from sklearn. feature_extraction. text import TfidfVectorizer
from sklearn. model_selection import GridSearchCV
from sklearn. feature_selection import SelectKBest, chi2
import sklearn. metrics as metrics
from sklearn. datasets import fetch_20newsgroups
import sys
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
- 10.
- 11.
- 12.
- 13.
- 14.
- 15.
编码问题显示
如果报错的话可以改为
用来正常显示中文 mpl.rcParams['font.sans-serif']=[u'simHei'] 用来正常正负号 mpl.rcParams['axes.unicode_minus']=False
获取数据
#data_home="./datas/"下载的新闻的保存地址subset='train'表示从训练集获取新闻categories获取哪些种类的新闻
datas = fetch_20newsgroups( data_home = "./datas/", subset = 'train', categories =[ 'alt.atheism', 'comp.graphics', 'comp.os.ms-windows.misc'])
datas_test = fetch_20newsgroups( data_home = "./datas/", subset = 'test', categories =[ 'alt.atheism', 'comp.graphics', 'comp.os.ms-windows.misc'])
train_x = datas. data #获取新闻X
train_y = datas. target #获取新闻Y
test_x = datas_test. data #获取测试集的x
test_y = datas_test. target #获取测试集的y
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
自动调参
import time
def setParam( algo, name):
gridSearch = GridSearchCV( algo, param_grid =[], cv = 5)
m = 0
if hasattr( algo, "alpha"):
n = np. logspace( - 2, 9, 10)
gridSearch. set_params( param_grid ={ "alpha": n})
m = 10
if hasattr( algo, "max_depth"):
depth =[ 2, 7, 10, 14, 20, 30]
gridSearch. set_params( param_grid ={ "max_depth": depth})
m = len( depth)
if hasattr( algo, "n_neighbors"):
neighbors =[ 2, 7, 10]
gridSearch. set_params( param_grid ={ "n_neighbors": neighbors})
m = len( neighbors)
t1 = time. time()
gridSearch. fit( train_x, train_y)
test_y_hat = gridSearch. predict( test_x)
train_y_hat = gridSearch. predict( train_x)
t2 = time. time() - t1
print( name, gridSearch. best_estimator_)
train_error = 1 - metrics. accuracy_score( train_y, train_y_hat)
test_error = 1 - metrics. accuracy_score( test_y, test_y_hat)
return name, t2 / 5 * m, train_error, test_error
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
- 10.
- 11.
- 12.
- 13.
- 14.
- 15.
- 16.
- 17.
- 18.
- 19.
- 20.
- 21.
- 22.
- 23.
- 24.
- 25.
选择算法调参
朴素贝叶斯,随机森林,KNN
可视化
#把名称,花费时间,训练错误率,测试错误率分别存到单个数组
names, times, train_err, test_err =[[ x[ i] for x in results] for i in range( 0, 4)]
axes = plt. axes()
axes. bar( np. arange( len( names)), times, color = "red", label = "耗费时间", width = 0.1)
axes. bar( np. arange( len( names)) + 0.1, train_err, color = "green", label = "训练集错误", width = 0.1)
axes. bar( np. arange( len( names)) + 0.2, test_err, color = "blue", label = "测试集错误", width = 0.1)
plt. xticks( np. arange( len( names)), names)
plt. legend()
plt. show()
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
- 10.
代码整合:
#coding=UTF-8
import numpy as np
import pandas as pd
defaultencoding = 'utf-8'
import matplotlib as mpl
import matplotlib. pyplot as plt
from sklearn. naive_bayes import MultinomialNB
from sklearn. neighbors import KNeighborsClassifier
from sklearn. linear_model import LogisticRegression
from sklearn. ensemble import RandomForestClassifier
from sklearn. feature_extraction. text import TfidfVectorizer
from sklearn. model_selection import GridSearchCV
from sklearn. feature_selection import SelectKBest, chi2
import sklearn. metrics as metrics
from sklearn. datasets import fetch_20newsgroups
import sys
import importlib, sys
if sys. getdefaultencoding() != defaultencoding:
# reload(sys)
importlib. reload( sys)
sys. setdefaultencoding( defaultencoding)
mpl. rcParams[ 'font.sans-serif'] =[ u'simHei']
mpl. rcParams[ 'axes.unicode_minus'] = False
#data_home="./datas/"下载的新闻的保存地址subset='train'表示从训练集获取新闻categories获取哪些种类的新闻
datas = fetch_20newsgroups( data_home = "./datas/", subset = 'train', categories =[ 'alt.atheism', 'comp.graphics', 'comp.os.ms-windows.misc'])
datas_test = fetch_20newsgroups( data_home = "./datas/", subset = 'test', categories =[ 'alt.atheism', 'comp.graphics', 'comp.os.ms-windows.misc'])
train_x = datas. data #获取新闻X
train_y = datas. target #获取新闻Y
test_x = datas_test. data #获取测试集的x
test_y = datas_test. target #获取测试集的y
tfidf = TfidfVectorizer( stop_words = "english")
train_x = tfidf. fit_transform( train_x, train_y) #向量转化
test_x = tfidf. transform( test_x) #向量转化
print( train_x. shape)
best = SelectKBest( chi2, k = 1000) #降维变成一千列
train_x = best. fit_transform( train_x, train_y) #转换
test_x = best. transform( test_x)
import time
def setParam( algo, name):
gridSearch = GridSearchCV( algo, param_grid =[], cv = 5)
m = 0
if hasattr( algo, "alpha"):
n = np. logspace( - 2, 9, 10)
gridSearch. set_params( param_grid ={ "alpha": n})
m = 10
if hasattr( algo, "max_depth"):
depth =[ 2, 7, 10, 14, 20, 30]
gridSearch. set_params( param_grid ={ "max_depth": depth})
m = len( depth)
if hasattr( algo, "n_neighbors"):
neighbors =[ 2, 7, 10]
gridSearch. set_params( param_grid ={ "n_neighbors": neighbors})
m = len( neighbors)
t1 = time. time()
gridSearch. fit( train_x, train_y)
test_y_hat = gridSearch. predict( test_x)
train_y_hat = gridSearch. predict( train_x)
t2 = time. time() - t1
print( name, gridSearch. best_estimator_)
train_error = 1 - metrics. accuracy_score( train_y, train_y_hat)
test_error = 1 - metrics. accuracy_score( test_y, test_y_hat)
return name, t2 / 5 * m, train_error, test_error
results =[]
plt. figure()
algorithm =[( "mnb", MultinomialNB()),( "random", RandomForestClassifier()),( "knn", KNeighborsClassifier())]
for name, algo in algorithm:
result = setParam( algo, name)
results. append( result)
#把名称,花费时间,训练错误率,测试错误率分别存到单个数组
names, times, train_err, test_err =[[ x[ i] for x in results] for i in range( 0, 4)]
axes = plt. axes()
axes. bar( np. arange( len( names)), times, color = "red", label = "耗费时间", width = 0.1)
axes. bar( np. arange( len( names)) + 0.1, train_err, color = "green", label = "训练集错误", width = 0.1)
axes. bar( np. arange( len( names)) + 0.2, test_err, color = "blue", label = "测试集错误", width = 0.1)
plt. xticks( np. arange( len( names)), names)
plt. legend()
plt. show()
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
- 10.
- 11.
- 12.
- 13.
- 14.
- 15.
- 16.
- 17.
- 18.
- 19.
- 20.
- 21.
- 22.
- 23.
- 24.
- 25.
- 26.
- 27.
- 28.
- 29.
- 30.
- 31.
- 32.
- 33.
- 34.
- 35.
- 36.
- 37.
- 38.
- 39.
- 40.
- 41.
- 42.
- 43.
- 44.
- 45.
- 46.
- 47.
- 48.
- 49.
- 50.
- 51.
- 52.
- 53.
- 54.
- 55.
- 56.
- 57.
- 58.
- 59.
- 60.
- 61.
- 62.
- 63.
- 64.
- 65.
- 66.
- 67.
- 68.
- 69.
- 70.
- 71.
- 72.
- 73.
- 74.
- 75.
- 76.
- 77.
- 78.
- 79.
- 80.
- 81.
- 82.
- 83.
- 84.
结果:
边栏推荐
- mysql appears: ERROR 1524 (HY000): Plugin '123' is not loaded
- 基于UiAutomator2+PageObject模式开展APP自动化测试实战
- 开发模式对测试的影响
- 谷歌数据中心发生“电力事故”造成 3 人受伤
- How can an organization judge the success of data governance?
- 电脑怎么设置屏幕息屏时间(日常使用分享)
- L2 applications from a product perspective: why is it a playground?
- 快速上手,征服三种不同分布式架构调用方案
- 是什么影响了MySQL性能?
- What is affecting MySQL performance?
猜你喜欢
随机推荐
How can an organization judge the success of data governance?
企业如何判断数据治理是否成功?
POJ 2891 Strange Way to Express Integers (Extended Euclidean)
Several small projects that I have open sourced over the years
关于“码农”的一点自嘲解构
mysql5.7 installation and deployment - yum installation
POJ 2891 Strange Way to Express Integers (扩展欧几里得)
L2 applications from a product perspective: why is it a playground?
使用哈工大LTP测试分词并且增加自定义字典
第二十二章 源代码文件 REST API 参考(四)
JWT implements login authentication + Token automatic renewal scheme
力扣练习—— 矩形区域不超过 K 的最大数值和(hard)
Pycharm终端出现PS问题、conda或activate不是内部命令问题..
[Brave food, not afraid of the linked list of brushing questions] Merging of ordered linked lists
The impact of development mode on testing
做自媒体月入几万?博主们都在用的几个自媒体工具
StoneDB 文档捉虫活动第一季
2023版揽胜运动曝光,安全、舒适一个不落
Weilai-software development engineer side record
如何使用工程仪器设备在线监测管理系统