当前位置:网站首页>pytorch-08. Load dataset
pytorch-08. Load dataset
2022-08-10 05:55:00 【Shengxin Research Ape】
import torchfrom torch.utils.data import Dataset # Dataset abstract class, cannot be instantiated, can only be inheritedfrom torch.utils.data import DataLoader # DataLoader can be instantiatedimport numpy as npclass DiabetesDataset(Dataset):def __init__(self, filepath):xy = np.loadtxt(filepath,delimiter=',',dtype=np.float32)self.len = xy.shape[0]self.x_data = torch.from_numpy(xy[:,:-1])self.y_data = torch.from_numpy(xy[:,[-1]])def __getitem__(self, index):return self.x_data[index], self.y_data[index]def __len__(self): #return the length of the datasetreturn self.lendataset = DiabetesDataset('diabetes.csv.gz')train_loader = DataLoader(dataset=dataset,batch_size=32,shuffle=True,num_workers=2) #num_workers read data in parallelclass Model(torch.nn.Module):def __init__(self):super(Model, self).__init__()self.linear1 = torch.nn.Linear(8,6)self.linear2 = torch.nn.Linear(6,4)self.linear3 = torch.nn.Linear(4,1)self.sigmoid = torch.nn.Sigmoid()self.activate = torch.nn.ReLU()def forward(self,x):x = self.activate(self.linear1(x))x = self.activate(self.linear2(x))x = self.sigmoid(self.linear3(x)) #RELU, the y value when x is less than 0 is 0, and ln0 may appear when calculating the lossreturn xmodel = Model()criterion = torch.nn.BCELoss(size_average=True)optimizer = torch.optim.SGD(model.parameters(),lr=0.1)'''Enumerate is mostly used to get the count in the for loop. It can be used to obtain the index and value at the same time, that is, when the index and value values are needed, enumerate can be used'''if __name__ == '__main__':for epoch in range(100):for i,data in enumerate(train_loader,0): #0 means start from the specified index 0print('train_loader:',train_loader)# 1 Prepare datainputs,labels = data# 2 Forwardy_pred = model(inputs)loss = criterion(y_pred, labels)print(epoch,i,loss.item())#3 Backwardoptimizer.zero_grad()loss.backward()#4 Updateoptimizer.step()边栏推荐
猜你喜欢

操作表 函数的使用

Chain Reading|The latest and most complete digital collection sales calendar-08.02

Machine Learning - Clustering - Shopping Mall Customer Clustering

树结构——二叉查找树原理与实现

PyTorch 之 可视化网络架构

Chain Reading | The latest and most complete digital collection calendar-07.28

pytorch-10.卷积神经网络(作业)

pytorch-09.多分类问题

复杂的“元宇宙”,为您解读,链读APP即将上线!

常用类 String概述
随机推荐
先人一步,不再错过,链读APP即将上线!
探索性数据分析EDA
链表API设计
Knowledge Distillation Thesis Learning
Common class BigDecimal
LeetCode 剑指offer 21.调整数组顺序使奇数位于偶数前面(简单)
Collection tool class
深度学习阶段性报告(一)
转载fstream,ifstream的详细用法
MySQL中MyISAM为什么比InnoDB查询快
多表查询 笔记
Mini Program Study Notes: Communication between Mini Program Components
pytorch-07.处理多维特征的输入
21天挑战杯MySQL-Day05
知识蒸馏论文学习
The latest and most complete digital collection sales calendar-07.27
ACID四种特性
基于 .NET Core MVC 的权限管理系统
tinymce rich text editor
机器学习——聚类——商场客户聚类