当前位置:网站首页>提取CNN模型中间层输出方法
提取CNN模型中间层输出方法
2022-04-21 15:30:00 【8倍】
前言
针对结构中定义了多个nn.sequential的网络模型,无法直接获取其内部某一中间层的输出,本文将给出两个方法进行解决。


方法
1 逐层进行forward
创建自定义函数,实现按照执行顺序逐层前向执行网络模型。
-----将模型输入以及模型作为参数传入函数,返回目标结果
def extract_res(inp, model):
for index, module in enumerate(model.modules()): # 按执行顺序遍历网络各层操作
if index in [0, 1, ...]: # 去除非操作层
continue
inp = module(inp) # 逐层前向执行,得到结果
if index == 3: # 判断是否为目标层 (示例为索引为3的操作)
return inp
tip: 利用.modules()在进行遍历操作时,其顺序为:
【总网络结构–>各部分–>各部分内部】







==》可见index = 0,1对应为非操作层,需要避免其执行forward。
故在使用此方法时,需要注意摒弃非操作层,跳过执行。此外,须推理出目标输出层对应索引号,才能实现精准获取。
2 使用hook函数
(1)定义保存hook内容的对象类
class SaveOutput:
def __init__(self):
self.outputs = []
def __call__(self, module, module_in, module_out):
self.outputs.append(module_out)
def clear(self):
self.outputs = []
(2) 为卷积层注册hook
hook_handles = []
save_output = SaveOutput()
for layer in model.modules(): # 按执行顺序遍历网络各层操作
if isinstance(layer, nn.Linear): # 按操作指令进行判别
handle = layer.register_forward_hook(save_output)
hook_handles.append(handle)
代码中示例即为寻找所有为执行nn.Linear()的操作层
(3) 对输入x进行预测(过程中每计算一个输出将自动调用hook函数)
out = model(x)
(4)取出通过目标层的输出
data = save_output.outputs[2] # 2为目标操作层输出在最终结果列表中的索引
tip: 网络包含几个操作同名层,save_output.outputs的size就为多少,取出对应位置的输出即可
------tbc-------
有用可以点个大拇指哦 🤭
版权声明
本文为[8倍]所创,转载请带上原文链接,感谢
https://blog.csdn.net/W9XM96/article/details/124316119
边栏推荐
- AcWing 1788. Why do cows cross the road (simulation)
- 49页石油石化行业信息化规划与数字化转型
- 105 page digital twin city information model CIM platform construction technical scheme
- 易语言CEF3获取请求返回的源码
- Reading breaks ten thousand "volumes": National Reading insight 2022
- 105页数字孪生城市信息模型CIM平台建设技术方案
- MySQL8.0正确修改密码的姿势
- OpenLayers入门(一)
- 50页京东云·睿擎-打造企业数字化转型的敏捷引擎业务中台解决方案
- AcWing 1854. Promotion count (Analog)
猜你喜欢

Login refactoring notes

AcWing 1854. Promotion count (Analog)

On the import and export of browser bookmarks

Web.xml文件详解

Elemetn form control --- automatically locate the position of the field when it is submitted without passing the verification field

Oracle official announcement: Tencent JDK 18 ranks first in China!

别紧张,就是聊聊软考

Applet introduction and development tools

外贸公司一般用什么邮箱,电子邮件如何群发?

返璞归真,多方安全计算要回归到“安全”的本源考虑
随机推荐
[binary search - simple] 69 Square root of X
Can station B be called YouTube in China?
Web.xml文件详解
AcWing1800. 不做最后一个(枚举)
[binary search - simple] LCP 28 Procurement scheme
105 page digital twin city information model CIM platform construction technical scheme
Site intelligent solution
Elemetn form control --- automatically locate the position of the field when it is submitted without passing the verification field
Obsidian 自动上传图片到图床——安装PicGo插件并配置
OpenLayers入门(二)
Solution de transformation de mot de données de 111 pages Fine Chemical Co., Ltd.
AcWing 1812. 方形牧场(枚举)
机器学习方法创建可学习的化学语法,构建可合成的单体和聚合物
EmlParse:一款超轻量级的批量解析EML格式电子邮件的工具
终极套娃 2.0|云原生 PaaS 平台的可观测性实践分享
SQL服务器如何设置起始日期查询语句
Deltix Round, Summer 2021 E. Equilibrium
无常损失简单解释
LaTeX常用公式查询
AcWing 1812. Square pasture (enumeration)