语言:
- 捷克语
许可证: Apache-2.0
标签:
- 自动语音识别
- 训练生成
- hf-asr排行榜
- mozilla-foundation/common_voice_8_0
- 鲁棒语音事件
- xlsr微调周
数据集:
- mozilla-foundation/common_voice_8_0
- ovm
- pscr
- vystadial2016
基础模型: facebook/wav2vec2-xls-r-300m
模型索引:
- 名称: 捷克语comodoro Wav2Vec2 XLSR 300M 250小时数据
结果:
- 任务:
类型: 自动语音识别
名称: 自动语音识别
数据集:
名称: Common Voice 8
类型: mozilla-foundation/common_voice_8_0
参数: cs
指标:
- 类型: 词错误率(WER)
值: 7.3
名称: 测试WER
- 类型: 字符错误率(CER)
值: 2.1
名称: 测试CER
- 任务:
类型: 自动语音识别
名称: 自动语音识别
数据集:
名称: 鲁棒语音事件 - 开发数据
类型: speech-recognition-community-v2/dev_data
参数: cs
指标:
- 类型: 词错误率(WER)
值: 43.44
名称: 测试WER
- 任务:
类型: 自动语音识别
名称: 自动语音识别
数据集:
名称: 鲁棒语音事件 - 测试数据
类型: speech-recognition-community-v2/eval_data
参数: cs
指标:
- 类型: 词错误率(WER)
值: 38.5
名称: 测试WER
捷克语wav2vec2-xls-r-300m-cs-250
该模型是基于facebook/wav2vec2-xls-r-300m在common_voice 8.0数据集以及其他列出的数据集上进行微调的版本。
在评估集上达到以下结果:
- 损失: 0.1271
- 词错误率(WER): 0.1475
- 字符错误率(CER): 0.0329
使用语言模型的eval.py
脚本结果为:
- WER: 0.07274312090176113
- CER: 0.021207369275558875
模型描述
基于facebook/wav2vec2-large-xlsr-53在捷克语Common Voice数据集上微调。
使用此模型时,请确保语音输入采样率为16kHz。
该模型可直接使用(无需语言模型),如下所示:
import torch
import torchaudio
from datasets import load_dataset
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
test_dataset = load_dataset("mozilla-foundation/common_voice_8_0", "cs", split="test[:2%]")
processor = Wav2Vec2Processor.from_pretrained("comodoro/wav2vec2-xls-r-300m-cs-250")
model = Wav2Vec2ForCTC.from_pretrained("comodoro/wav2vec2-xls-r-300m-cs-250")
resampler = torchaudio.transforms.Resample(48_000, 16_000)
def speech_file_to_array_fn(batch):
speech_array, sampling_rate = torchaudio.load(batch["path"])
batch["speech"] = resampler(speech_array).squeeze().numpy()
return batch
test_dataset = test_dataset.map(speech_file_to_array_fn)
inputs = processor(test_dataset[:2]["speech"], sampling_rate=16_000, return_tensors="pt", padding=True)
with torch.no_grad():
logits = model(inputs.input_values, attention_mask=inputs.attention_mask).logits
predicted_ids = torch.argmax(logits, dim=-1)
print("预测:", processor.batch_decode(predicted_ids))
print("参考:", test_dataset[:2]["sentence"])
评估
可使用附带的eval.py
脚本评估模型:
python eval.py --model_id comodoro/wav2vec2-xls-r-300m-cs-250 --dataset mozilla-foundation/common-voice_8_0 --split test --config cs
训练和评估数据
训练使用了Common Voice 8.0的train
和validation
数据集,以及以下数据集:
-
Šmídl, Luboš and Pražák, Aleš, 2013, OVM – Otázky Václava Moravce, LINDAT/CLARIAH-CZ数字图书馆,位于查理大学数学与物理学院形式与应用语言学研究所(ÚFAL),http://hdl.handle.net/11858/00-097C-0000-000D-EC98-3.
-
Pražák, Aleš and Šmídl, Luboš, 2012, 捷克议会会议记录, LINDAT/CLARIAH-CZ数字图书馆,位于查理大学数学与物理学院形式与应用语言学研究所(ÚFAL),http://hdl.handle.net/11858/00-097C-0000-0005-CF9C-4.
-
Plátek, Ondřej; Dušek, Ondřej and Jurčíček, Filip, 2016, Vystadial 2016 – 捷克数据, LINDAT/CLARIAH-CZ数字图书馆,位于查理大学数学与物理学院形式与应用语言学研究所(ÚFAL),http://hdl.handle.net/11234/1-1740.
训练超参数
训练期间使用的超参数如下:
- 学习率: 0.0001
- 训练批次大小: 32
- 评估批次大小: 8
- 随机种子: 42
- 优化器: Adam,参数beta=(0.9,0.999),epsilon=1e-08
- 学习率调度器类型: 线性
- 学习率预热步数: 800
- 训练轮数: 5
- 混合精度训练: Native AMP
训练结果
训练损失 |
轮次 |
步数 |
验证损失 |
WER |
CER |
3.4203 |
0.16 |
800 |
3.3148 |
1.0 |
1.0 |
2.8151 |
0.32 |
1600 |
0.8508 |
0.8938 |
0.2345 |
0.9411 |
0.48 |
2400 |
0.3335 |
0.3723 |
0.0847 |
0.7408 |
0.64 |
3200 |
0.2573 |
0.2840 |
0.0642 |
0.6516 |
0.8 |
4000 |
0.2365 |
0.2581 |
0.0595 |
0.6242 |
0.96 |
4800 |
0.2039 |
0.2433 |
0.0541 |
0.5754 |
1.12 |
5600 |
0.1832 |
0.2156 |
0.0482 |
0.5626 |
1.28 |
6400 |
0.1827 |
0.2091 |
0.0463 |
0.5342 |
1.44 |
7200 |
0.1744 |
0.2033 |
0.0468 |
0.4965 |
1.6 |
8000 |
0.1705 |
0.1963 |
0.0444 |
0.5047 |
1.76 |
8800 |
0.1604 |
0.1889 |
0.0422 |
0.4814 |
1.92 |
9600 |
0.1604 |
0.1827 |
0.0411 |
0.4471 |
2.09 |
10400 |
0.1566 |
0.1822 |
0.0406 |
0.4509 |
2.25 |
11200 |
0.1619 |
0.1853 |
0.0432 |
0.4415 |
2.41 |
12000 |
0.1513 |
0.1764 |
0.0397 |
0.4313 |
2.57 |
12800 |
0.1515 |
0.1739 |
0.0392 |
0.4163 |
2.73 |
13600 |
0.1445 |
0.1695 |
0.0377 |
0.4142 |
2.89 |
14400 |
0.1478 |
0.1699 |
0.0385 |
0.4184 |
3.05 |
15200 |
0.1430 |
0.1669 |
0.0376 |
0.3886 |
3.21 |
16000 |
0.1433 |
0.1644 |
0.0374 |
0.3795 |
3.37 |
16800 |
0.1426 |
0.1648 |
0.0373 |
0.3859 |
3.53 |
17600 |
0.1357 |
0.1604 |
0.0361 |
0.3762 |
3.69 |
18400 |
0.1344 |
0.1558 |
0.0349 |
0.384 |
3.85 |
19200 |
0.1379 |
0.1576 |
0.0359 |
0.3762 |
4.01 |
20000 |
0.1344 |
0.1539 |
0.0346 |
0.3559 |
4.17 |
20800 |
0.1339 |
0.1525 |
0.0351 |
0.3683 |
4.33 |
21600 |
0.1315 |
0.1518 |
0.0342 |
0.3572 |
4.49 |
22400 |
0.1307 |
0.1507 |
0.0342 |
0.3494 |
4.65 |
23200 |
0.1294 |
0.1491 |
0.0335 |
0.3476 |
4.81 |
24000 |
0.1287 |
0.1491 |
0.0336 |
0.3475 |
4.97 |
24800 |
0.1271 |
0.1475 |
0.0329 |
框架版本
- Transformers 4.16.2
- Pytorch 1.10.1+cu102
- Datasets 1.18.3
- Tokenizers 0.11.0