用户提供的翻译信息如下:
语言:
- 英语
许可证:MIT
库名称:transformers
标签:
- 音频
- 自动语音识别
小部件:
- 示例标题:LibriSpeech样本1
来源:https://cdn-media.huggingface.co/speech_samples/sample1.flac
- 示例标题:LibriSpeech样本2
来源:https://cdn-media.huggingface.co/speech_samples/sample2.flac
管道标签:自动语音识别
Distil-Whisper: Distil-Large-v3.5
Distil-Whisper是OpenAI Whisper-Large-v3 的知识蒸馏版本,相关技术细节在论文《通过大规模伪标签实现稳健知识蒸馏》中描述。作为Distil-Whisper英语系列的最新成员,Distil-Large-v3.5在保持前代模型高效性的同时,提供了更优的性能表现。
相比早期模型,Distil-Large-v3.5的训练数据量增加了4倍以上(达98,000小时),并采用了"耐心"教师模型策略,在蒸馏过程中延长训练周期并应用了强力的数据增强技术(SpecAugment)。这使得模型在鲁棒性和准确性方面都优于之前的Distil-Whisper版本,可作为直接替代方案使用。
为何在已有Whisper-Large-v3-Turbo的情况下还要考虑Distil-Large-v3.5?
- 它在准确性和效率之间提供了不同的平衡点,在保持比Whisper-Large-v3-Turbo快约1.5倍的同时,短格式转录性能略优,长格式转录仅落后约1%。
- 它非常适合作为Whisper-Large-v3推测解码的草稿模型。通过在训练期间冻结编码器,我们只需加载两个额外的解码器层并前向传播编码器一次。相比Whisper-Large-v3,这能实现约2倍的推理加速,同时保持输出完全一致。
本模型是Bofeng Huang、Eustache Le Bihan、Steven Zheng和Vaibhav Srivastav在🤗平台上的协作成果。
目录
性能表现
模型在短格式和长格式转录任务上进行了评估,使用领域内(ID)和领域外(OOD)数据集来测试准确性、泛化能力和鲁棒性。
注意:此处展示的词错误率(WER)结果是经过后规范化处理的,包括转换为小写、移除符号和标点等操作。
短格式评估
我们按照🤗 Open ASR排行榜的方法,在5个领域内(ID)测试集和2个领域外(OOD)测试集上评估了模型的短格式转录性能。
数据集 |
规模/小时 |
large-v3 |
large-v3-turbo |
distil-v3 |
distil-v3.5 |
AMI |
8.68 |
15.95 |
16.13 |
15.16 |
14.63 |
Gigaspeech |
35.36 |
10.02 |
10.14 |
10.08 |
9.84 |
LS Clean |
5.40 |
2.01 |
2.10 |
2.54 |
2.37 |
LS Other |
5.34 |
3.91 |
4.24 |
5.19 |
5.04 |
Tedlium |
2.61 |
3.86 |
3.57 |
3.86 |
3.64 |
----------- |
----- |
----- |
----- |
----- |
----- |
Earnings22 |
5.43 |
11.29 |
11.63 |
11.79 |
11.29 |
SPGISpeech |
100.00 |
2.94 |
2.97 |
3.27 |
2.87 |
----------- |
----- |
----- |
----- |
----- |
----- |
ID平均值 |
|
7.15 |
7.24 |
7.37 |
7.10 |
OOD平均值 |
|
7.12 |
7.30 |
7.53 |
7.08 |
平均值 |
|
7.14 |
7.25 |
7.41 |
7.10 |
注:ID/OOD分类基于distil-v3和distil-v3.5的训练数据。由于Large-v3和large-v3-turbo的训练语料详情未知,此分类可能无法准确反映它们真实的领域内外表现。
长格式评估
我们使用顺序解码算法(condition_on_prev_tokens=False, return_timestamps=True),在1个领域内(ID)测试集和4个领域外(OOD)测试集上评估了模型的长格式转录性能。
数据集 |
规模/小时 |
large-v3-turbo |
distil-v2 |
distil-v3 |
distil-v3.5 |
tedlium-long-form |
2.47 |
3.07 |
9.66 |
3.9 |
4.63 |
----------------- |
----- |
----- |
----- |
----- |
----- |
meanwhile |
1.01 |
5.03 |
16.75 |
7.04 |
6.79 |
earnings21 |
39.26 |
9.84 |
15.09 |
10.54 |
10.6 |
earnings22 |
119.89 |
13.32 |
19.11 |
15.06 |
14.19 |
rev16 |
16.16 |
12.82 |
21.15 |
13.76 |
13.98 |
----------------- |
----- |
----- |
----- |
----- |
----- |
ID平均值 |
|
3.07 |
9.66 |
3.9 |
4.63 |
OOD平均值 |
|
10.25 |
18.03 |
11.6 |
11.39 |
平均值 |
|
8.82 |
16.35 |
10.06 |
10.04 |
注:ID/OOD分类基于distil-v3和distil-v3.5的训练数据。由于Large-v3和large-v3-turbo的训练语料详情未知,此分类可能无法准确反映它们真实的领域内外表现。
下方的实时因子(RTFx)测量结果显示,在长格式转录任务上,Distil-Large-v3.5比Whisper-Large-v3-Turbo快约1.5倍。
数据集 |
large-v3-turbo |
distil-v2 |
distil-v3 |
distil-v3.5 |
tedlium-long-form |
34.33 |
27.96 |
44.95 |
45.19 |
meanwhile |
26.55 |
28.01 |
40.84 |
42.48 |
earnings21 |
35.25 |
36.66 |
54.69 |
54.3 |
earnings22 |
39.08 |
42.09 |
57.28 |
58.8 |
rev16 |
33.86 |
23.87 |
45.43 |
45.91 |
----------------- |
----- |
----- |
----- |
----- |
平均值 |
33.81 |
31.72 |
48.64 |
49.34 |
Transformers使用指南
Distil-Large-v3.5从4.39版本开始支持Hugging Face 🤗 Transformers库。要运行模型,首先安装最新版本的Transformers。本示例中,我们还将安装🤗 Datasets以从Hugging Face Hub加载示例音频数据集:
pip install --upgrade pip
pip install --upgrade transformers accelerate datasets[audio]
短格式转录
可以使用pipeline
类转录短格式音频文件(<30秒):
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
from datasets import load_dataset
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model_id = "distil-whisper/distil-large-v3.5"
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
max_new_tokens=128,
torch_dtype=torch_dtype,
device=device,
)
dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
sample = dataset[0]["audio"]
result = pipe(sample)
print(result["text"])
要转录本地音频文件,只需在调用pipeline时传入音频文件路径:
- result = pipe(sample)
+ result = pipe("audio.mp3")
要获取分段级时间戳,传递参数return_timestamps=True
并返回"chunks"
输出:
result = pipe(sample, return_timestamps=True)
print(result["chunks"])
如需更精细控制生成参数,可直接使用模型+处理器API:
临时生成参数可传递给model.generate
,包括用于束搜索的num_beams
、用于分段级时间戳的return_timestamps
以及用于提示的prompt_ids
。详见文档字符串。
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor
from datasets import Audio, load_dataset
device =