当前位置:网站首页>pytorch-07.处理多维特征的输入
pytorch-07.处理多维特征的输入
2022-08-10 05:32:00 【生信研究猿】
import numpy as np
import torch
xy = np.loadtxt('diabetes.csv.gz',delimiter=',',dtype=np.float32) # delimiter=',' : 分隔符为,
x_data = torch.from_numpy(xy[:,:-1]) #所有行 从第一列开始,最后一列不要
y_data = torch.from_numpy(xy[:,[-1]]) # 所有行 只要最后一列 []代表拿出来的是矩阵不是向量
class Model(torch.nn.Module):
def __init__(self):
super(Model, self).__init__()
self.linear1 = torch.nn.Linear(8,6)
self.linear2 = torch.nn.Linear(6,4)
self.linear3 = torch.nn.Linear(4,1)
self.sigmoid = torch.nn.Sigmoid()
def forward(self,x):
x = self.sigmoid(self.linear1(x))
x = self.sigmoid(self.linear2(x))
x = self.sigmoid(self.linear3(x))
return x
model = Model()
criterion = torch.nn.BCELoss(size_average=True)
optimizer = torch.optim.SGD(model.parameters(),lr=0.1)
for epoch in range (10000):
y_pred = model(x_data)
loss = criterion(y_pred,y_data)
print('epoch=', epoch, " loss=", loss.item())
optimizer.zero_grad()
loss.backward()
optimizer.step()
#--------------------------------------------------------
# Test Model
x_test = torch.Tensor([[-0.29,0.48,0.18,-0.29,0,0,-0.53,-0.03]])
y_test = model(x_test)
print('y_pred=',y_test.data.item())
测试一组数据:
把激活函数换成RELU:
import numpy as np
import torch
xy = np.loadtxt('diabetes.csv.gz',delimiter=',',dtype=np.float32) # delimiter=',' : 分隔符为,
x_data = torch.from_numpy(xy[:,:-1]) #所有行 从第一列开始,最后一列不要
y_data = torch.from_numpy(xy[:,[-1]]) # 所有行 只要最后一列 []代表拿出来的是矩阵不是向量
class Model(torch.nn.Module):
def __init__(self):
super(Model, self).__init__()
self.linear1 = torch.nn.Linear(8,6)
self.linear2 = torch.nn.Linear(6,4)
self.linear3 = torch.nn.Linear(4,1)
self.sigmoid = torch.nn.Sigmoid()
self.activate = torch.nn.ReLU()
def forward(self,x):
x = self.activate(self.linear1(x))
x = self.activate(self.linear2(x))
x = self.sigmoid(self.linear3(x)) #RELU,x小于0时的的y值都是0,算损失时有可能出现ln0
return x
model = Model()
criterion = torch.nn.BCELoss(size_average=True)
optimizer = torch.optim.SGD(model.parameters(),lr=0.1)
for epoch in range (10000):
y_pred = model(x_data)
loss = criterion(y_pred,y_data)
print('epoch=', epoch, " loss=", loss.item())
optimizer.zero_grad()
loss.backward()
optimizer.step()
#--------------------------------------------------------
# Test Model
x_test = torch.Tensor([[-0.29,0.48,0.18,-0.29,0,0,-0.53,-0.03]])
y_test = model(x_test)
print('y_pred=',y_test.data.item())
边栏推荐
猜你喜欢
国内数字藏品投资价值分析
Day1 微信小程序-小程序代码的构成
String常用方法
十年磨一剑!数字藏品行情软件,链读APP正式开放内测!
Batch add watermark to pictures batch add background zoom batch merge tool picUnionV4.0
测一测异性的你长什么样?
操作表 函数的使用
最新最全的数字藏品发售日历-07.27
Analysis of the investment value of domestic digital collections
Chain Reading Good Article: Jeff Garzik Launches Web3 Production Company
随机推荐
Batch add watermark to pictures batch add background zoom batch merge tool picUnionV4.0
21天挑战杯MySQL——Day06
智能合约和去中心化应用DAPP
事务、存储引擎
ORACLE系统表空间SYSTEM占满无法扩充表空间问题解决过程
cesium add point, move point
Day1 微信小程序-小程序代码的构成
树结构——2-3树图解
Database Notes Create Database, Table Backup
impdp import data
Chained Picks: Starbucks looks at digital collectibles and better engages customers
Chain Reading Recommendation: From Tiles to Generative NFTs
cesium listens to map zoom or zoom to control whether the content added on the map is displayed
view【】【】【】【】
MySql 约束
链读|最新最全的数字藏品发售日历-08.02
impdp 导入数据
我不喜欢我的代码
The complex "metaverse" will be interpreted for you, and the Link Reading APP will be launched soon!
Operation table Function usage