许可证:apache-2.0
标签:
- 训练生成
指标:
- 词错误率(wer)
模型索引:
- 名称:whisper-large-v2-spanish
结果:[]
whisper-large-v2-spanish
该模型是基于openai/whisper-large-v2在特定数据集上微调的版本。
在评估集上取得了以下结果:
- 损失:0.1466
- 词错误率(WER):0.0855
模型描述
需补充更多信息
预期用途与限制
需补充更多信息
训练与评估数据
需补充更多信息
训练过程
训练超参数
训练过程中使用了以下超参数:
- 学习率:1e-05
- 训练批次大小:16
- 评估批次大小:16
- 随机种子:42
- 优化器:Adam(β1=0.9,β2=0.999,ε=1e-08)
- 学习率调度器类型:线性
- 学习率预热步数:500
- 训练总步数:25000
- 混合精度训练:原生AMP
训练结果
训练损失 |
周期 |
步数 |
验证损失 |
WER |
0.1908 |
0.03 |
1000 |
0.2235 |
0.1154 |
0.1888 |
0.07 |
2000 |
0.2132 |
0.1131 |
0.167 |
0.1 |
3000 |
0.2115 |
0.1133 |
0.1752 |
0.14 |
4000 |
0.2081 |
0.1146 |
0.1656 |
0.17 |
5000 |
0.2002 |
0.1073 |
0.1535 |
0.21 |
6000 |
0.1971 |
0.1086 |
0.1854 |
0.24 |
7000 |
0.1927 |
0.1048 |
0.1722 |
0.28 |
8000 |
0.1889 |
0.1043 |
0.166 |
0.31 |
9000 |
0.1850 |
0.1022 |
0.1277 |
0.35 |
10000 |
0.1820 |
0.1032 |
0.1457 |
0.38 |
11000 |
0.1777 |
0.0998 |
0.169 |
0.42 |
12000 |
0.1771 |
0.0982 |
0.1612 |
0.45 |
13000 |
0.1724 |
0.0976 |
0.1616 |
0.49 |
14000 |
0.1693 |
0.0956 |
0.1556 |
0.52 |
15000 |
0.1671 |
0.0942 |
0.1448 |
0.56 |
16000 |
0.1646 |
0.0930 |
0.117 |
0.59 |
17000 |
0.1613 |
0.0914 |
0.1441 |
0.62 |
18000 |
0.1596 |
0.0899 |
0.148 |
0.66 |
19000 |
0.1571 |
0.0895 |
0.1255 |
0.69 |
20000 |
0.1547 |
0.0874 |
0.1479 |
0.73 |
21000 |
0.1525 |
0.0885 |
0.1304 |
0.76 |
22000 |
0.1503 |
0.0861 |
0.1111 |
0.8 |
23000 |
0.1486 |
0.0867 |
0.1337 |
0.83 |
24000 |
0.1472 |
0.0854 |
0.1289 |
0.87 |
25000 |
0.1466 |
0.0855 |
转录示例:
from datasets import load_dataset, Audio
import torch
from transformers import WhisperProcessor, WhisperForConditionalGeneration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
processor = WhisperProcessor.from_pretrained("clu-ling/whisper-large-v2-spanish")
model = WhisperForConditionalGeneration.from_pretrained("clu-ling/whisper-large-v2-spanish").to(device)
forced_decoder_ids = processor.get_decoder_prompt_ids(language="es", task="transcribe")
commonvoice_eval = load_dataset("mozilla-foundation/common_voice_11_0", "es", split="validation", streaming=True)
commonvoice_eval = commonvoice_eval.cast_column("audio", Audio(sampling_rate=16000))
sample = next(iter(commonvoice_eval))["audio"]
input_features = processor(sample["array"], sampling_rate=sample["sampling_rate"], return_tensors="pt").input_features
predicted_ids = model.generate(input_features.to(device), forced_decoder_ids=forced_decoder_ids)
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
print(transcription)
评估:
在mozilla-foundation/common_voice_11_0
测试集上评估该模型。
from transformers.models.whisper.english_normalizer import BasicTextNormalizer
from datasets import load_dataset, Audio
import evaluate
import torch
import re
from transformers import WhisperProcessor, WhisperForConditionalGeneration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
wer_metric = evaluate.load("wer")
processor = WhisperProcessor.from_pretrained("clu-ling/whisper-large-v2-spanish")
model = WhisperForConditionalGeneration.from_pretrained("clu-ling/whisper-large-v2-spanish")
dataset = load_dataset("mozilla-foundation/common_voice_11_0", "es", split="test")
dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
def normalize(batch):
batch["gold_text"] = whisper_norm(batch['sentence'])
return batch
def map_wer(batch):
model.to(device)
forced_decoder_ids = processor.get_decoder_prompt_ids(language="es", task="transcribe")
inputs = processor(batch["audio"]["array"], sampling_rate=batch["audio"]["sampling_rate"], return_tensors="pt").input_features
with torch.no_grad():
generated_ids = model.generate(inputs=inputs.to(device), forced_decoder_ids=forced_decoder_ids)
transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
batch["predicted_text"] = whisper_norm(transcription)
return batch
processed_dataset = dataset.map(normalize)
predicted = processed_dataset.map(map_wer)
wer = wer_metric.compute(references=predicted['gold_text'], predictions=predicted['predicted_text'])
wer = round(100 * wer, 2)
print("WER:", wer)
框架版本
- Transformers 4.26.0.dev0
- PyTorch 1.13.1+cu117
- Datasets 2.8.1.dev0
- Tokenizers 0.13.2