当前位置:网站首页>pytorch-07.处理多维特征的输入
pytorch-07.处理多维特征的输入
2022-08-10 05:32:00 【生信研究猿】
import numpy as np
import torch
xy = np.loadtxt('diabetes.csv.gz',delimiter=',',dtype=np.float32) # delimiter=',' : 分隔符为,
x_data = torch.from_numpy(xy[:,:-1]) #所有行 从第一列开始,最后一列不要
y_data = torch.from_numpy(xy[:,[-1]]) # 所有行 只要最后一列 []代表拿出来的是矩阵不是向量
class Model(torch.nn.Module):
def __init__(self):
super(Model, self).__init__()
self.linear1 = torch.nn.Linear(8,6)
self.linear2 = torch.nn.Linear(6,4)
self.linear3 = torch.nn.Linear(4,1)
self.sigmoid = torch.nn.Sigmoid()
def forward(self,x):
x = self.sigmoid(self.linear1(x))
x = self.sigmoid(self.linear2(x))
x = self.sigmoid(self.linear3(x))
return x
model = Model()
criterion = torch.nn.BCELoss(size_average=True)
optimizer = torch.optim.SGD(model.parameters(),lr=0.1)
for epoch in range (10000):
y_pred = model(x_data)
loss = criterion(y_pred,y_data)
print('epoch=', epoch, " loss=", loss.item())
optimizer.zero_grad()
loss.backward()
optimizer.step()
#--------------------------------------------------------
# Test Model
x_test = torch.Tensor([[-0.29,0.48,0.18,-0.29,0,0,-0.53,-0.03]])
y_test = model(x_test)
print('y_pred=',y_test.data.item())
测试一组数据:
把激活函数换成RELU:
import numpy as np
import torch
xy = np.loadtxt('diabetes.csv.gz',delimiter=',',dtype=np.float32) # delimiter=',' : 分隔符为,
x_data = torch.from_numpy(xy[:,:-1]) #所有行 从第一列开始,最后一列不要
y_data = torch.from_numpy(xy[:,[-1]]) # 所有行 只要最后一列 []代表拿出来的是矩阵不是向量
class Model(torch.nn.Module):
def __init__(self):
super(Model, self).__init__()
self.linear1 = torch.nn.Linear(8,6)
self.linear2 = torch.nn.Linear(6,4)
self.linear3 = torch.nn.Linear(4,1)
self.sigmoid = torch.nn.Sigmoid()
self.activate = torch.nn.ReLU()
def forward(self,x):
x = self.activate(self.linear1(x))
x = self.activate(self.linear2(x))
x = self.sigmoid(self.linear3(x)) #RELU,x小于0时的的y值都是0,算损失时有可能出现ln0
return x
model = Model()
criterion = torch.nn.BCELoss(size_average=True)
optimizer = torch.optim.SGD(model.parameters(),lr=0.1)
for epoch in range (10000):
y_pred = model(x_data)
loss = criterion(y_pred,y_data)
print('epoch=', epoch, " loss=", loss.item())
optimizer.zero_grad()
loss.backward()
optimizer.step()
#--------------------------------------------------------
# Test Model
x_test = torch.Tensor([[-0.29,0.48,0.18,-0.29,0,0,-0.53,-0.03]])
y_test = model(x_test)
print('y_pred=',y_test.data.item())
边栏推荐
猜你喜欢
One step ahead, don't miss it again, the chain reading APP will be launched soon!
Count down the six weapons of the domestic interface collaboration platform!
最新最全的数字藏品发售日历-07.27
测一测异性的你长什么样?
链读|最新最全的数字藏品发售日历-08.02
MySQL中MyISAM为什么比InnoDB查询快
力扣——省份数量
常用类 String概述
集合 Map
Chain Reading | The latest and most complete digital collection calendar-07.28
随机推荐
转载fstream,ifstream的详细用法
网络安全3
cesium 旋转图片
IDEA连接MySQL数据库并执行SQL查询操作
符号表
力扣——省份数量
去中心化和p2p网络以及中心化为核心的传统通信
Common class BigDecimal
链读推荐:从瓷砖到生成式 NFT
opencv
cesium listens to map zoom or zoom to control whether the content added on the map is displayed
WeChat applet wx.writeBLECharacteristicValue Chinese character to buffer problem
2021-06-22
Batch add watermark to pictures batch add background zoom batch merge tool picUnionV4.0
cesium rotate image
Chained Picks: Starbucks looks at digital collectibles and better engages customers
最新最全的数字藏品发售日历-07.26
Decentralized and p2p networks and traditional communications with centralization at the core
Reprint fstream, detailed usage of ifstream
每天一个小知识点