本章简介
- 如何从模型中心(hub)加载大型数据集
- 如何使用高级的
TrainerAPI 微调一个模型 - 如何使用自定义训练过程
- 如何利用🤗 Accelerate 库在所有分布式设备上轻松运行自定义训练过程
预处理数据
处理数据
从模型中心(hub)加载数据集
模型中心(hub)不仅仅包含模型,还有许多别的语言的数据集。访问 Datasets 的链接即可进行浏览。我们建议你在完成本节的学习后阅读一下 加载和处理新的数据集 这篇文章,这会让你对 huggingface 的数据集理解更加清晰。 现在让我们使用 MRPC 数据集中的 GLUE 基准测试数据集 作为我们训练所使用的数据集,它是构成 MRPC 数据集的 10 个数据集之一,作为一个用于衡量机器学习模型在 10 个不同文本分类任务中性能的学术基准。
from datasets import load_dataset
raw_datasets = load_dataset("glue", "mrpc")
raw_datasetsDatasetDict({
train: Dataset({
features: ['sentence1', 'sentence2', 'label', 'idx'],
num_rows: 3668
})
validation: Dataset({
features: ['sentence1', 'sentence2', 'label', 'idx'],
num_rows: 408
})
test: Dataset({
features: ['sentence1', 'sentence2', 'label', 'idx'],
num_rows: 1725
})
})预处理数据集
def tokenize_function(example):
return tokenizer(example["sentence1"], example["sentence2"], truncation=True)tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)
tokenized_datasetsDatasetDict({
train: Dataset({
features: ['attention_mask', 'idx', 'input_ids', 'label', 'sentence1', 'sentence2', 'token_type_ids'],
num_rows: 3668
})
validation: Dataset({
features: ['attention_mask', 'idx', 'input_ids', 'label', 'sentence1', 'sentence2', 'token_type_ids'],
num_rows: 408
})
test: Dataset({
features: ['attention_mask', 'idx', 'input_ids', 'label', 'sentence1', 'sentence2', 'token_type_ids'],
num_rows: 1725
})
})