当前位置:网站首页>记录贴:pytorch学习Part5
记录贴:pytorch学习Part5
2022-08-08 17:33:00 【安联之夜】
记录贴:pytorch学习Part5
一、Embedding
import torch
import torch.nn as nn
# nn.Embedding
word_to_ix = {
"hello":0,"world":1}
lookup_tensor = torch.tensor([word_to_ix["hello"]],dtype=torch.long)
embeds = nn.Embedding(2,5)#两个单词五个维度,查表得到
hello_embed = embeds(lookup_tensor)
#该表一般由词向量模型提前生成
# load word embedding
rnn = RNN(len(TEXT.vocab), 100, 256)
pretrained_embedding = TEXT.vocab.vectors
print('pretrained_embedding:', pretrained_embedding.shape)
rnn.embedding.weight.data.copy_(pretrained_embedding)
print('embedding layer inited.')
二、RNN
#单层Rnn
rnn = nn.RNN(100,10)#词向量维度为100,短期记忆长度为20
rnn._parameters.keys()
rnn = nn.RNN(input_size = 100,hidden_size = 20,num_layers = 1)#单词100维,记忆20维
x = torch.randn(10,3,100)#单词,句子,向量
out,h = rnn(x,torch.zeros(1,3,20))#ho,一层,三个句子,记忆20维,可空
print(out.shape,h.shape)
#Rnncell
cell1 = nn.RNNCell(100,20)#100维,记忆20
h1 = torch.zeros(3,20)
x = torch.randn(10,3,100)
for xt in x:
h1 = cell1(xt,h1)
print(h1.shape)
#双层Rnn
rnn = nn.RNN(100,10,num_layers = 2)
rnn._parameters.keys()
rnn = nn.RNN(input_size = 100,hidden_size = 20,num_layers = 4)#单词100维,记忆20维
x = torch.randn(10,3,100)#单词,句子,向量
out,h = rnn(x)
print(out.shape,h.shape)
cell1 = nn.RNNCell(100,30)
cell2 = nn.RNNCell(30,20)
h1 = torch.zeros(3,30)
h2 = torch.zeros(3,20)
for xt in x:
h1 = cell1(xt,h1)
h2 = cell2(h1,h2)
print(h2.shape)
三、Lstm
lstm = nn.LSTM(input_size=100,hidden_size=20,num_layers=4)
print(lstm)
x = torch.randn(10,3,100)
out,(h,c) = lstm(x)
cell = nn.LSTMCell(input_size=100,hidden_size=20)
h = torch.zeros(3,20)
c = torch.zeros(3,20)
for xt in x:
h,c = cell(xt,[h,c])
cell1 = nn.LSTMCell(input_size=100,hidden_size=30)
cell2 = nn.LSTMCell(input_size=30,hidden_size=20)
h1 = torch.zeros(3,30)
c1 = torch.zeros(3,30)
h2 = torch.zeros(3,20)
c2 = torch.zeros(3,20)
for xt in x:
h1,c1 = cell1(xt,[h1,c1])
h2,c2 = cell2(h1,[h2,c2])
四、时间序列预测
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch import optim
#基本参数
num_time_steps = 50
input_size = 1
hidden_size = 16
output_size = 1
lr = 0.01
iterations = 6000
#生成训练数据
start = np.random.randint(3,size = 1)[0]
time_steps = np.linspace(start,start + 10,num_time_steps)
data = np.sin(time_steps)
data = data.reshape(num_time_steps,1)
x = torch.tensor(data[:-1]).float().view(1,num_time_steps - 1,1)
y = torch.tensor(data[1:]).float().view(1,num_time_steps - 1,1)
#定义RNN
class Net(nn.Module):
def __init__(self):
super(Net,self).__init__()
self.rnn = nn.RNN(
input_size = input_size,
hidden_size = hidden_size,
num_layers = 1,
batch_first = True,
)
self.linear = nn.Linear(hidden_size,output_size)
#前向传播过程
def forward(self,x,hidden_prev):
out,hidden_prev = self.rnn(x,hidden_prev)
out = out.view(-1,hidden_size)
out = self.linear(out)
out = out.unsqueeze(dim = 0)
return out,hidden_prev
model = Net()
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(),lr = lr)
hidden_prev = torch.zeros(1,1,hidden_size)
for iter in range(iterations):
output,hidden_prev = model(x,hidden_prev)
hidden_prev = hidden_prev.detach()
loss = criterion(output,y)
model.zero_grad()
loss.backward()
optimizer.step()
if iter % 100 == 0:
print(f"Iteration: {
iter} loss: {
loss.item()}")
predictions = []
input = x[:,0,:]
for _ in range(x.shape[1]):
input = input.view(1,1,1)
(pred,hidden_prev) = model(input,hidden_prev)
input = pred
predictions.append(pred.detach().numpy().ravel()[0])
x = x.data.numpy().ravel()
y = y.data.numpy()
plt.scatter(time_steps[:-1], x.ravel(), s=90)
plt.plot(time_steps[:-1], x.ravel())
plt.scatter(time_steps[1:], predictions)
plt.show()
边栏推荐
猜你喜欢
随机推荐
poj1961 Period(KMP)
【NodeJs篇】fs文件系统模块
爬百度图片
Regular use in js
中金证券股票开户流程是什么,我需要准备身份证吗,安全吗
Obtain - 64 [chances] : the soldier, subtlety also - 5 - read sun tzu - melee meter
leetcode:306. 累加数
顺序表与链表结构及解析
leetcode:313. 超级丑数
一甲子,正青春,CCF创建六十周年庆典在苏州举行
永续合约交易所系统开发逻辑详情
L2-025 分而治之 (25 分)
C人脸识别
Open source summer | I have nothing to do during the epidemic, I made a button display box special effect to display my blog
Tensorflow教程(五)——MNIST项目提高
【AI玩家养成记 • 第3期】AI开发者必备!史上最适合新手的昇腾AI环境搭建教程!!
L2-011 玩转二叉树 (25 分) (二叉树)
How to set timed network disconnection to assist self-discipline in win10
L2-015 互评成绩 (25 分)
比较器是否可以当做运放使用?