当前位置:网站首页>Pytorch - 07. Multidimensional characteristics of input processing
Pytorch - 07. Multidimensional characteristics of input processing
2022-08-10 05:56:00 【Shengxin Research Ape】
import numpy as npimport torchxy = np.loadtxt('diabetes.csv.gz',delimiter=',',dtype=np.float32) # delimiter=',' : the delimiter is,x_data = torch.from_numpy(xy[:,:-1]) #All rows start from the first column, not the last columny_data = torch.from_numpy(xy[:,[-1]]) # All rows as long as the last column [] represents a matrix not a vectorclass Model(torch.nn.Module):def __init__(self):super(Model, self).__init__()self.linear1 = torch.nn.Linear(8,6)self.linear2 = torch.nn.Linear(6,4)self.linear3 = torch.nn.Linear(4,1)self.sigmoid = torch.nn.Sigmoid()def forward(self,x):x = self.sigmoid(self.linear1(x))x = self.sigmoid(self.linear2(x))x = self.sigmoid(self.linear3(x))return xmodel = Model()criterion = torch.nn.BCELoss(size_average=True)optimizer = torch.optim.SGD(model.parameters(),lr=0.1)for epoch in range (10000):y_pred = model(x_data)loss = criterion(y_pred,y_data)print('epoch=', epoch, " loss=", loss.item())optimizer.zero_grad()loss.backward()optimizer.step()#--------------------------------------------------------# Test Modelx_test = torch.Tensor([[-0.29,0.48,0.18,-0.29,0,0,-0.53,-0.03]])y_test = model(x_test)print('y_pred=',y_test.data.item())
Test a set of data:
Change the activation function to RELU:
import numpy as npimport torchxy = np.loadtxt('diabetes.csv.gz',delimiter=',',dtype=np.float32) # delimiter=',' : the delimiter is,x_data = torch.from_numpy(xy[:,:-1]) #All rows start from the first column, not the last columny_data = torch.from_numpy(xy[:,[-1]]) # All rows as long as the last column [] represents a matrix not a vectorclass Model(torch.nn.Module):def __init__(self):super(Model, self).__init__()self.linear1 = torch.nn.Linear(8,6)self.linear2 = torch.nn.Linear(6,4)self.linear3 = torch.nn.Linear(4,1)self.sigmoid = torch.nn.Sigmoid()self.activate = torch.nn.ReLU()def forward(self,x):x = self.activate(self.linear1(x))x = self.activate(self.linear2(x))x = self.sigmoid(self.linear3(x)) #RELU, the y value when x is less than 0 is 0, and ln0 may appear when calculating the lossreturn xmodel = Model()criterion = torch.nn.BCELoss(size_average=True)optimizer = torch.optim.SGD(model.parameters(),lr=0.1)for epoch in range (10000):y_pred = model(x_data)loss = criterion(y_pred,y_data)print('epoch=', epoch, " loss=", loss.item())optimizer.zero_grad()loss.backward()optimizer.step()#--------------------------------------------------------# Test Modelx_test = torch.Tensor([[-0.29,0.48,0.18,-0.29,0,0,-0.53,-0.03]])y_test = model(x_test)print('y_pred=',y_test.data.item())
边栏推荐
猜你喜欢
pytorch-07.处理多维特征的输入
ORACLE system table space SYSTEM is full and cannot expand table space problem solving process
generic notes()()()
LeetCode refers to offer 10-I. Fibonacci sequence (simple)
Timer (setInterval) on and off
Day1 微信小程序-小程序代码的构成
基于 .NET Core MVC 的权限管理系统
Machine Learning - Clustering - Shopping Mall Customer Clustering
国内数字藏品投资价值分析
Database Notes Create Database, Table Backup
随机推荐
Convolutional Neural Network (CNN) for mnist handwritten digit recognition
LeetCode 100.相同的树(简单)
[List Exercise] Traverse the collection and sort by price from low to high,
Chain Reading | The latest and most complete digital collection calendar-07.28
WeChat applet wx.writeBLECharacteristicValue Chinese character to buffer problem
error in ./node_modules/cesium/Source/ThirdParty/zip.js
The way for programmers to make money from a sideline business and increase their monthly income by 20K
pytorch-06. Logistic regression
Canal reports Could not find first log file name in binary log index file
Reflection 【Notes】
Operation table Function usage
network security firewall
MySQL中MyISAM为什么比InnoDB查询快
深度学习TensorFlow入门环境配置
Privatisation build personal network backup NextCloud
pytorch-10.卷积神经网络(作业)
Batch add watermark to pictures batch add background zoom batch merge tool picUnionV4.0
LeetCode 292. Nim Game (Simple)
我不喜欢我的代码
Reprint fstream, detailed usage of ifstream