语言:
- 罗马尼亚语
许可证: apache-2.0
标签:
- 自动语音识别
- hf-asr排行榜
- 鲁棒语音赛事
数据集:
- mozilla-foundation/common_voice_8_0
- gigant/romanian_speech_synthesis_0_8_1
基础模型: facebook/wav2vec2-xls-r-300m
模型索引:
- 名称: wav2vec2-ro-300m_01
结果:
- 任务:
类型: 自动语音识别
名称: 自动语音识别
数据集:
名称: 鲁棒语音赛事
类型: speech-recognition-community-v2/dev_data
参数: ro
指标:
- 类型: wer
值: 46.99
名称: 开发集WER(无语言模型)
- 类型: cer
值: 16.04
名称: 开发集CER(无语言模型)
- 类型: wer
值: 38.63
名称: 开发集WER(带语言模型)
- 类型: cer
值: 14.52
名称: 开发集CER(带语言模型)
- 任务:
类型: 自动语音识别
名称: 自动语音识别
数据集:
名称: 通用语音
类型: mozilla-foundation/common_voice_8_0
参数: ro
指标:
- 类型: wer
值: 11.73
名称: 测试集WER(无语言模型)
- 类型: cer
值: 2.93
名称: 测试集CER(无语言模型)
- 类型: wer
值: 7.31
名称: 测试集WER(带语言模型)
- 类型: cer
值: 2.17
名称: 测试集CER(带语言模型)
- 任务:
类型: 自动语音识别
名称: 自动语音识别
数据集:
名称: 鲁棒语音赛事-测试数据
类型: speech-recognition-community-v2/eval_data
参数: ro
指标:
- 类型: wer
值: 43.23
名称: 测试集WER
您可以通过罗马尼亚语语音识别空间在线测试该模型
该模型在HuggingFace鲁棒语音挑战赛中位列罗马尼亚语语音识别第一名:
罗马尼亚语Wav2Vec2
该模型是基于facebook/wav2vec2-xls-r-300m在通用语音8.0-罗马尼亚语子集数据集上微调的版本,并额外使用了罗马尼亚语语音合成数据集进行训练。
在不使用5-gram语言模型优化的情况下,在评估集(通用语音8.0,罗马尼亚语子集,测试分割)上取得了以下结果:
- 损失: 0.1553
- WER: 0.1174
- CER: 0.0294
模型描述
架构基于facebook/wav2vec2-xls-r-300m,带有语音识别CTC头部,并添加了5-gram语言模型(使用pyctcdecode和kenlm)训练于罗马尼亚议会语料库数据集。需要这些库才能使语言模型增强的解码器工作。
预期用途与限制
该模型用于从16kHz采样的音频片段中进行罗马尼亚语语音识别。预测文本为小写且不包含任何标点符号。
使用方法
确保安装了正确的依赖项以使语言模型增强版本正常工作。您可以运行以下命令安装kenlm
和pyctcdecode
库:
pip install https://github.com/kpu/kenlm/archive/master.zip pyctcdecode
使用transformers
框架,您可以通过以下代码加载模型:
from transformers import AutoProcessor, AutoModelForCTC
processor = AutoProcessor.from_pretrained("gigant/romanian-wav2vec2")
model = AutoModelForCTC.from_pretrained("gigant/romanian-wav2vec2")
或者,如果您想测试模型,可以从transformers
加载自动语音识别管道:
from transformers import pipeline
asr = pipeline("automatic-speech-recognition", model="gigant/romanian-wav2vec2")
使用datasets
库的示例
首先,您需要加载数据
我们将在此示例中使用罗马尼亚语语音合成数据集。
from datasets import load_dataset
dataset = load_dataset("gigant/romanian_speech_synthesis_0_8_1")
您可以使用IPython.display
库收听样本:
from IPython.display import Audio
i = 0
sample = dataset["train"][i]
Audio(sample["audio"]["array"], rate = sample["audio"]["sampling_rate"])
该模型训练用于处理16kHz采样的音频,因此如果数据集中的音频采样率不同,我们需要重新采样。
在示例中,音频采样率为48kHz。我们可以通过检查dataset["train"][0]["audio"]["sampling_rate"]
来确认这一点。
以下代码使用torchaudio
库重新采样音频:
import torchaudio
import torch
i = 0
audio = sample["audio"]["array"]
rate = sample["audio"]["sampling_rate"]
resampler = torchaudio.transforms.Resample(rate, 16_000)
audio_16 = resampler(torch.Tensor(audio)).numpy()
收听重新采样后的样本:
Audio(audio_16, rate=16000)
现在您可以通过运行以下代码获取模型预测:
predicted_text = asr(audio_16)
ground_truth = dataset["train"][i]["sentence"]
print(f"预测文本: {predicted_text}")
print(f"真实文本: {ground_truth}")
训练和评估数据
训练数据:
评估数据:
训练过程
训练超参数
训练期间使用了以下超参数:
- 学习率: 0.003
- 训练批次大小: 16
- 评估批次大小: 8
- 随机种子: 42
- 梯度累积步数: 3
- 总训练批次大小: 48
- 优化器: Adam,betas=(0.9,0.999),epsilon=1e-08
- 学习率调度器类型: 线性
- 学习率预热步数: 500
- 训练轮数: 50.0
- 混合精度训练: Native AMP
训练结果
训练损失 |
轮次 |
步数 |
验证损失 |
WER |
CER |
2.9272 |
0.78 |
500 |
0.7603 |
0.7734 |
0.2355 |
0.6157 |
1.55 |
1000 |
0.4003 |
0.4866 |
0.1247 |
0.4452 |
2.33 |
1500 |
0.2960 |
0.3689 |
0.0910 |
0.3631 |
3.11 |
2000 |
0.2580 |
0.3205 |
0.0796 |
0.3153 |
3.88 |
2500 |
0.2465 |
0.2977 |
0.0747 |
0.2795 |
4.66 |
3000 |
0.2274 |
0.2789 |
0.0694 |
0.2615 |
5.43 |
3500 |
0.2277 |
0.2685 |
0.0675 |
0.2389 |
6.21 |
4000 |
0.2135 |
0.2518 |
0.0627 |
0.2229 |
6.99 |
4500 |
0.2054 |
0.2449 |
0.0614 |
0.2067 |
7.76 |
5000 |
0.2096 |
0.2378 |
0.0597 |
0.1977 |
8.54 |
5500 |
0.2042 |
0.2387 |
0.0600 |
0.1896 |
9.32 |
6000 |
0.2110 |
0.2383 |
0.0595 |
0.1801 |
10.09 |
6500 |
0.1909 |
0.2165 |
0.0548 |
0.174 |
10.87 |
7000 |
0.1883 |
0.2206 |
0.0559 |
0.1685 |
11.65 |
7500 |
0.1848 |
0.2097 |
0.0528 |
0.1591 |
12.42 |
8000 |
0.1851 |
0.2039 |
0.0514 |
0.1537 |
13.2 |
8500 |
0.1881 |
0.2065 |
0.0518 |
0.1504 |
13.97 |
9000 |
0.1840 |
0.1972 |
0.0499 |
0.145 |
14.75 |
9500 |
0.1845 |
0.2029 |
0.0517 |
0.1417 |
15.53 |
10000 |
0.1884 |
0.2003 |
0.0507 |
0.1364 |
16.3 |
10500 |
0.2010 |
0.2037 |
0.0517 |
0.1331 |
17.08 |
11000 |
0.1838 |
0.1923 |
0.0483 |
0.129 |
17.86 |
11500 |
0.1818 |
0.1922 |
0.0489 |
0.1198 |
18.63 |
12000 |
0.1760 |
0.1861 |
0.0465 |
0.1203 |
19.41 |
12500 |
0.1686 |
0.1839 |
0.0465 |
0.1225 |
20.19 |
13000 |
0.1828 |
0.1920 |
0.0479 |
0.1145 |
20.96 |
13500 |
0.1673 |
0.1784 |
0.0446 |
0.1053 |
21.74 |
14000 |
0.1802 |
0.1810 |
0.0456 |
0.1071 |
22.51 |
14500 |
0.1769 |
0.1775 |
0.0444 |
0.1053 |
23.29 |
15000 |
0.1920 |
0.1783 |
0.0457 |
0.1024 |
24.07 |
15500 |
0.1904 |
0.1775 |
0.0446 |
0.0987 |
24.84 |
16000 |
0.1793 |
0.1762 |
0.0446 |
0.0949 |
25.62 |
16500 |
0.1801 |
0.1766 |
0.0443 |
0.0942 |
26.4 |
17000 |
0.1731 |
0.1659 |
0.0423 |
0.0906 |
27.17 |
17500 |
0.1776 |
0.1698 |
0.0424 |
0.0861 |
27.95 |
18000 |
0.1716 |
0.1600 |
0.0406 |
0.0851 |
28.73 |
18500 |
0.1662 |
0.1630 |
0.0410 |
0.0844 |
29.5 |
19000 |
0.1671 |
0.1572 |
0.0393 |
0.0792 |
30.28 |
19500 |
0.1768 |
0.1599 |
0.0407 |
0.0798 |
31.06 |
20000 |
0.1732 |
0.1558 |
0.0394 |
0.0779 |
31.83 |
20500 |
0.1694 |
0.1544 |
0.0388 |
0.0718 |
32.61 |
21000 |
0.1709 |
0.1578 |
0.0399 |
0.0732 |
33.38 |
21500 |
0.1697 |
0.1523 |
0.0391 |
0.0708 |
34.16 |
22000 |
0.1616 |
0.1474 |
0.0375 |
0.0678 |
34.94 |
22500 |
0.1698 |
0.1474 |
0.0375 |
0.0642 |
35.71 |
23000 |
0.1681 |
0.1459 |
0.0369 |
0.0661 |
36.49 |
23500 |
0.1612 |
0.1411 |
0.0357 |
0.0629 |
37.27 |
24000 |
0.1662 |
0.1414 |
0.0355 |
0.0587 |
38.04 |
24500 |
0.1659 |
0.1408 |
0.0351 |
0.0581 |
38.82 |
25000 |
0.1612 |
0.1382 |
0.0352 |
0.0556 |
39.6 |
25500 |
0.1647 |
0.1376 |
0.0345 |
0.0543 |
40.37 |
26000 |
0.1658 |
0.1335 |
0.0337 |
0.052 |
41.15 |
26500 |
0.1716 |
0.1369 |
0.0343 |
0.0513 |
41.92 |
27000 |
0.1600 |
0.1317 |
0.0330 |
0.0491 |
42.7 |
27500 |
0.1671 |
0.1311 |
0.0328 |
0.0463 |
43.48 |
28000 |
0.1613 |
0.1289 |
0.0324 |
0.0468 |
44.25 |
28500 |
0.1599 |
0.1260 |
0.0315 |
0.0435 |
45.03 |
29000 |
0.1556 |
0.1232 |
0.0308 |
0.043 |
45.81 |
29500 |
0.1588 |
0.1240 |
0.0309 |
0.0421 |
46.58 |
30000 |
0.1567 |
0.1217 |
0.0308 |
0.04 |
47.36 |
30500 |
0.1533 |
0.1198 |
0.0302 |
0.0389 |
|
|
|
|
|