当前位置:网站首页>抱抱脸(hugging face)教程-中文翻译-文本分类
抱抱脸(hugging face)教程-中文翻译-文本分类
2022-08-09 14:56:00 【wwlsm_zql】
文本分类
文本分类是一个常见的自然语言处理任务,它为文本分配一个标签或类。有许多实际应用的文本分类广泛应用于生产的一些今天的最大的公司。最流行的文本分类形式之一是情感分析,它为一系列文本分配一个标签,如正面、负面或中性。
本指南将向您展示如何对 IMDb 数据集上的 DistilBERT 进行微调,以确定电影评论是正面的还是负面的。
有关其他形式的文本分类及其相关模型、数据集和度量的更多信息,请参见文本分类任务页。
从 Datasets 库加载 IMDb 数据集:
>>> from datasets import load_dataset
>>> imdb = load_dataset("imdb")
然后看一个例子:
>>> imdb["test"][0]
{
"label": 0,
"text": "I love sci-fi and ...Jeeez! Dallas all over again.",
}
这个数据集中有两个字段:
- Text: 包含电影评论文本的字符串。
- Label: 一个值,对于负面评价可以是0,对于正面评价可以是1。
预处理
加载 DistilBERT 标记器以处理文本字段:
>>> from transformers import AutoTokenizer
>>> tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
创建一个预处理函数来标记文本并截断不超过 DistilBERT 最大输入长度的序列:
>>> def preprocess_function(examples):
return tokenizer(examples["text"], truncation=True)
使用 Datasets 映射函数对整个数据集应用预处理函数。您可以通过设置 batching=True 来加速 map 函数,以便同时处理数据集中的多个元素:
>>> tokenized_imdb = imdb.map(preprocess_function, batched=True)
使用 DataCollatorWithPadd 创建一批示例。它还将动态地将文本填充到其批处理中最长元素的长度,因此它们是统一长度。虽然可以通过设置 pding = True 在 tokenizer 函数中填充文本,但是动态填充效率更高。
Pytorch
>>> from transformers import DataCollatorWithPadding
>>> data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
TensorFlow
>>> from transformers import DataCollatorWithPadding
>>> data_collator = DataCollatorWithPadding(tokenizer=tokenizer, return_tensors="tf")
Train
Pytorch
使用 AutoModelForSequenceAnalysis 加载 DistilBERT 以及预期的标签数量:
>>>from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer
>>>model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=2)
如果您不熟悉用 Trainer 对模型进行微调,请参阅这里的基本教程!
目前,只剩下三个步骤:
在 TrainingArguments 中定义训练超参数。
将训练参数连同模型、数据集、标记器和数据校对器一起传递给 Trainer。
调用 train ()对模型进行微调。
>>> training_args = TrainingArguments(
output_dir="./results",
learning_rate=2e-5,
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
num_train_epochs=5,
weight_decay=0.01,
)
>>> trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_imdb["train"],
eval_dataset=tokenized_imdb["test"],
tokenizer=tokenizer,
data_collator=data_collator,
)
>>> trainer.train()
当您向 Trainer 传递 tokenizer 时,它将默认应用动态填充。在这种情况下,不需要显式地指定数据校对器。
TensorFlow
要在 TensorFlow 中微调模型,首先要将数据集转换为 tf.data。具有 to_tf_data 的数据集格式。在列中指定输入和标签,是否对数据集顺序、批量大小和数据排序器进行洗牌:
>>> tf_train_set = tokenized_imdb["train"].to_tf_dataset(
columns=["attention_mask", "input_ids", "label"],
shuffle=True,
batch_size=16,
collate_fn=data_collator,
)
>>> tf_validation_set = tokenized_imdb["test"].to_tf_dataset(
columns=["attention_mask", "input_ids", "label"],
shuffle=False,
batch_size=16,
collate_fn=data_collator,
)
如果您不熟悉使用 Kera 对模型进行微调,请在这里查看基本教程!
建立一个优化器函数、学习速率进度表和一些训练超参数:
>>> from transformers import create_optimizer
>>> import tensorflow as tf
>>> batch_size = 16
>>> num_epochs = 5
>>> batches_per_epoch = len(tokenized_imdb["train"]) // batch_size
>>> total_train_steps = int(batches_per_epoch * num_epochs)
>>> optimizer, schedule = create_optimizer(init_lr=2e-5, num_warmup_steps=0, num_train_steps=total_train_steps)
使用 TFAutoModelForSequenceAnalysis 加载 DistilBERT 以及预期的标签数量:
>>> from transformers import TFAutoModelForSequenceClassification
>>> model = TFAutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=2)
通过编译配置培训模型:
>>> import tensorflow as tf
>>> model.compile(optimizer=optimizer)
调用 fit 对模型进行微调:
>>> model.fit(x=tf_train_set, validation_data=tf_validation_set, epochs=3)
有关如何微调文本分类模型的更深入示例,请查看相应的 PyTorch 笔记本或 TensorFlow 笔记本。
边栏推荐
猜你喜欢
随机推荐
pyspark.sql之实现collect_list的排序
什么是跨境电商测评?
WebGL:BabylonJS入门——初探:我的世界
《身体是革命的本钱,该注意时还是要注意!》
Basic principles and common methods of digital image processing
pyspark dataframe分位数计算
Several important functional operations of general two-way circular list
spark shuffle
升职加薪之SQL索引
The recycle bin has been showed no problem to empty the icon
解决跨域问题的三种方式
It is deeply recognized that the compiler can cause differences in the compilation results
LNK1123: Failed during transition to COFF: invalid or corrupt file
pyspark jieba 集群模型 对文本进行切词
AsyncTask 串行还是并行
Qt control - QTextEdit usage record
js总结,基础篇
.Net Core动态注入
A shortcut method for writing menu commands in C
PAT1027 打印沙漏