当前位置:网站首页>Pytorch - 07. Multidimensional characteristics of input processing
Pytorch - 07. Multidimensional characteristics of input processing
2022-08-10 05:56:00 【Shengxin Research Ape】
import numpy as npimport torchxy = np.loadtxt('diabetes.csv.gz',delimiter=',',dtype=np.float32) # delimiter=',' : the delimiter is,x_data = torch.from_numpy(xy[:,:-1]) #All rows start from the first column, not the last columny_data = torch.from_numpy(xy[:,[-1]]) # All rows as long as the last column [] represents a matrix not a vectorclass Model(torch.nn.Module):def __init__(self):super(Model, self).__init__()self.linear1 = torch.nn.Linear(8,6)self.linear2 = torch.nn.Linear(6,4)self.linear3 = torch.nn.Linear(4,1)self.sigmoid = torch.nn.Sigmoid()def forward(self,x):x = self.sigmoid(self.linear1(x))x = self.sigmoid(self.linear2(x))x = self.sigmoid(self.linear3(x))return xmodel = Model()criterion = torch.nn.BCELoss(size_average=True)optimizer = torch.optim.SGD(model.parameters(),lr=0.1)for epoch in range (10000):y_pred = model(x_data)loss = criterion(y_pred,y_data)print('epoch=', epoch, " loss=", loss.item())optimizer.zero_grad()loss.backward()optimizer.step()#--------------------------------------------------------# Test Modelx_test = torch.Tensor([[-0.29,0.48,0.18,-0.29,0,0,-0.53,-0.03]])y_test = model(x_test)print('y_pred=',y_test.data.item())Test a set of data:

Change the activation function to RELU:
import numpy as npimport torchxy = np.loadtxt('diabetes.csv.gz',delimiter=',',dtype=np.float32) # delimiter=',' : the delimiter is,x_data = torch.from_numpy(xy[:,:-1]) #All rows start from the first column, not the last columny_data = torch.from_numpy(xy[:,[-1]]) # All rows as long as the last column [] represents a matrix not a vectorclass Model(torch.nn.Module):def __init__(self):super(Model, self).__init__()self.linear1 = torch.nn.Linear(8,6)self.linear2 = torch.nn.Linear(6,4)self.linear3 = torch.nn.Linear(4,1)self.sigmoid = torch.nn.Sigmoid()self.activate = torch.nn.ReLU()def forward(self,x):x = self.activate(self.linear1(x))x = self.activate(self.linear2(x))x = self.sigmoid(self.linear3(x)) #RELU, the y value when x is less than 0 is 0, and ln0 may appear when calculating the lossreturn xmodel = Model()criterion = torch.nn.BCELoss(size_average=True)optimizer = torch.optim.SGD(model.parameters(),lr=0.1)for epoch in range (10000):y_pred = model(x_data)loss = criterion(y_pred,y_data)print('epoch=', epoch, " loss=", loss.item())optimizer.zero_grad()loss.backward()optimizer.step()#--------------------------------------------------------# Test Modelx_test = torch.Tensor([[-0.29,0.48,0.18,-0.29,0,0,-0.53,-0.03]])y_test = model(x_test)print('y_pred=',y_test.data.item())
边栏推荐
- pytorch-06. Logistic regression
- el-dropdown drop-down menu style modification, remove the small triangle
- shell脚本中利用sqlplus操作数据库
- LeetCode 1720. Decoding XORed Arrays (Simple)
- 树结构——2-3树图解
- wiki confluence installation
- pytorch-10.卷积神经网络(作业)
- 反射【笔记】
- LeetCode 2011. Variable Value After Action (Simple)
- Test of the opposite sex what you look like?
猜你喜欢
随机推荐
开源免费WMS仓库管理系统【推荐】
堆的原理与实现以及排序
pytorch-09.多分类问题
Convolutional Neural Network (CNN) for mnist handwritten digit recognition
棋类游戏-五子棋小游戏
探索性数据分析EDA
Chain Reading Good Article: Jeff Garzik Launches Web3 Production Company
笔记1
Link reading good article: What is the difference between hot encrypted storage and cold encrypted storage?
Machine Learning - Clustering - Shopping Mall Customer Clustering
win12 modify dns script
LeetCode 1894.找到需要补充粉笔的学生编号
opencv
Collection set interface
sqlplus displays the previous command and the available backspace key
tinymce rich text editor
Count down the six weapons of the domestic interface collaboration platform!
pytorch-06.逻辑斯蒂回归
常用类 String概述
Reprint fstream, detailed usage of ifstream








