当前位置:网站首页>抱抱脸(hugging face)教程-中文翻译-翻译任务(translating)
抱抱脸(hugging face)教程-中文翻译-翻译任务(translating)
2022-08-09 14:56:00 【wwlsm_zql】
Translation
翻译将文本序列从一种语言转换为另一种语言。这是可以作为序列到序列问题来规划的几个任务之一,序列到序列问题是一个扩展到视觉和音频任务的强大框架。
本指南将向您展示如何在 OPUS Books 数据集的英语-法语子集上微调 T5,以便将英语文本翻译成法语。
有关其关联模型、数据集和指标的更多信息,请参见翻译任务页。
加载 OPUS 图书数据集
从 Datasets 库加载 OPUS Books 数据集:
from datasets import load_dataset
books = load_dataset("opus_books", "en-fr")
将这个数据集分割成一列火车和测试集:
books = books["train"].train_test_split(test_size=0.2)
然后看一个例子:
books["train"][0]
{
'id': '90560',
'translation': {
'en': 'But this lofty plateau measured only a few fathoms, and soon we reentered Our Element.',
'fr': 'Mais ce plateau élevé ne mesurait que quelques toises, et bientôt nous fûmes rentrés dans notre élément.'}}
翻译领域是包含文本的英语和法语翻译的字典。
预处理
加载 T5标记器来处理语言对:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("t5-small")
预处理功能需要:
- 以提示符作为输入的前缀,这样 T5就知道这是一个翻译任务。一些能够执行多个 NLP 任务的模型需要对特定任务进行提示。
- 分别对输入(英语)和目标(法语)进行标记。你不能用英语词汇预先训练好的标记器来标记法语文本。上下文管理器将帮助首先将标记器设置为法语,然后再对其进行标记。
- 截断序列不要超过 max_length 参数设置的最大长度。
source_lang = "en"
target_lang = "fr"
prefix = "translate English to French: "
def preprocess_function(examples):
inputs = [prefix + example[source_lang] for example in examples["translation"]]
targets = [example[target_lang] for example in examples["translation"]]
model_inputs = tokenizer(inputs, max_length=128, truncation=True)
with tokenizer.as_target_tokenizer():
labels = tokenizer(targets, max_length=128, truncation=True)
model_inputs["labels"] = labels["input_ids"]
return model_inputs
使用 Datasets 映射函数对整个数据集应用预处理函数。您可以通过设置 batching=True 来加速 map 函数,以便同时处理数据集中的多个元素:
>>> tokenized_books = books.map(preprocess_function, batched=True)
使用 DataCollatorForSeq2Seq 创建一批示例。它还将动态地将文本和标签填充到其批处理中最长元素的长度,因此它们是统一长度。虽然可以通过设置 padding=True 在 tokenizer 函数中填充文本,但是动态填充效率更高。
Pytorch
from transformers import DataCollatorForSeq2Seq
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)
TensorFlow
>>> from transformers import DataCollatorForSeq2Seq
>>> data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model, return_tensors="tf")
Train
Pytorch
使用 AutoModelForSeq2SeqLM 加载 T5:
>>> from transformers import AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer
>>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-small")
目前,只剩下三个步骤:
在 Seq2SeqTrainingArguments 中定义训练超参数。
将训练参数与模型、数据集、标记器和数据校对器一起传递给 Seq2SeqTrainer。
调用 train() 对模型进行微调。
>>> training_args = Seq2SeqTrainingArguments(
output_dir="./results",
evaluation_strategy="epoch",
learning_rate=2e-5,
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
weight_decay=0.01,
save_total_limit=3,
num_train_epochs=1,
fp16=True,
)
>>> trainer = Seq2SeqTrainer(
model=model,
args=training_args,
train_dataset=tokenized_books["train"],
eval_dataset=tokenized_books["test"],
tokenizer=tokenizer,
data_collator=data_collator,
)
trainer.train()
TensorFlow
要在 TensorFlow 中微调模型,首先要将数据集转换为 tf.data。具有 to_tf_data 的数据集格式。在列中指定输入和标签,是否对数据集顺序、批量大小和数据排序器进行打乱:
>>> tf_train_set = tokenized_books["train"].to_tf_dataset(
columns=["attention_mask", "input_ids", "labels"],
shuffle=True,
batch_size=16,
collate_fn=data_collator,
)
>>> tf_test_set = tokenized_books["test"].to_tf_dataset(
columns=["attention_mask", "input_ids", "labels"],
shuffle=False,
batch_size=16,
collate_fn=data_collator,
)
如果您不熟悉使用 Kera 对模型进行微调,请在这里查看基本教程!
建立一个优化器函数、学习速率进度表和一些训练超参数:
>>> from transformers import create_optimizer, AdamWeightDecay
>>> optimizer = AdamWeightDecay(learning_rate=2e-5, weight_decay_rate=0.01)
用 TFAutoModelForSeq2SeqLM 加载 T5:
>>> from transformers import TFAutoModelForSeq2SeqLM
>>> model = TFAutoModelForSeq2SeqLM.from_pretrained("t5-small")
通过编译配置培训模型:
>>> model.compile(optimizer=optimizer)
调用 fit 对模型进行微调:
>>> model.fit(x=tf_train_set, validation_data=tf_test_set, epochs=3)
有关如何调优翻译模型的更深入示例,请查看相应的 PyTorch 笔记本或 TensorFlow 笔记本。
边栏推荐
- ASP.Net Core实战——初识.NetCore
- How to create a new project with VS+Qt
- 浏览器指纹识别是什么意思?
- Talking about Shallow Cloning and Deep Cloning of ArraryList
- OpenCV简介与搭建使用环境
- CV复习:BatchNorm
- 将从后台获取到的数据 转换成 树形结构数据
- Server运维:设置.htaccess按IP和UA禁止访问
- strlen(), strcpy(), strncpy(), strcat(), strncat(), strcmp(), strncmp()函数的封装
- cheerio根据多个class匹配
猜你喜欢
随机推荐
VS2010:出现devenv.sln解决方案保存对话框
PAT1027 Printing Hourglass
Qt control - QTextEdit usage record
Use tensorboard remotely on the server
单向链表几个比较重要的函数(包括插入、删除、反转等)
深刻地认识到,编译器会导致编译结果的不同
爱因斯坦的光子理论
Suddenly want to analyze the mortgage interest rate and interest calculation
你知道亚马逊代运营的成本是多少吗?
stream去重相同属性对象
Matlab修改Consolas字体
小型项目如何使用异步任务管理器实现不同业务间的解耦
模仿微信金钱输入框规则(修复7.0手机崩溃)
(13)Filter过滤器
Talking about quantitative trading and programmatic trading
链表翻转 全翻转 部分翻转
鸡生蛋,蛋生鸡问题。JS顶级对象Function,Object关系
Simply record offsetof and container_of
什么是跨境电商测评?
WebGL探索——抉择:实践方向(twgl.js、Filament、Claygl、BabylonJS、ThreeJS、LayaboxJS、SceneJS、ThinkJS、ThingJS)