当前位置:网站首页>对比学习系列(三)-----SimCLR
对比学习系列(三)-----SimCLR
2022-08-11 08:11:00 【陶将】
SimCLR
SimCLR通过隐藏空间的对比损失最大化相同数据在不同增广下的一致性来学习表达。SimCLR框架有四个主要的组件,分别是:数据增广,encode网络,projection head网络和对比学习函数。
对于数据 x x x,从同一个数据增广族中抽取两个独立的数据增广算子( t ∼ T t \sim T t∼T和 t ′ ∼ T {t}' \sim T t′∼T),以获得两个相关的视图 x ^ i \hat{x}_{i} x^i和 x ^ j \hat{x}_{j} x^j, x ^ i \hat{x}_{i} x^i和 x ^ j \hat{x}_{j} x^j是一对正样本,然后一个神经网络编码器 f ( ⋅ ) f\left( \cdot \right) f(⋅)从增广的数据中提取特征 h i = f ( x ^ i ) , h j = f ( x ^ j ) , h_{i}=f\left( \hat{x}_{i} \right), h_{j}=f\left( \hat{x}_{j} \right), hi=f(x^i),hj=f(x^j),。再然后一个小的神经网络project head g ( ⋅ ) g\left( \cdot \right) g(⋅)将特征映射到对比损失的空间。project head采用带有一个隐含层的MLP获取 z i = g ( h i ) = W ( 2 ) σ ( W ( 1 ) h i ) z_{i} = g\left( h_{i} \right) = W^{\left( 2 \right)} \sigma \left( W^{\left( 1 \right)} h_{i}\right) zi=g(hi)=W(2)σ(W(1)hi)。
对于包含一对正样本 x ^ i \hat{x}_{i} x^i和 x ^ j \hat{x}_{j} x^j的集合 { x ^ k } \{ \hat{x}_{k} \} { x^k},对比预测任务目的是对于给定的 x ^ i \hat{x}_{i} x^i在 { x ^ } k ≠ i \{ \hat{x} \}_{k \neq i} { x^}k=i中识别出 x ^ j \hat{x}_{j} x^j。随机挑选 N N N个样本组成一个minibatch,这个minibatch中则有 2 N 2N 2N个数据样本,将其他 2 ( N − 1 ) 2\left( N - 1\right) 2(N−1)个扩增的样本作为这个minibatch中的负样本,设 s i m ( u , v ) = u T v / ∥ u ∥ ∥ v ∥ sim\left( u, v\right) = u^{T}v / \| u\| \| v\| sim(u,v)=uTv/∥u∥∥v∥表示 l 2 l_{2} l2正则化后你的 u u u和 v v v的点积,那么对一对正样本 ( i , j ) \left( i, j \right) (i,j),损失函数如下定义:
l i , j = − l o g e x p ( s i m ( z i , z j ) / τ ) ∑ k = 1 2 N 1 [ k ≠ i ] e x p ( s i m ( z i , z k ) / τ ) l_{i,j} = - log \frac{exp\left( sim \left( z_{i}, z_{j}\right) / \tau \right)}{\sum_{k=1}^{2N} \mathbb{1}_{[ k \neq i]} exp\left( sim \left( z_{i}, z_{k}\right) / \tau \right)} li,j=−log∑k=12N1[k=i]exp(sim(zi,zk)/τ)exp(sim(zi,zj)/τ)
最后的损失函数计算一个minibatch中的所有的正样本对,包括 ( i , j ) \left( i, j \right) (i,j)和 ( j , i ) \left( j,i \right) (j,i)。下面是SimCLR的伪代码。从伪代码中可以看出,编码器 f ( ⋅ ) f\left( \cdot \right) f(⋅)和project head g ( ⋅ ) g\left( \cdot \right) g(⋅) 在训练时都会被更新参数,但是只有编码器 f ( ⋅ ) f\left( \cdot \right) f(⋅)用于下游任务。
simCLR不采用memory bank的形式进行训练,而是加大batchsize,bacth size为8192,对于每一个正样本,将会有16382张负样本实例。增大batch size其实相当于每个minibatch时动态生成一个memory bank。论文中发现使用标准的SGD/Momentum,大batch size训练时是不稳定的,论文中采用LARS优化器。
参考
边栏推荐
猜你喜欢
1.1-Regression
Hibernate 的 Session 缓存相关操作
1.1-回归
Serverless + domain name can also build a personal blog? Really, and soon
2.1 - Gradient Descent
Redis source code: how to view the Redis source code, the order of viewing the Redis source code, the sequence of the source code from the external data structure of Redis to the internal data structu
About # SQL problem: how to set the following data by commas into multiple lines, in the form of column display
【TA-霜狼_may-《百人计划》】图形3.7.2 command buffer简
Creo9.0 特征的成组
3.1-Classification-probabilistic generative model
随机推荐
1076 Wifi Password (15 points)
go-grpc TSL authentication solution transport: authentication handshake failed: x509 certificate relies on ... ...
【TA-霜狼_may-《百人计划》】图形3.7.2 command buffer简
Do you know the basic process and use case design method of interface testing?
支持各种文件快速重命名最简单的小技巧
笔试题大疆08.07
Four states of Activity
leetcode: 69. Square root of x
分门别类输入输出,Go lang1.18入门精炼教程,由白丁入鸿儒,go lang基本数据类型和输入输出EP03
流式结构化数据计算语言的进化与新选择
Keep track of your monthly income and expenses through bookkeeping
剑指offer专项突击版第26天
Project 1 - PM2.5 Forecast
Break pad source code compilation--refer to the summary of the big blogger
Test cases are hard?Just have a hand
Kotlin算法入门计算水仙花数
oracle19c does not support real-time synchronization parameters, do you guys have any good solutions?
【LeetCode】Summary of linked list problems
【BM87 合并两个有序的数组】
klayout--导出版图为gds文件