许可证:apache-2.0
语言:
标签:
- common_voice_8_0
- 基于训练器生成
- hf-asr-leaderboard
- mozilla-foundation/common_voice_8_0
- robust-speech-event
数据集:
- mozilla-foundation/common_voice_8_0
模型索引:
- 名称:wave2vec-xls-r-300m-es
结果:
- 任务:
名称:语音识别
类型:automatic-speech-recognition
数据集:
名称:mozilla-foundation/common_voice_8_0 es
类型:mozilla-foundation/common_voice_8_0
参数:es
指标:
- 任务:
名称:自动语音识别
类型:automatic-speech-recognition
数据集:
名称:Robust Speech Event - 开发数据
类型:speech-recognition-community-v2/dev_data
参数:es
指标:
- 任务:
名称:自动语音识别
类型:automatic-speech-recognition
数据集:
名称:Robust Speech Event - 测试数据
类型:speech-recognition-community-v2/eval_data
参数:es
指标:
Wav2Vec2-XLSR-300m-es
此模型是基于西班牙语common_voice数据集对facebook/wav2vec2-xls-r-300m进行微调的版本,得益于OVHcloud为语音识别挑战赛慷慨提供的GPU算力支持。
在评估集上取得了以下结果:
无语言模型:
使用5-gram语言模型:
结合5-gram语言模型使用
该模型可通过处理器内嵌的n-gram(n=5)如下使用:
import re
from transformers import AutoModelForCTC, Wav2Vec2ProcessorWithLM
import torch
processor = Wav2Vec2ProcessorWithLM.from_pretrained("polodealvarado/xls-r-300m-es")
model = AutoModelForCTC.from_pretrained("polodealvarado/xls-r-300m-es")
def remove_extra_chars(batch):
chars_to_ignore_regex = '[^a-záéíóúñ ]'
text = batch["translation"][target_lang]
batch["text"] = re.sub(chars_to_ignore_regex, "", text.lower())
return batch
def prepare_dataset(batch):
audio = batch["audio"]
batch["input_values"] = processor(audio["array"], sampling_rate=audio["sampling_rate"], return_tensors="pt", padding=True).input_values[0]
with processor.as_target_processor():
batch["labels"] = processor(batch["sentence"]).input_ids
return batch
common_voice_test = load_dataset("mozilla-foundation/common_voice_8_0", "es", split="test", use_auth_token=True)
common_voice_test = common_voice_test.remove_columns(["accent", "age", "client_id", "down_votes", "gender", "locale", "segment", "up_votes"])
common_voice_test = common_voice_test.cast_column("audio", Audio(sampling_rate=16_000))
common_voice_test = common_voice_test.map(remove_extra_chars, remove_columns=dataset.column_names)
common_voice_test = common_voice_test.map(prepare_dataset)
inputs = torch_tensor(common_voice_test[0]["input_values"])
with torch.no_grad():
logits = model(inputs).logits
pred_ids = torch.argmax(logits, dim=-1)
text = processor.batch_decode(logits.numpy()).text
print(text)
此外,可通过eval.py文件进行评估:
$ python eval.py --model_id polodealvarado/xls-r-300m-es --dataset mozilla-foundation/common_voice_8_0 --config es --device 0 --split test
训练超参数
训练过程中使用的超参数如下:
- 学习率:0.0003
- 训练批次大小:16
- 评估批次大小:8
- 随机种子:42
- 梯度累积步数:2
- 总训练批次大小:32
- 优化器:Adam(betas=(0.9,0.999),epsilon=1e-08)
- 学习率调度器类型:线性
- 学习率预热步数:500
- 训练轮次:4
- 混合精度训练:原生AMP
训练结果
训练损失 |
轮次 |
步数 |
验证损失 |
WER |
3.6747 |
0.3 |
400 |
0.6535 |
0.5926 |
0.4439 |
0.6 |
800 |
0.3753 |
0.3193 |
0.3291 |
0.9 |
1200 |
0.3267 |
0.2721 |
0.2644 |
1.2 |
1600 |
0.2816 |
0.2311 |
0.24 |
1.5 |
2000 |
0.2647 |
0.2179 |
0.2265 |
1.79 |
2400 |
0.2406 |
0.2048 |
0.1994 |
2.09 |
2800 |
0.2357 |
0.1869 |
0.1613 |
2.39 |
3200 |
0.2242 |
0.1821 |
0.1546 |
2.69 |
3600 |
0.2123 |
0.1707 |
0.1441 |
2.99 |
4000 |
0.2067 |
0.1619 |
0.1138 |
3.29 |
4400 |
0.2044 |
0.1519 |
0.1072 |
3.59 |
4800 |
0.1917 |
0.1457 |
0.0992 |
3.89 |
5200 |
0.1900 |
0.1438 |
框架版本
- Transformers 4.16.0.dev0
- PyTorch 1.10.1+cu102
- Datasets 1.17.1.dev0
- Tokenizers 0.11.0