当前位置:网站首页>记录贴:pytorch学习Part5
记录贴:pytorch学习Part5
2022-08-08 17:33:00 【安联之夜】
记录贴:pytorch学习Part5
一、Embedding
import torch
import torch.nn as nn
# nn.Embedding
word_to_ix = {
"hello":0,"world":1}
lookup_tensor = torch.tensor([word_to_ix["hello"]],dtype=torch.long)
embeds = nn.Embedding(2,5)#两个单词五个维度,查表得到
hello_embed = embeds(lookup_tensor)
#该表一般由词向量模型提前生成
# load word embedding
rnn = RNN(len(TEXT.vocab), 100, 256)
pretrained_embedding = TEXT.vocab.vectors
print('pretrained_embedding:', pretrained_embedding.shape)
rnn.embedding.weight.data.copy_(pretrained_embedding)
print('embedding layer inited.')
二、RNN
#单层Rnn
rnn = nn.RNN(100,10)#词向量维度为100,短期记忆长度为20
rnn._parameters.keys()
rnn = nn.RNN(input_size = 100,hidden_size = 20,num_layers = 1)#单词100维,记忆20维
x = torch.randn(10,3,100)#单词,句子,向量
out,h = rnn(x,torch.zeros(1,3,20))#ho,一层,三个句子,记忆20维,可空
print(out.shape,h.shape)
#Rnncell
cell1 = nn.RNNCell(100,20)#100维,记忆20
h1 = torch.zeros(3,20)
x = torch.randn(10,3,100)
for xt in x:
h1 = cell1(xt,h1)
print(h1.shape)
#双层Rnn
rnn = nn.RNN(100,10,num_layers = 2)
rnn._parameters.keys()
rnn = nn.RNN(input_size = 100,hidden_size = 20,num_layers = 4)#单词100维,记忆20维
x = torch.randn(10,3,100)#单词,句子,向量
out,h = rnn(x)
print(out.shape,h.shape)
cell1 = nn.RNNCell(100,30)
cell2 = nn.RNNCell(30,20)
h1 = torch.zeros(3,30)
h2 = torch.zeros(3,20)
for xt in x:
h1 = cell1(xt,h1)
h2 = cell2(h1,h2)
print(h2.shape)
三、Lstm
lstm = nn.LSTM(input_size=100,hidden_size=20,num_layers=4)
print(lstm)
x = torch.randn(10,3,100)
out,(h,c) = lstm(x)
cell = nn.LSTMCell(input_size=100,hidden_size=20)
h = torch.zeros(3,20)
c = torch.zeros(3,20)
for xt in x:
h,c = cell(xt,[h,c])
cell1 = nn.LSTMCell(input_size=100,hidden_size=30)
cell2 = nn.LSTMCell(input_size=30,hidden_size=20)
h1 = torch.zeros(3,30)
c1 = torch.zeros(3,30)
h2 = torch.zeros(3,20)
c2 = torch.zeros(3,20)
for xt in x:
h1,c1 = cell1(xt,[h1,c1])
h2,c2 = cell2(h1,[h2,c2])
四、时间序列预测
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch import optim
#基本参数
num_time_steps = 50
input_size = 1
hidden_size = 16
output_size = 1
lr = 0.01
iterations = 6000
#生成训练数据
start = np.random.randint(3,size = 1)[0]
time_steps = np.linspace(start,start + 10,num_time_steps)
data = np.sin(time_steps)
data = data.reshape(num_time_steps,1)
x = torch.tensor(data[:-1]).float().view(1,num_time_steps - 1,1)
y = torch.tensor(data[1:]).float().view(1,num_time_steps - 1,1)
#定义RNN
class Net(nn.Module):
def __init__(self):
super(Net,self).__init__()
self.rnn = nn.RNN(
input_size = input_size,
hidden_size = hidden_size,
num_layers = 1,
batch_first = True,
)
self.linear = nn.Linear(hidden_size,output_size)
#前向传播过程
def forward(self,x,hidden_prev):
out,hidden_prev = self.rnn(x,hidden_prev)
out = out.view(-1,hidden_size)
out = self.linear(out)
out = out.unsqueeze(dim = 0)
return out,hidden_prev
model = Net()
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(),lr = lr)
hidden_prev = torch.zeros(1,1,hidden_size)
for iter in range(iterations):
output,hidden_prev = model(x,hidden_prev)
hidden_prev = hidden_prev.detach()
loss = criterion(output,y)
model.zero_grad()
loss.backward()
optimizer.step()
if iter % 100 == 0:
print(f"Iteration: {
iter} loss: {
loss.item()}")
predictions = []
input = x[:,0,:]
for _ in range(x.shape[1]):
input = input.view(1,1,1)
(pred,hidden_prev) = model(input,hidden_prev)
input = pred
predictions.append(pred.detach().numpy().ravel()[0])
x = x.data.numpy().ravel()
y = y.data.numpy()
plt.scatter(time_steps[:-1], x.ravel(), s=90)
plt.plot(time_steps[:-1], x.ravel())
plt.scatter(time_steps[1:], predictions)
plt.show()
边栏推荐
猜你喜欢
orbslam2实验记录-----稠密建图
DSPE-PEG-Biotin,385437-57-0,磷脂-聚乙二醇-生物素用于生物分子的检测和纯化
Camera calibration toobox for Matlab(一)—— 工具包的基本使用
[Paper Reading] RAL 2022: Receding Moving Object Segmentation in 3D LiDAR Data Using Sparse 4D Convolutions
2 prerequisites for the success of "digital transformation" of enterprises!
Prometheus+Grafana监控系统
Fluorescein-PEG-CLS,胆固醇-聚乙二醇-荧光素用于缩短包封周期
arxiv国内镜像——快速下载
LeetCode_二叉树_中等_515.在每个树行中找最大值
R file not found problem
随机推荐
Tensorflow教程(四)——MNIST项目入门
什么是服务网格?在微服务体系中又是如何使用的?
socket concept
The difference between rv and sv
永续合约交易所系统开发逻辑详情
看到这个应用上下线方式,不禁感叹:优雅,太优雅了!
LeetCode_回溯_中等_491.递增子序列
leetcode:306. 累加数
Cy5反式环辛烯,TCO-Cy5,Cy5 trans-cyclooctene标记生物分子
并发与并行
用皮肤“听”音乐,网友戴上这款装备听音乐会:仿佛住在钢琴里
R file not found problem
在指南针炒股软件中的指标靠谱吗?安全吗?
Prometheus+Grafana监控系统
从2022投影行业最新报告,读懂2022年家用智能投影仪该怎么选!
L2-024 部落 (25 分)(并查集)
以数治企,韧性成长,2022 年中国 CIO 数字峰会成功举行
史上最强IDEA工具使用教程,你想要的全都有!
arxiv国内镜像——快速下载
The difference between a uri (url urn)