当前位置:网站首页>pytorch-08.加载数据集
pytorch-08.加载数据集
2022-08-10 05:32:00 【生信研究猿】
import torch
from torch.utils.data import Dataset # Dataset抽象类,不可实例化,只能继承
from torch.utils.data import DataLoader # DataLoader 可实例化
import numpy as np
class DiabetesDataset(Dataset):
def __init__(self,filepath):
xy = np.loadtxt(filepath,delimiter=',',dtype = np.float32)
self.len = xy.shape[0]
self.x_data = torch.from_numpy(xy[:,:-1])
self.y_data = torch.from_numpy(xy[:,[-1]])
def __getitem__(self, index):
return self.x_data[index],self.y_data[index]
def __len__(self): #返回数据集长度
return self.len
dataset = DiabetesDataset('diabetes.csv.gz')
train_loader = DataLoader(dataset=dataset,batch_size=32,shuffle=True,num_workers=2) #num_workers 并行读数据
class Model(torch.nn.Module):
def __init__(self):
super(Model, self).__init__()
self.linear1 = torch.nn.Linear(8,6)
self.linear2 = torch.nn.Linear(6,4)
self.linear3 = torch.nn.Linear(4,1)
self.sigmoid = torch.nn.Sigmoid()
self.activate = torch.nn.ReLU()
def forward(self,x):
x = self.activate(self.linear1(x))
x = self.activate(self.linear2(x))
x = self.sigmoid(self.linear3(x)) #RELU,x小于0时的的y值都是0,算损失时有可能出现ln0
return x
model = Model()
criterion = torch.nn.BCELoss(size_average=True)
optimizer = torch.optim.SGD(model.parameters(),lr=0.1)
'''
enumerate多用于在for循环中得到计数,利用它可以同时获得索引和值,即需要index和value值的时候可以使用enumerate
'''
if __name__ == '__main__':
for epoch in range(100):
for i,data in enumerate(train_loader,0): #0代表从指定索引0开始
print('train_loader:',train_loader)
# 1 Prepare data
inputs,labels = data
# 2 Forward
y_pred = model(inputs)
loss = criterion(y_pred,labels)
print(epoch,i,loss.item())
# 3 Backward
optimizer.zero_grad()
loss.backward()
#4 Update
optimizer.step()
边栏推荐
- 第二次实验
- 小程序wx.request简单Promise封装
- Chain Reading Recommendation: From Tiles to Generative NFTs
- Index Notes【】【】
- Bifrost micro synchronous database implementation services across the library data synchronization
- 复杂的“元宇宙”,为您解读,链读APP即将上线!
- Analysis of the investment value of domestic digital collections
- 我不喜欢我的代码
- [Notes] Collection Framework System Collection
- 各个架构指令集对应的机型
猜你喜欢
随机推荐
Bifrost micro synchronous database implementation services across the library data synchronization
常用类 BigDecimal
Linux database Oracle client installation, used for shell scripts to connect to the database with sqlplus
el-dropdown下拉菜单样式修改,去掉小三角
链读好文:Jeff Garzik 推出 Web3 制作公司
Count down the six weapons of the domestic interface collaboration platform!
Decentralized and p2p networks and traditional communications with centralization at the core
Smart contracts and DAPP decentralized applications
你不知道的常规流
智能合约和去中心化应用DAPP
栈和队列
IO流【】【】【】
细说MySql索引原理
知识蒸馏论文学习
Copy large files with crontab
力扣——统计只差一个字符的子串数目
并查集原理与API设计
21天挑战杯MySQL——Day06
集合 set接口
2021-07-09