当前位置:网站首页>实验记录:搭建网络过程
实验记录:搭建网络过程
2022-08-09 11:32:00 【匿名的魔术师】
一、遇到的问题
1.RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same
RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same
当想要测试搭建的网络,看看各个过程的输出形状,随便创建个固定shape的输入如下
search_var = torch.FloatTensor(1, 3, 255, 255).cuda()
观察报错的问题,发现是输入 是 torch.cuda.FloatTensor类型的,而 weight是 torch.FloatTensor类型的,所以,需要把 model 放到cuda上才行。
解决:
model =model.cuda()
2. TypeError: 'PointTarget' object is not iterable
这个是由于在 设计 points时产生的问题,由于定义类中的
class PointTarget:
def __init__(self):
def __call__(self, target, size, neg=False):
__init__ 和 __call__ 传入参数冲突引起的,故需要做出改变。
而且在利用类时一定要首先初始化实例对象,
3.debug 时 变量若显示unable to get repr for <class ‘torch.Tensor‘
分类损失函数,用的focal loss, 所以送入之前数据需要映射到0~1之间,故需要进行一下sigmoid运算 。
4. RuntimeError: CUDA error: device-side assert triggered
5 .神经网络的输出出现nan值
这个可能由于搭建网络结构时,某层的网络的输出没有规一化或者 relu 输出,导致输出值不正常
6. config.py 与 config.yaml冲突
若config.yaml中设置的参数值同时也出现在了config.py文件中的话,一定要确保config.py文件中参数设置的值 与 config.yaml中设置的一致
7. 分布式训练卡住
可以设置
--nproc_per_node=1
但第二天早上 又 设置成 2又不卡了。
二、 实现历程
1. 搭建自己的网络
首先根据设计的方法搭建模型,编写代码。 搭建模型的过程中 分部分来完成,即先完成基础的每个部分,然后将它们连接在一起。比如举例来说 分成了 backbone 和 head。
最后通过创建一个model类来完成模型的前向传播过程,这个会继承pytorch的nn.Module类
可以随便创建哥输入变量,然后一步一步看看传输过程中特征shape的变化,创建tensor变量的语句如下
template_var = torch.rand((1, 3, 127, 127))
search_var = torch.rand((1, 3, 255, 255))
2. 设计以及实现损失函数的计算
搭建完网络结构之后,可以通过model得到想要的输出形式,然后接下来就是损失函数的设计与实现。首先,根据损失函数的需要,先把整个流程确定下来。在这个过程中,涉及到 标签的get,根据想要的输出怎么得到对应的标签呢?标签的shape 一般来说都是与 model 的output shape 相对应的。这里 标签一定要与 得到的预测输出一一对应,与之相联系的就是在进行标签和预测输出进行 permute 和 torch.cat 等等操作时shape变换的对应。
1) 创建points
2) 得到标签 cls 和 reg
3) 组成data 正式训练时构建 dataloader
4) 搭建模型 model
5) 数据送入model,得到预测输出
3. 嵌入模型
1) 注意data的重新设置。 data中的标签看看是否匹配了
2) 注意训练时 model 的 forward 过程,是否顺利
3) 网络中各模块学习率和优化器的设置,注意这方面的解耦
学习率的设置 是通过字典来设置的,参数 和 lr 键 以及再设置对应的值
trainable_params += [{'params': filter(lambda x: x.requires_grad,
model.backbone.parameters()),
'lr': cfg.BACKBONE.LAYERS_LR * cfg.TRAIN.BASE_LR}] # <c> 可能是骨干网络和其他的部分 初始学习率不一样
if cfg.ADJUST.ADJUST: # True
trainable_params += [{'params': model.fpn.parameters(),
'lr': cfg.TRAIN.BASE_LR}] # <c>
trainable_params += [{'params': model.ban.parameters(),
'lr': cfg.TRAIN.BASE_LR}]
4. 推理过程
推理过程与训练过程是两个截然不同的部分,一般推理过程和训练过程在同一个tracker类里,不过它们需要分开去定义自己的流程。
边栏推荐
- The use of signal function (signal) in C language
- ACM01背包问题
- 富媒体在客服IM消息通信中的秒发实践
- ECCV 2022 Oral | CCPL: 一种通用的关联性保留损失函数实现通用风格迁移
- x86 Exception Handling and Interrupt Mechanism (1) Overview of the source and handling of interrupts
- Qt 国际化翻译
- 【精华文】C语言结构体特殊情况分析:结构体指针 / 基本数据类型指针,指向其他结构体
- 学生成绩查找系统
- 使用.NET简单实现一个Redis的高性能克隆版(四、五)
- ClickHouse物化视图(八)
猜你喜欢
log4net使用指南(winform版,sqlserver记录)
二重指针-char **、int **的作用
Arduino学习总结 + 实习项目
ICML 2022 | Out-of-Distribution检测与深最近的邻居
[现代控制理论]4_PhasePortrait爱情故事动态系统分析
Number theory knowledge
x86 Exception Handling and Interrupt Mechanism (3) Interrupt Handling Process
This application has no explicit mapping for /error, so you are seeing this as a fallback
ECCV 2022 Oral | CCPL: 一种通用的关联性保留损失函数实现通用风格迁移
C# async 和 await 理解
随机推荐
CANopen DS402名词
PTA习题 分类统计字符个数(C)
ACM01背包问题
SQL Server查询优化
win10 outlook邮件设置
es6的async函数
PAT1003
[C language] creation and use of dynamic arrays
《数字经济全景白皮书》银行业智能营销应用专题分析 发布
字符串 | 反转字符串 | 双指针法 | leecode刷题笔记
ThreadLocal类
在北京参加UI设计培训到底怎么样?
log4net使用指南(winform版,sqlserver记录)
TIC2000系列处理器在线升级
防止数据冒用的方法
The use of signal function (signal) in C language
【C language】typedef的使用:结构体、基本数据类型、数组
ZOJ1298(单源最短路径)
【Data augmentation in NLP】——1
论文分享 | ACL2022 | 基于迁移学习的论元关系提取