当前位置:网站首页>Detailed explanation of VIT source code
Detailed explanation of VIT source code
2022-08-11 03:26:00 【The romance of cherry blossoms】
1.项目配置说明
参数说明:
数据集:
--name cifar10-100_500
--dataset cifar10
which version of the model:
--model_type ViT-B_16
预训练权重:
--pretrained_dir checkpoint/ViT-B_16.npz
2.patch embeding与position_embedding
for image encoding,以VIT - B/16为例,First use the convolution kernel size as 16*16、步长为16的卷积,对图像进行变换,At this point the image dimension becomes 16 * 768 * 14 * 14,Then transform the dimension to [16, 196, 768],Then set the dimension to 16*1*768的0patch相连.
对于位置编码,构建一个1 * 197 * 768的向量
最后,The encoding is completed by adding the image encoding and the position encoding.
代码如下:
class Embeddings(nn.Module):
"""Construct the embeddings from patch, position embeddings.
"""
def __init__(self, config, img_size, in_channels=3):
super(Embeddings, self).__init__()
self.hybrid = None
img_size = _pair(img_size)
# patch_size 大小 与 patch数量 n_patches
if config.patches.get("grid") is not None:
grid_size = config.patches["grid"]
patch_size = (img_size[0] // 16 // grid_size[0], img_size[1] // 16 // grid_size[1])
n_patches = (img_size[0] // 16) * (img_size[1] // 16)
self.hybrid = True
else:
patch_size = _pair(config.patches["size"])
n_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1])
self.hybrid = False
# 使用混合模型
if self.hybrid:
self.hybrid_model = ResNetV2(block_units=config.resnet.num_layers,
width_factor=config.resnet.width_factor)
in_channels = self.hybrid_model.width * 16
# patch_embeding 16 * 768 * 14 * 14
self.patch_embeddings = Conv2d(in_channels=in_channels,
out_channels=config.hidden_size,
kernel_size=patch_size,
stride=patch_size)
# 初始化 position_embeddings: 1 * 197 * 768
self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches+1, config.hidden_size))
# 初始化第 0 个patch,Represents categorical features 1*1*768
self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
# dropout层
self.dropout = Dropout(config.transformer["dropout_rate"])
def forward(self, x):
print(x.shape)
B = x.shape[0]
# 拓展cls_tokens的维度:16 *1*768
cls_tokens = self.cls_token.expand(B, -1, -1)
print(cls_tokens.shape)
# 混合模型
if self.hybrid:
x = self.hybrid_model(x)
# 编码:16 * 768 * 14 * 14
x = self.patch_embeddings(x)
print(x.shape)
# 变换维度:16 * 768 * 14 * 14-->[16, 768, 196]
x = x.flatten(2)
print(x.shape)
# [16, 768, 196] --> [16, 196, 768]
x = x.transpose(-1, -2)
print(x.shape)
# Add categorical featurespatch
x = torch.cat((cls_tokens, x), dim=1)
print(x.shape)
# 加入位置编码
embeddings = x + self.position_embeddings
print(embeddings.shape)
# dropout层
embeddings = self.dropout(embeddings)
print(embeddings.shape)
return embeddings3.ecoder
多头注意力模块:
首先构建q,k,vThree auxiliary vectors,Because we employ a multi-head attention mechanism(12个),首先,我们需要将q,k,v维度从16, 197, 768转换成16, 12, 197, 64,然后获得q,k的相似性qk,Because what is obtained is the relationship between the two,所以维度为16, 12, 197, 197,消除量纲,经过softmax后,Get the extracted feature vectorqkv,维度为16, 12, 197, 64,Then restore the dimension to 16, 197, 768
class Attention(nn.Module):
def __init__(self, config, vis):
super(Attention, self).__init__()
self.vis = vis
# heads数量
self.num_attention_heads = config.transformer["num_heads"]
# 每个head的向量维度
self.attention_head_size = int(config.hidden_size / self.num_attention_heads)
# 总head_size
self.all_head_size = self.num_attention_heads * self.attention_head_size
# query向量
self.query = Linear(config.hidden_size, self.all_head_size)
# key向量
self.key = Linear(config.hidden_size, self.all_head_size)
# value向量
self.value = Linear(config.hidden_size, self.all_head_size)
# 全连接层
self.out = Linear(config.hidden_size, config.hidden_size)
# dropout层
self.attn_dropout = Dropout(config.transformer["attention_dropout_rate"])
self.proj_dropout = Dropout(config.transformer["attention_dropout_rate"])
self.softmax = Softmax(dim=-1)
def transpose_for_scores(self, x):
# 维度:16, 197, 768-->16,197,12,64
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
# print(new_x_shape)
x = x.view(*new_x_shape)
# print(x.shape)
# print(x.permute(0, 2, 1, 3).shape)
# 16,197,12,64 --> 16, 12, 197, 64
return x.permute(0, 2, 1, 3)
def forward(self, hidden_states):
# print(hidden_states.shape)
# q,k,v:16, 197, 768
mixed_query_layer = self.query(hidden_states)
# print(mixed_query_layer.shape)
mixed_key_layer = self.key(hidden_states)
# print(mixed_key_layer.shape)
mixed_value_layer = self.value(hidden_states)
# print(mixed_value_layer.shape)
# q,k,v:16, 197, 768-->16, 12, 197, 64
query_layer = self.transpose_for_scores(mixed_query_layer)
# print(query_layer.shape)
key_layer = self.transpose_for_scores(mixed_key_layer)
# print(key_layer.shape)
value_layer = self.transpose_for_scores(mixed_value_layer)
# print(value_layer.shape)
# q,k的相似性:16, 12, 197, 197
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
# print(attention_scores.shape)
# 消除量纲
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
# print(attention_scores.shape)
attention_probs = self.softmax(attention_scores)
# print(attention_probs.shape)
weights = attention_probs if self.vis else None
attention_probs = self.attn_dropout(attention_probs)
# print(attention_probs.shape)
# print(value_layer.shape)
# 特征向量:qkv:16, 12, 197, 64
context_layer = torch.matmul(attention_probs, value_layer)
# print(context_layer.shape)
# 16, 12, 197, 64-->16, 12, 197, 64
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
# print(context_layer.shape)
# 16, 12, 197, 64-->16, 197, 768
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
# print(context_layer.shape)
# 全连接层:16, 197, 768
attention_output = self.out(context_layer)
# print(attention_output.shape)
# dropout层
attention_output = self.proj_dropout(attention_output)
# print(attention_output.shape)
return attention_output, weightstransformer encoder
对于输入的x,First after layer normalization,Enter the multi-head attention mechanism,Residual joins are performed on the results,After layer normalization,经过两层全连接,After residual concatenation,Get a module result,堆叠L层,输出最终结果
class Block(nn.Module):
def __init__(self, config, vis):
super(Block, self).__init__()
# The size of the sequence:768
self.hidden_size = config.hidden_size
# 层归一化
self.attention_norm = LayerNorm(config.hidden_size, eps=1e-6)
self.ffn_norm = LayerNorm(config.hidden_size, eps=1e-6)
# MLP层
self.ffn = Mlp(config)
# 多头注意力机制
self.attn = Attention(config, vis)
def forward(self, x):
# print(x.shape)
# 16, 197, 768
h = x
# 层归一化
x = self.attention_norm(x)
# print(x.shape)
# 多头注意力机制
x, weights = self.attn(x)
# 残差连接
x = x + h
# print(x.shape)
h = x
# 层归一化
x = self.ffn_norm(x)
# print(x.shape)
# MLP层
x = self.ffn(x)
# print(x.shape)
# 残差连接
x = x + h
# print(x.shape)
return x, weights整体架构
对于输入x,进行patch embeding和position embeding后,此时维度为16*197*768,输入encoder中,经过LLayer encoding module,取出第0个patch的编码结果(Represents categorical features),input classification layer,得到预测结果.
class VisionTransformer(nn.Module):
def __init__(self, config, img_size=224, num_classes=21843, zero_head=False, vis=False):
super(VisionTransformer, self).__init__()
self.num_classes = num_classes
self.zero_head = zero_head
self.classifier = config.classifier
self.transformer = Transformer(config, img_size, vis)
self.head = Linear(config.hidden_size, num_classes)
def forward(self, x, labels=None):
x, attn_weights = self.transformer(x)
print(x.shape)
# X.shape:16, 197, 768 logits.shape:16, 10
logits = self.head(x[:, 0])
print(logits.shape)
# 交叉熵
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_classes), labels.view(-1))
return loss
else:
return logits, attn_weights
边栏推荐
- A brief analysis of whether programmatic futures trading or manual order is better?
- 最倒霉与最幸运
- The impact of programmatic trading and subjective trading on the profit curve!
- Redis老了吗?Redis与Dragonfly性能比较
- MongoDB 基础了解(二)
- Qnet Weak Network Test Tool Operation Guide
- EasyCVR接入海康大华设备选择其它集群服务器时,通道ServerID错误该如何解决?
- 互换性测量技术-几何误差
- How does MSP430 download programs to the board?(IAR MSPFET CCS)
- Docker 链接sqlserver时出现en-us is an invalid culture错误解决方案
猜你喜欢

EasyCVR接入GB28181设备时,设备接入正常但视频无法播放是什么原因?

Traversal of DOM tree-----modify styles, select elements, create and delete nodes

STC8H开发(十五): GPIO驱动Ci24R1无线模块

MongoDB 基础了解(二)

Official release丨VS Code 1.70

多商户商城系统功能拆解26讲-平台端分销设置

【愚公系列】2022年08月 Go教学课程 035-接口和继承和转换与空接口

E-commerce project - mall time-limited seckill function system

正式发布丨VS Code 1.70

The most unlucky and the luckiest
随机推荐
阿里低代码框架 lowcode-engine 之自定义物料篇
言简意赅,说说 @Transactional 在项目中的使用
this question in js
Google search skills - programmer is recommended
What has programmatic trading changed?
font
互换性与测量技术-公差原则与选用方法
Goodbye Guangzhou paper invoices!The issuance of electronic invoices for accommodation fees will completely replace the invoices of hotels, restaurants and gas stations
Ninjutsu_v3_08_2020 - safety penetrating system installation
[idea error] Invalid target distribution: 17 solution reference
IDE编译报错:Dangling metacharacter
The 125th day of starting a business - a note
QueryDet:级联稀疏query加速高分辨率下的小目标检测
21天学习挑战赛第一周总结
广州纸质发票再见!开住宿费电子发票即将全面取代酒店餐饮加油站发票
flink The object probably contains or references non serializable fields.
浅析一下期货程序化交易好还是手工单好?
学编程的第十三天
Add support for Textbundle
Roewe imax8ev cube battery security, what blackening and swelling are hidden behind it?