许可证:apache-2.0
标签:
- 训练生成
指标:
- 词错误率(wer)
模型索引:
- 名称:whisper-small-spanish
结果:[]
whisper-small-sp
该模型是基于openai/whisper-small在commonvoice dataset v11
数据集上微调的版本。在评估集上取得了以下结果:
- 损失:0.4485
- 词错误率(WER):20.6842
模型描述
需补充更多信息
预期用途与限制
需补充更多信息
训练与评估数据
需补充更多信息
训练流程
训练超参数
训练过程中使用的超参数如下:
- 学习率:0.0005
- 训练批次大小:16
- 评估批次大小:8
- 随机种子:42
- 优化器:Adam(beta1=0.9,beta2=0.999,epsilon=1e-08)
- 学习率调度器类型:线性
- 学习率预热步数:500
- 训练总步数:25000
- 混合精度训练:原生AMP
训练结果
训练损失 |
周期 |
步数 |
验证损失 |
词错误率(WER) |
2.2671 |
0.13 |
1000 |
2.2108 |
76.2667 |
1.4465 |
0.26 |
2000 |
1.6057 |
67.8753 |
1.0997 |
0.39 |
3000 |
1.1928 |
54.2433 |
0.9389 |
0.52 |
4000 |
1.0020 |
47.8307 |
0.7881 |
0.65 |
5000 |
0.8933 |
46.0046 |
0.7596 |
0.78 |
6000 |
0.7721 |
38.5595 |
0.5678 |
0.91 |
7000 |
0.6903 |
36.2897 |
0.4412 |
1.04 |
8000 |
0.6476 |
32.7473 |
0.4239 |
1.17 |
9000 |
0.5973 |
30.8142 |
0.3935 |
1.3 |
10000 |
0.5444 |
29.0208 |
0.3307 |
1.43 |
11000 |
0.5024 |
27.0434 |
0.2937 |
1.56 |
12000 |
0.4608 |
24.7318 |
0.2471 |
1.69 |
13000 |
0.4259 |
22.8940 |
0.2357 |
1.82 |
14000 |
0.3936 |
21.6018 |
0.2292 |
1.95 |
15000 |
0.3776 |
20.8004 |
0.1493 |
2.08 |
16000 |
0.4599 |
24.0491 |
0.1708 |
2.21 |
17000 |
0.4370 |
23.3443 |
0.1385 |
2.34 |
18000 |
0.4277 |
22.3171 |
0.1288 |
2.47 |
19000 |
0.4050 |
21.0118 |
0.1627 |
2.6 |
20000 |
0.4507 |
23.4004 |
0.1675 |
2.73 |
21000 |
0.4346 |
22.8261 |
0.159 |
2.86 |
22000 |
0.4179 |
22.2949 |
0.1458 |
2.99 |
23000 |
0.3978 |
21.0810 |
0.0487 |
3.12 |
24000 |
0.4456 |
20.8617 |
0.0401 |
3.25 |
25000 |
0.4485 |
20.6842 |
转录示例:
from datasets import load_dataset, Audio
import torch
from transformers import WhisperProcessor, WhisperForConditionalGeneration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
processor = WhisperProcessor.from_pretrained("clu-ling/whisper-small-spanish")
model = WhisperForConditionalGeneration.from_pretrained("clu-ling/whisper-small-spanish").to(device)
forced_decoder_ids = processor.get_decoder_prompt_ids(language="es", task="transcribe")
commonvoice_eval = load_dataset("mozilla-foundation/common_voice_11_0", "es", split="validation", streaming=True)
commonvoice_eval = commonvoice_eval.cast_column("audio", Audio(sampling_rate=16000))
sample = next(iter(commonvoice_eval))["audio"]
input_features = processor(sample["array"], sampling_rate=sample["sampling_rate"], return_tensors="pt").input_features
predicted_ids = model.generate(input_features.to(device), forced_decoder_ids=forced_decoder_ids)
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
print(transcription)
评估:
在mozilla-foundation/common_voice_11_0
测试集上评估该模型。
from transformers.models.whisper.english_normalizer import BasicTextNormalizer
from datasets import load_dataset, Audio
import evaluate
import torch
import re
from transformers import WhisperProcessor, WhisperForConditionalGeneration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
wer_metric = evaluate.load("wer")
processor = WhisperProcessor.from_pretrained("clu-ling/whisper-small-spanish")
model = WhisperForConditionalGeneration.from_pretrained("clu-ling/whisper-small-spanish")
dataset = load_dataset("mozilla-foundation/common_voice_11_0", "es", split="test")
dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
def normalize(batch):
batch["gold_text"] = whisper_norm(batch['sentence'])
return batch
def map_wer(batch):
model.to(device)
forced_decoder_ids = processor.get_decoder_prompt_ids(language="es", task="transcribe")
inputs = processor(batch["audio"]["array"], sampling_rate=batch["audio"]["sampling_rate"], return_tensors="pt").input_features
with torch.no_grad():
generated_ids = model.generate(inputs=inputs.to(device), forced_decoder_ids=forced_decoder_ids)
transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
batch["predicted_text"] = whisper_norm(transcription)
return batch
processed_dataset = dataset.map(normalize)
predicted = processed_dataset.map(map_wer)
wer = wer_metric.compute(references=predicted['gold_text'], predictions=predicted['predicted_text'])
wer = round(100 * wer, 2)
print("词错误率(WER):", wer)
框架版本
- Transformers 4.25.1
- PyTorch 1.13.0+cu117
- Datasets 2.7.1
- Tokenizers 0.13.2