当前位置:网站首页>Pytorch - 07. Multidimensional characteristics of input processing
Pytorch - 07. Multidimensional characteristics of input processing
2022-08-10 05:56:00 【Shengxin Research Ape】
import numpy as npimport torchxy = np.loadtxt('diabetes.csv.gz',delimiter=',',dtype=np.float32) # delimiter=',' : the delimiter is,x_data = torch.from_numpy(xy[:,:-1]) #All rows start from the first column, not the last columny_data = torch.from_numpy(xy[:,[-1]]) # All rows as long as the last column [] represents a matrix not a vectorclass Model(torch.nn.Module):def __init__(self):super(Model, self).__init__()self.linear1 = torch.nn.Linear(8,6)self.linear2 = torch.nn.Linear(6,4)self.linear3 = torch.nn.Linear(4,1)self.sigmoid = torch.nn.Sigmoid()def forward(self,x):x = self.sigmoid(self.linear1(x))x = self.sigmoid(self.linear2(x))x = self.sigmoid(self.linear3(x))return xmodel = Model()criterion = torch.nn.BCELoss(size_average=True)optimizer = torch.optim.SGD(model.parameters(),lr=0.1)for epoch in range (10000):y_pred = model(x_data)loss = criterion(y_pred,y_data)print('epoch=', epoch, " loss=", loss.item())optimizer.zero_grad()loss.backward()optimizer.step()#--------------------------------------------------------# Test Modelx_test = torch.Tensor([[-0.29,0.48,0.18,-0.29,0,0,-0.53,-0.03]])y_test = model(x_test)print('y_pred=',y_test.data.item())
Test a set of data:
Change the activation function to RELU:
import numpy as npimport torchxy = np.loadtxt('diabetes.csv.gz',delimiter=',',dtype=np.float32) # delimiter=',' : the delimiter is,x_data = torch.from_numpy(xy[:,:-1]) #All rows start from the first column, not the last columny_data = torch.from_numpy(xy[:,[-1]]) # All rows as long as the last column [] represents a matrix not a vectorclass Model(torch.nn.Module):def __init__(self):super(Model, self).__init__()self.linear1 = torch.nn.Linear(8,6)self.linear2 = torch.nn.Linear(6,4)self.linear3 = torch.nn.Linear(4,1)self.sigmoid = torch.nn.Sigmoid()self.activate = torch.nn.ReLU()def forward(self,x):x = self.activate(self.linear1(x))x = self.activate(self.linear2(x))x = self.sigmoid(self.linear3(x)) #RELU, the y value when x is less than 0 is 0, and ln0 may appear when calculating the lossreturn xmodel = Model()criterion = torch.nn.BCELoss(size_average=True)optimizer = torch.optim.SGD(model.parameters(),lr=0.1)for epoch in range (10000):y_pred = model(x_data)loss = criterion(y_pred,y_data)print('epoch=', epoch, " loss=", loss.item())optimizer.zero_grad()loss.backward()optimizer.step()#--------------------------------------------------------# Test Modelx_test = torch.Tensor([[-0.29,0.48,0.18,-0.29,0,0,-0.53,-0.03]])y_test = model(x_test)print('y_pred=',y_test.data.item())
边栏推荐
猜你喜欢
索引笔记【】【】
视图【】【】【】【】
Collection set interface
One step ahead, don't miss it again, the chain reading APP will be launched soon!
LeetCode 1894.找到需要补充粉笔的学生编号
LeetCode 1720.解码异或后的数组(简单)
Machine Learning - Clustering - Shopping Mall Customer Clustering
Analysis of the investment value of domestic digital collections
pytorch-09.多分类问题
细说MySql索引原理
随机推荐
Chain Reading Good Article: Jeff Garzik Launches Web3 Production Company
最新最全的数字藏品发售日历-07.27
21天挑战杯MySQL——Day06
链表API设计
Collection Map
The submenu of the el-cascader cascade selector is double-clicked to display the selected content
Privatisation build personal network backup NextCloud
机器学习——聚类——商场客户聚类
学生管理系统以及其简单功能的实现
LeetCode 292.Nim 游戏(简单)
wiki confluence installation
树结构——2-3树图解
generic notes()()()
我不喜欢我的代码
One step ahead, don't miss it again, the chain reading APP will be launched soon!
Chain Reading|The latest and most complete digital collection sales calendar-08.02
Mini Program Study Notes: Communication between Mini Program Components
LeetCode 面试题17.14 最小k个数(中等)
.Net Core imports tens of millions of data to Mysql
Chain Reading | The latest and most complete digital collection calendar-07.28