当前位置:网站首页>利用Transformers自定义一个神经网络结构
利用Transformers自定义一个神经网络结构
2022-08-08 06:24:00 【hithithithithit】
### 伪代码
import torch
from torch.optim import lr_scheduler, optimizer
from transformers import BertTokenizer, BertModel, BertPreTrainedModel
from torch import nn
# bert = BertModel('bert-base-uncased')
class BertModelCoustom(BertPreTrainedModel):
"""
利用transformers定制化神经网络模型
"""
def __init__(self, config):
super(BertModelCoustom, self).__init__(config)
# 导入配置,下面的语句对应后面实例化时候的from_pretrained
self.bert = BertModel(config)
self.linear = nn.Linear(768, 128)
self.hidden_dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, x):
# 定义前向传播
return self.hidden_dropout(self.linear(self.bert(x)))
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
bmc = BertModelCoustom.from_pretrained('bert-base-uncased')
for n, p in bmc.linear.named_parameters():
print("线形层模型参数:")
print(n)
print(p)
break
for n, p in bmc.bert.named_parameters():
print("Bert模型参数:")
print(n)
print(p)
break
print("hidden_dropout:", bmc.hidden_dropout)
# 调用模型
x = torch.tensor([[1, 2, 3]])
output = bmc(x)
loss = output['loss']
loss.backward()
optimizer.step()
lr_scheduler.step()
边栏推荐
猜你喜欢

Food Industry Report: Research Analysis and Development Prospect Forecast of Chili Market

ER图是什么?

二、TF2 常见问题解决

Refrigerator compressor market status research analysis and development prospect forecast

栈的实例应用

Mybaits笔记
![[WUSTCTF2020]CV Maker1](/img/be/989b1ea8597f31f4b82c2edc6345d5.png)
[WUSTCTF2020]CV Maker1

树基础入门

Instant Noodle Industry Survey: Expected to Reach $43.6 Billion in 2028

Shell(一)
随机推荐
hyperledger-fabric documention官方文档
4.Callable接口实现多线程
PHP操作MongoDB的原生CURD方法
MongoDB的备份与恢复
线程和进程定义
6.线程的休眠
[BSidesCF 2020] Had a bad day1
八.Redis 主从复制
Neo4j service configuration
UXDB lost the database password, how to recover?
必知必会的VGG网络(含代码)
demo:数组方法-商品查询
Lamp analysis: LED lamps are expected to reach $45.9 billion in 2028
三、MATPLOTLIB数据可视化分析工具
什么是原型图设计?
cybox target machine wp
总结:numpy常用方法
MySQL----存储引擎
课堂作业--验证码较验
MySQL表的增删改查