🚀 粤语版小模型Whisper - Alvin
本模型是基于粤语对 openai/whisper-small 进行微调的版本。在 Common Voice 16.0 数据集上,其字符错误率(CER)在无标点时为 7.93%,有标点时为 9.72%。
✨ 主要特性
- 基于预训练模型
openai/whisper-small
进行粤语微调。
- 在多个粤语数据集上进行训练和评估,具有较好的粤语语音识别性能。
- 支持多种推理加速方法,如 Flash Attention 和 Speculative Decoding。
📦 安装指南
文档未提及安装步骤,暂不提供。
💻 使用示例
基础用法
import librosa
import torch
from transformers import WhisperForConditionalGeneration, WhisperProcessor
y, sr = librosa.load('audio.mp3', sr=16000)
MODEL_NAME = "alvanlii/whisper-small-cantonese"
processor = WhisperProcessor.from_pretrained(MODEL_NAME)
model = WhisperForConditionalGeneration.from_pretrained(MODEL_NAME)
processed_in = processor(y, sampling_rate=sr, return_tensors="pt")
gout = model.generate(
input_features=processed_in.input_features,
output_scores=True, return_dict_in_generate=True
)
transcription = processor.batch_decode(gout.sequences, skip_special_tokens=True)[0]
print(transcription)
高级用法
使用 huggingface pipelines 进行推理:
from transformers import pipeline
MODEL_NAME = "alvanlii/whisper-small-cantonese"
lang = "zh"
device = 0
pipe = pipeline(
task="automatic-speech-recognition",
model=MODEL_NAME,
chunk_length_s=30,
device=device,
)
pipe.model.config.forced_decoder_ids = pipe.tokenizer.get_decoder_prompt_ids(language=lang, task="transcribe")
text = pipe('audio.mp3')["text"]
📚 详细文档
训练和评估数据
训练数据
- CantoMap:Winterstein, Grégoire, Tang, Carmen 和 Lai, Regine (2020) "CantoMap: a Hong Kong Cantonese MapTask Corpus",发表于 The 12th Language Resources and Evaluation Conference 会议论文集,Marseille: European Language Resources Association, p. 2899 - 2906。
- Cantonse - ASR:Yu, Tiezheng, Frieske, Rita, Xu, Peng, Cahyawijaya, Samuel, Yiu, Cheuk Tung, Lovenia, Holy, Dai, Wenliang, Barezi, Elham, Chen, Qifeng, Ma, Xiaojuan, Shi, Bertram, Fung, Pascale (2022) "Automatic Speech Recognition Datasets in Cantonese: A Survey and New Dataset",2022 年。链接:https://arxiv.org/pdf/2201.02419.pdf
名称 |
时长(小时) |
Common Voice 16.0 zh - HK Train |
138 |
Common Voice 16.0 yue Train |
85 |
Common Voice 17.0 yue Train |
178 |
Cantonese - ASR |
72 |
CantoMap |
23 |
Pseudo - Labelled YouTube Data |
438 |
评估数据
使用 Common Voice 16.0 yue 测试集进行评估。
评估结果
- 字符错误率(CER,越低越好):
- 无标点:0.0793
- 有标点:0.0972,较之前版本的 0.1073 和 0.1581 有所下降
- GPU 推理(使用 Fast Attention,示例如下):每个样本 0.055 秒
- 注意:所有 GPU 评估均在 RTX 3090 GPU 上进行
- GPU 推理:每个样本 0.308 秒
- CPU 推理:每个样本 2.57 秒
- GPU 显存占用:约 1.5 GB
模型加速
只需添加 attn_implementation="sdpa"
即可使用 Flash Attention 进行加速。
from transformers import AutoModelForSpeechSeq2Seq
import torch
torch_dtype = torch.float16
model = AutoModelForSpeechSeq2Seq.from_pretrained(
"alvanlii/whisper-small-cantonese",
torch_dtype=torch_dtype,
low_cpu_mem_usage=True,
use_safetensors=True,
attn_implementation="sdpa",
)
使用 Flash Attention 后,每个样本的推理时间从 0.308 秒减少到 0.055 秒。
推测解码
可以使用更大的模型,然后使用 alvanlii/whisper-small-cantonese
加速推理,且基本不损失准确性。
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor
import torch
torch_dtype = torch.float16
device = "cuda" if torch.cuda.is_available() else "cpu"
model_id = "simonl0909/whisper-large-v2-cantonese"
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id,
torch_dtype=torch_dtype,
low_cpu_mem_usage=True,
use_safetensors=True,
attn_implementation="sdpa",
)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
assistant_model_id = "alvanlii/whisper-small-cantonese"
assistant_model = AutoModelForSpeechSeq2Seq.from_pretrained(
assistant_model_id,
torch_dtype=torch_dtype,
low_cpu_mem_usage=True,
use_safetensors=True,
attn_implementation="sdpa",
)
assistant_model.to(device)
inputs = processor(...)
model.generate(**inputs, use_cache=True, assistant_model=assistant_model)
原始的 simonl0909/whisper-large-v2-cantonese
模型每个样本推理时间为 0.714 秒,CER 为 7.65%。使用 alvanlii/whisper-small-cantonese
进行推测解码后,每个样本推理时间为 0.137 秒,CER 为 7.67%,速度大幅提升。
Whisper.cpp
截至 2024 年 6 月,已上传用于 Whisper cpp 的 GGML 二进制文件。可以从 这里 下载二进制文件,并在 这里 进行测试。
Whisper CT2
若要在 WhisperX 或 FasterWhisper 中使用,需要 CT2 文件。转换后的模型文件位于 这里。
训练超参数
属性 |
详情 |
学习率 |
5e - 5 |
训练批次大小 |
25(在 1 块 3090 GPU 上) |
评估批次大小 |
8 |
梯度累积步数 |
4 |
总训练批次大小 |
25 x 4 = 100 |
优化器 |
Adam,beta=(0.9, 0.999),epsilon = 1e - 08 |
学习率调度器类型 |
线性 |
学习率调度器热身步数 |
500 |
训练步数 |
15000 |
数据增强 |
无 |
🔧 技术细节
文档未提供足够详细的技术实现细节,暂不提供。
📄 许可证
本模型遵循 Apache - 2.0 许可证。