许可协议: mit
数据集:
- mozilla-foundation/common_voice_16_1
语言:
- es
库名称: transformers
管道标签: automatic-speech-recognition
标签:
- spanish
- español
- speech
- recognition
- whisper
- distil-whisper
distil-whisper-large-v3-es
此仓库为基于Mozilla Common Voice数据集v16.1训练的Whisper v3大模型蒸馏版本。
该模型由SandboxAI与Universidad Nacional de Rio Negro合作开发完成。
使用方式
Distil-Whisper在Hugging Face 🤗 Transformers 4.35及以上版本中支持。运行模型前需先安装最新版Transformers库。本例中我们还将安装🤗 Datasets以从Hub加载示例音频数据集:
pip install --upgrade pip
pip install --upgrade transformers accelerate datasets[audio]
短音频转录
可通过pipeline
类转录短音频文件(<30秒):
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
from datasets import load_dataset
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model_id = "marianbasti/distil-whisper-large-v3-es"
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
max_new_tokens=128,
torch_dtype=torch_dtype,
device=device,
)
dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
sample = dataset[0]["audio"]
result = pipe(sample)
print(result["text"])
转录本地音频文件时,只需传入文件路径:
- result = pipe(sample)
+ result = pipe("audio.mp3")
长音频转录
Distil-Whisper采用分块算法处理长音频(>30秒)。实际测试中该算法比OpenAI原始序列化处理快9倍(参见论文表7)。
启用分块需设置chunk_length_s
参数,15秒为最优值。启用批处理需设置batch_size
:
pipe = pipeline(
...,
chunk_length_s=15,
batch_size=16,
)
推测解码
Distil-Whisper可作为Whisper的辅助模型进行推测解码,在保证输出完全相同的前提下提速2倍。以下示例展示如何加载辅助模型:
assistant_model = AutoModelForCausalLM.from_pretrained(
"marianbasti/distil-whisper-large-v3-es",
torch_dtype=torch_dtype
)
pipe = pipeline(
...,
generate_kwargs={"assistant_model": assistant_model},
)
训练过程
模型在单张RTX3090上训练约60小时(60,000步/1.47轮),关键参数如下:
--learning_rate 1e-4
--per_device_train_batch_size 8
--max_steps 60000
--gradient_checkpointing
--freeze_encoder
性能指标
模型WER为5.11%(正交WER 10.15%)。
许可
继承OpenAI Whisper的MIT许可证。
引用
若使用本模型,请引用Distil-Whisper论文:
@misc{gandhi2023distilwhisper,
title={Distil-Whisper: Robust Knowledge Distillation via Large-Scale Pseudo Labelling},
author={Sanchit Gandhi and Patrick von Platen and Alexander M. Rush},
year={2023},
eprint={2311.00430},
archivePrefix={arXiv},
primaryClass={cs.CL}
}