当前位置:网站首页>【Pytorch】nn.Linear,nn.Conv
【Pytorch】nn.Linear,nn.Conv
2022-08-11 06:28:00 【二进制人工智能】
nn.Linear
nn.Conv1d
当nn.Conv1d
的kernel_size=1
时,效果与nn.Linear
相同,不过输入数据格式不同:
https://blog.csdn.net/l1076604169/article/details/107170146
import torch
def count_parameters(model):
"""Count the number of parameters in a model."""
return sum([p.numel() for p in model.parameters()])
conv = torch.nn.Conv1d(3, 32, kernel_size=1)
print(count_parameters(conv))
# 128
linear = torch.nn.Linear(3, 32)
print(count_parameters(linear))
# 128
print(conv.weight.shape)
# torch.Size([32, 3, 1])
print(linear.weight.shape)
# torch.Size([32, 3])
# use same initialization
linear.weight = torch.nn.Parameter(conv.weight.squeeze(2))
linear.bias = torch.nn.Parameter(conv.bias)
tensor = torch.randn(128, 256, 3) # [batch, feature_num,feature_size]
permuted_tensor = tensor.permute(0, 2, 1).clone().contiguous() # [batch, feature_size,feature_num]
out_linear = linear(tensor)
print(out_linear.mean())
# tensor(0.0344, grad_fn=<MeanBackward0>)
print(out_linear.shape)
# torch.Size([128, 256, 32])
out_conv = conv(permuted_tensor)
print(out_conv.mean())
# tensor(0.0344, grad_fn=<MeanBackward0>)
print(out_conv.shape)
# torch.Size([128, 32, 256])
nn.Conv2d
nn.Conv3d
边栏推荐
- Trill keyword search goods - API
- jar服务导致cpu飙升问题-带解决方法
- 拼多多API接口(附上我的可用API)
- What are the things that should be planned from the beginning when developing a project with Unity?How to avoid a huge pit in the later stage?
- 京东商品详情API调用实例讲解
- Redis测试
- Douyin API interface
- 2022-08-09 Group 4 Self-cultivation class study notes (every day)
- 亚马逊获得AMAZON商品详情 API 返回值说明
- Taobao sku API interface (PHP example)
猜你喜欢
Daily sql: request for friend application pass rate
Douyin get douyin share password url API return value description
Edge provides label grouping functionality
《猪猪1984》NFT 作品集将上线 The Sandbox 市场平台
Get Pinduoduo product information operation details
每日sql-统计各个专业人数(包括专业人数为0的)
获取拼多多商品信息操作详情
下一代 无线局域网--强健性
导航定位中的坐标系
EasyPlayer针对H.265视频不自动播放设置下,loading状态无法消失的解决办法
随机推荐
基于FPGA的FIR滤波器的实现(4)— 串行结构FIR滤波器的FPGA代码实现
Discourse 的关闭主题(Close Topic )和重新开放主题
Discourse's Close Topic and Reopen Topic
ssh服务攻防与加固
【推荐系统】:协同过滤和基于内容过滤概述
【latex异常和错误】Missing $ inserted.<inserted text>You can‘t use \spacefactor in math mode.输出文本要注意特殊字符的转义
拼多多api接口应用示例
Implement general-purpose, high-performance sorting and quicksort optimizations
基于FPGA的FIR滤波器的实现(5)— 并行结构FIR滤波器的FPGA代码实现
Coordinate system in navigation and positioning
空间金字塔池化 -Spatial Pyramid Pooling(含源码)
从 dpdk-20.11 移植 intel E810 百 G 网卡 pmd 驱动到 dpdk-16.04 中
那些事情是用Unity开发项目应该一开始规划好的?如何避免后期酿成巨坑?
技能在赛题解析:交换机防环路设置
每日sql--统计员工近三个月的总薪水(不包括最新一个月)
Daily sql - judgment + aggregation
皮质-皮质网络的多尺度交流
LeetCode刷题系列 -- 46. 全排列
Spatial Pyramid Pooling -Spatial Pyramid Pooling (including source code)
EasyPlayer针对H.265视频不自动播放设置下,loading状态无法消失的解决办法