当前位置:网站首页>【Pytorch】nn.Linear,nn.Conv
【Pytorch】nn.Linear,nn.Conv
2022-08-11 06:28:00 【二进制人工智能】
nn.Linear
nn.Conv1d
当nn.Conv1d
的kernel_size=1
时,效果与nn.Linear
相同,不过输入数据格式不同:
https://blog.csdn.net/l1076604169/article/details/107170146
import torch
def count_parameters(model):
"""Count the number of parameters in a model."""
return sum([p.numel() for p in model.parameters()])
conv = torch.nn.Conv1d(3, 32, kernel_size=1)
print(count_parameters(conv))
# 128
linear = torch.nn.Linear(3, 32)
print(count_parameters(linear))
# 128
print(conv.weight.shape)
# torch.Size([32, 3, 1])
print(linear.weight.shape)
# torch.Size([32, 3])
# use same initialization
linear.weight = torch.nn.Parameter(conv.weight.squeeze(2))
linear.bias = torch.nn.Parameter(conv.bias)
tensor = torch.randn(128, 256, 3) # [batch, feature_num,feature_size]
permuted_tensor = tensor.permute(0, 2, 1).clone().contiguous() # [batch, feature_size,feature_num]
out_linear = linear(tensor)
print(out_linear.mean())
# tensor(0.0344, grad_fn=<MeanBackward0>)
print(out_linear.shape)
# torch.Size([128, 256, 32])
out_conv = conv(permuted_tensor)
print(out_conv.mean())
# tensor(0.0344, grad_fn=<MeanBackward0>)
print(out_conv.shape)
# torch.Size([128, 32, 256])
nn.Conv2d
nn.Conv3d
边栏推荐
猜你喜欢
随机推荐
JVM学习——3——数据一致性
js判断图片是否存在
Redis测试
Go语言实现Etcd服务发现(Etcd & Service Discovery & Go)
语音信号处理:预处理【预加重、分帧、加窗】
exness:黄金1800关口遇阻,静待美国CPI出炉
2022-08-10 第四小组 修身课 学习笔记(every day)
详述MIMIC 的ICU患者检测时间信息表(十六)
每日sql -用户两天留存率
Redis源码-String:Redis String命令、Redis String存储原理、Redis字符串三种编码类型、Redis String SDS源码解析、Redis String应用场景
sql--7天内(含当天)购买次数超过3次(含),且近7天的购买金额超过1000的用户
1688商品详情接口
Douyin get douyin share password url API return value description
抖音分享口令url API工具
radix-4 FFT 原理和C语言代码实现
Coordinate system in navigation and positioning
Daily sql-employee bonus filtering and answer rate ranking first
从苹果、SpaceX等高科技企业的产品发布会看企业产品战略和敏捷开发的关系
maxwell 概念
mysql视图与索引