当前位置:网站首页>《PyTorch深度学习实践》08 加载数据集
《PyTorch深度学习实践》08 加载数据集
2022-04-22 19:44:00 【小白学知识】
1. 说明
本系列博客记录B站课程《PyTorch深度学习实践》的实践代码课程链接请点我
2. 代码
# ---------------------------
# @Time : 2022/4/14 11:45
# @Author : lcq
# @File : 08_MyDataLoaderTest.py
# @Function :
# ---------------------------
import torch
import numpy as np
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
# Step1: 加载数据
filePath = "data//diabetes.csv"
xy = np.loadtxt(filePath, delimiter=',', dtype=np.float32)
X = xy[:, :-1]
Y = xy[:, [-1]]
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.3) # 划分测试集和训练集,test_size=0.3为测试集的比例
X_test = torch.from_numpy(X_test) # 转换程torch的张量
Y_test = torch.from_numpy(Y_test)
# Step2:将训练集进行小批量处理
class MiniBatchDataset(Dataset):
def __init__(self, X_input, Y_label):
self.len = X_input.shape[0]
self.X_input = torch.from_numpy(X_input)
self.Y_label = torch.from_numpy(Y_label)
def __getitem__(self, index):
return self.X_input[index], self.Y_label[index]
def __len__(self):
return self.len
# Step3: 将数据变成小批量数据
train_data = MiniBatchDataset(X_train, Y_train)
train_loader = DataLoader(dataset=train_data, batch_size=32, shuffle=True, num_workers=2)
# Step4: 构建模型
class Model(torch.nn.Module):
def __init__(self, column):
super(Model, self).__init__()
self.model = torch.nn.Sequential(
torch.nn.Linear(column, 6),
torch.nn.Sigmoid(),
torch.nn.Linear(6, 4),
torch.nn.Sigmoid(),
torch.nn.Linear(4, 2),
torch.nn.Sigmoid(),
torch.nn.Linear(2, 1),
torch.nn.Sigmoid()
)
def forward(self, X_input):
y_pred = self.model(X_input)
return y_pred
model = Model(column=X_test.shape[1])
# Step5: 构建损失函数和优化器
Loss = torch.nn.BCELoss(reduction='mean')
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
# Step6: 训练
def train(epoch):
for i, data in enumerate(train_loader, 0):
inputs, labels = data
y_pred = model.forward(inputs)
loss = Loss(y_pred, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print("epoch: ", epoch, "i: ", i, "loss: ", loss.item())
def test():
with torch.no_grad(): # 表示之后的运算不需要进行图构建
y_pred = model.forward(X_test)
y_pred_label = torch.where(y_pred >= 0.5, torch.tensor([1.0]), torch.tensor([0.0]))
acc = torch.eq(y_pred_label, Y_test).sum().item() / Y_test.size(0)
print("test acc: ", acc)
if __name__ == '__main__':
for epoch in range(5):
train(epoch)
print('-------test-----------')
test()
版权声明
本文为[小白学知识]所创,转载请带上原文链接,感谢
https://blog.csdn.net/qq_41915623/article/details/124292506
边栏推荐
猜你喜欢

More than 100 days, 0 basic self-study and career change software testing, from 3000 to 15K monthly salary, I compiled a super complete learning guide

从功能测试到自动化测试,待遇翻倍,我整理了这一份3000字超全学习指南

Shenkaihong signed a cooperation agreement with Yisheng technology to jointly build a new ecosystem of business display industry

【AI视野·今日Robot 机器人论文速览 第三十三期】Thu, 21 Apr 2022

.net 用supersocket搭建socket server

【八股文】Redis缓存

LeetCode_ 343 integer split

if-else 优化

10-Streaming Query

mmocr DBLoss
随机推荐
Selenium自动化之弹窗处理
2路CAN/CAN FD 数据记录诊断仪为企业解决偶发性错误难点
梅宏院士:如何构造人工群体智能
2-way can / can FD data recording diagnostic instrument solves the difficulty of accidental errors for enterprises
js复制粘贴,clipboard.js
嵌入式Web项目(一)——Web服务器的引入
.net socket.io客户端使用过程
骗子用AI语音获利近1.8亿,受害者:听不出来是机器人啊
Arithmetic overflow error converting identity to data type int
2022-01-12 微信小程序调试
[Niuke brush question 19] MP3 cursor position
sqlserver中一个表中树形结构递归数据查询
STM32学习记录006——新建工程模板(基于固件库)
Royalscope quickly locates the fault node in the can network and arranges the quality of the CAN bus
Sqlserver determines whether a column in the table contains Chinese, English and pure numbers
Shenkaihong signed a cooperation agreement with Yisheng technology to jointly build a new ecosystem of business display industry
.net 后台上传图片不用保存图片实现压缩图片
13-Set Time Zone
. net using supersocket to build socket server
闭包的概念、作用、问题及解决方式