当前位置:网站首页>【学习一下】HF-Net 训练
【学习一下】HF-Net 训练
2022-04-23 05:47:00 【mightbxg】
HF-Net 可以提取图像描述子(global_descriptors)和图像中的特征点(keypoints)及其描述子(local_descriptors),前者用于图像检索,后者配合 SuperGlue/NN 等特征匹配算法可用于相机位姿计算。因此 HF-Net 的应用场景就是 SLAM 中的地图定位与位姿恢复。
虽然官方仓库已有完善的训练脚本和详细的使用说明,但有些细节还是需要注意的,在此记录一下。预先提醒:HF-Net 训练会产生 1.4TB 的中间数据,请确保数据盘有 1.5TB~2TB 的可用空间!
素材准备
首先从 github pull 源码,我一般放在 home 目录下,比如 /home/abc/Sources/hfnet。hfnet 设置时需要提供两个路径:DATA_PATH 和 EXPER_PATH,前者存放训练图像和预训练模型权重,后者存放 hfnet 的训练输出。对训练而言,需要下载的东西有:
- GoogleLandmarks 数据集: 先下载图片索引,然后使用
setup/scripts/download_google_landmarks.py脚本下载。索引文件中共有 1098k 张图的链接,但有一部分是无效的。随机下载 185k 张图(约 55GB)即可,多出的图在训练时会被抛弃掉(见配置文件hfnet_train_distill.yaml)。 - BerkeleyDeepDrive 数据集: 直接去官网,简单注册后在下载页面
BDD100K标签页中下载100K Images和Labels,总共约 5.4GB。 - NetVLAD 模型参数: 下载地址
- SuperPoint 模型参数: 下载地址
- Mobilenet V2 模型参数: 在下载页面找到
float_v2_0.75_224下载。
BDD100k 数据集需要挑选出 night 和 dawn 类别的图像,我写了一个脚本来执行该任务。解压 images_100k 和 labels ,下载 bdd_extract_images_for_hfnet.py 后,按如下方式组织文件夹:
bdd100k
├── bdd_extract_images_for_hfnet.py
├── images
│ └── 100k
└── labels
├── bdd100k_labels_images_train.json
└── bdd100k_labels_images_val.json
然后执行 bdd_extract_images_for_hfnet.py,即可得到 dawn_images_vga 和 night_images_vga 两个文件夹。
上面准备的这些东西都放在 DATA_PATH 目录下,最终 DATA_PATH 中的内容为:
HfNetDataset
├── bdd
│ ├── dawn_images_vga # 文件夹
│ └── night_images_vga # 文件夹
├── google_landmarks
│ └── images # 文件夹
└── weights
├── mobilenet_v2_0.75_224 # 文件夹
├── superpoint_v1.pth
└── vd16_pitts30k_conv5_3_vlad_preL2_intra_white # 文件夹
环境准备
HF-Net 使用的是 tensorflow 1.12,实测 tf 1.15.5 也是可以用的。基于镜像 tensorflow/tensorflow:1.15.5-gpu 创建 docker 容器:
docker run --gpus all --name "hfnet" -v /home/abc:/home/abc -it tensorflow/tensorflow:1.15.5-gpu /bin/bash
这里只挂载了 home 目录,因为我使用的代码和数据都在 home 目录下,如果你的 DATA_PATH 在另一个盘,也需要一并挂载(加个 -v /host_dir:/docker_dir)。本文开头提到了 DATA_PATH 需要至少 1.5TB 的可用空间!
进入容器后需要安装一个包:libgl1-mesa-glx,为了方便操作,tmux 和 vim 也可用装一下。
然后是配置 hfnet 仓库,在此之前改一下 setup/requirements.txt 文件,把 tensorflow-gpu==1.12 后面的 ==1.12 删掉,否则会对容器中已有的 tf1.15.5 降级。在 hfnet 根目录下执行 make install 进行安装配置。这里 DATA_PATH 和 EXPER_PATH 必须填写绝对路径。
模型训练
虽然 docker 容器启用了所有 GPU,但服务器上的显卡往往并不是全都能用的,可以在终端设置 CUDA_VISIBLE_DEVICES 环境变量来控制 GPU 的使用,比如:
export CUDA_VISIBLE_DEVICES=1 # 只启用1号GPU
export CUDA_VISIBLE_DEVICES=1,2 #只启用1、2号GPU
在真正的训练之前,需要先导出 NetVLAD 和 SuperPoint 模型的预测,作为数据集的 label(HF-Net 使用的是模型蒸馏训练方法,即 NetVLAD 和 SuperPoint 两个大模型监督训练 hfnet 这个小模型)。在 hfnet 根目录分别执行:
python3 hfnet/export_predictions.py \
hfnet/configs/netvlad_export_distill.yaml \
global_descriptors \
--keys global_descriptor \
--as_dataset
python3 hfnet/export_predictions.py \
hfnet/configs/superpoint_export_distill.yaml \
superpoint_predictions \
--keys local_descriptor_map,dense_scores \
--as_dataset
这一步会在 DATA_PATH 下得到 global_descriptors(4.6GB) 和 superpoint_predictions(1.4TB) 两个文件夹,耗时可能有数小时。
训练需要的数据集结构是这样的(global_descriptors 和 superpoint_predictions 与图像文件夹在同一层):
├── bdd
│ ├── dawn_images_vga
│ ├── global_descriptors
│ ├── night_images_vga
│ └── superpoint_predictions
└── google_landmarks
├── global_descriptors
├── images
└── superpoint_predictions
但前面得到的文件夹结构却是这样的:
├── bdd
│ ├── dawn_images_vga
│ └── night_images_vga
├── global_descriptors
├── google_landmarks
│ └── images
└── superpoint_predictions
一种方法是将 global_descriptors 和 superpoint_predictions 文件夹中属于两个数据集的部分分开放到各自文件夹中,更简单的方法是直接在 bdd 和 google_landmarks 文件夹中建立 global_descriptors 和 superpoint_predictions 文件夹的软连接,因为训练时是先找到图像,然后根据图像名字找对应的 label,某一数据集下有多余 label 并不会影响训练。
然后就是训练了(数小时到十几小时):
python3 hfnet/train.py hfnet/configs/hfnet_train_distill.yaml hfnet
训练过程 log 和训练后的模型保存在 EXPER_PATH/hfnet/,可以用 tensorboard 监测训练过程:
tensorboard --logdir=$EXPER_PATH/hfnet --host=0.0.0.0
# 本地浏览器访问 http://server_ip:port/ 查看
最后可以导出为 pb 模型(EXPER_PATH/saved_models/hfnet/):
python3 hfnet/export_model.py hfnet/configs/hfnet_train_distill.yaml hfnet
版权声明
本文为[mightbxg]所创,转载请带上原文链接,感谢
https://blog.csdn.net/mightbxg/article/details/119888881
边栏推荐
猜你喜欢
随机推荐
Rust 中的 RefCell
Gesture recognition research
Robocode教程4——Robocode的游戏物理
Robocode教程7——雷达锁定
ArcGIS表转EXCEL超出上限转换失败
D. Optimal partition segment tree optimization DP
Busybox initrd and initialization process
Rust 中的 Rc智能指针
C array
selenium+webdriver+chrome实现百度以图搜图
H. Are You Safe? Convex hull naked problem
P1018 maximum product solution
非参数化相机畸变模型简介
生成验证码
代理服务器
【UDS统一诊断服务】三、应用层协议(1)
serde - rust的序列化方案
Rainbow (DP)
Common shortcut keys of IDE
Rust:如何实现一个线程池?









![[untitled] database - limit the number of returned rows](/img/20/9a143e6972f1ce2eed5a3d11c3a46d.png)