许可证: mit
标签:
- 自动语音识别
- asr
- pytorch
- wav2vec2
- 沃洛夫语
- wo
模型索引:
- 名称: wav2vec2-xls-r-300m-wolof-lm
结果:
- 任务:
名称: 语音识别
类型: automatic-speech-recognition
指标:
- 名称: 测试WER
类型: wer
值: 21.25
- 名称: 验证损失
类型: Loss
值: 0.36
wav2vec2-xls-r-300m-wolof-lm
沃洛夫语是塞内加尔及周边国家使用的一种语言,该语言的资源较为匮乏,尤其在文本和语音领域。
为此,我们希望通过此项目为其贡献力量,这也是本仓库的初衷。
此模型是基于facebook/wav2vec2-xls-r-300m微调的版本,并使用了ALFFA_PUBLIC中最大的可用语音数据集训练的语言模型。
在评估集上取得了以下结果:
- 损失: 0.367826
- 词错误率(WER): 0.212565
模型描述
训练数据时长为16.8小时,我们将其分为10,000个音频文件用于训练,3,339个用于测试。
训练与评估数据
我们每1500步评估一次模型并记录,每33340步保存一次。
训练超参数
训练过程中使用的超参数如下:
- 学习率: 1e-4
- 训练批次大小: 3
- 评估批次大小: 8
- 总训练批次大小: 64
- 总评估批次大小: 64
- 优化器: Adam,参数为betas=(0.9,0.999),epsilon=1e-08
- 学习率调度器类型: 线性
- 学习率预热步数: 1000
- 训练轮数: 10.0
训练结果
步数 |
训练损失 |
验证损失 |
WER |
1500 |
2.854200 |
0.642243 |
0.543964 |
3000 |
0.599200 |
0.468138 |
0.429549 |
4500 |
0.468300 |
0.433436 |
0.405644 |
6000 |
0.427000 |
0.384873 |
0.344150 |
7500 |
0.377000 |
0.374003 |
0.323892 |
9000 |
0.337000 |
0.363674 |
0.306189 |
10500 |
0.302400 |
0.349884 |
0.283908 |
12000 |
0.264100 |
0.344104 |
0.277120 |
13500 |
0.254000 |
0.341820 |
0.271316 |
15000 |
0.208400 |
0.326502 |
0.260695 |
16500 |
0.203500 |
0.326209 |
0.250313 |
18000 |
0.159800 |
0.323539 |
0.239851 |
19500 |
0.158200 |
0.310694 |
0.230028 |
21000 |
0.132800 |
0.338318 |
0.229283 |
22500 |
0.112800 |
0.336765 |
0.224145 |
24000 |
0.103600 |
0.350208 |
0.227073 |
25500 |
0.091400 |
0.353609 |
0.221589 |
27000 |
0.084400 |
0.367826 |
0.212565 |
使用方法
模型可直接按以下方式使用:
import librosa
import warnings
from transformers import AutoProcessor, AutoModelForCTC
from datasets import Dataset, DatasetDict
from datasets import load_metric
wer_metric = load_metric("wer")
wolof = pd.read_csv('Test.csv')
wolof = DatasetDict({'test': Dataset.from_pandas(wolof)})
chars_to_ignore_regex = '[\"\?\.\!\-\;\:\(\)\,]'
def remove_special_characters(batch):
batch["transcription"] = re.sub(chars_to_ignore_regex, '', batch["transcription"]).lower() + " "
return batch
wolof = wolof.map(remove_special_characters)
processor = AutoProcessor.from_pretrained("abdouaziiz/wav2vec2-xls-r-300m-wolof-lm")
model = AutoModelForCTC.from_pretrained("abdouaziiz/wav2vec2-xls-r-300m-wolof-lm")
warnings.filterwarnings("ignore")
def speech_file_to_array_fn(batch):
speech_array, sampling_rate = librosa.load(batch["file"], sr = 16000)
batch["speech"] = speech_array.astype('float16')
batch["sampling_rate"] = sampling_rate
batch["target_text"] = batch["transcription"]
return batch
wolof = wolof.map(speech_file_to_array_fn, remove_columns=wolof.column_names["test"], num_proc=1)
def map_to_result(batch):
model.to("cuda")
input_values = processor(
batch["speech"],
sampling_rate=batch["sampling_rate"],
return_tensors="pt"
).input_values.to("cuda")
with torch.no_grad():
logits = model(input_values).logits
pred_ids = torch.argmax(logits, dim=-1)
batch["pred_str"] = processor.batch_decode(pred_ids)[0]
return batch
results = wolof["test"].map(map_to_result)
print("测试WER: {:.3f}".format(wer_metric.compute(predictions=results["pred_str"], references=results["transcription"])))
附注:
通过以下方式可以进一步提升结果:
- 结合Wav2vec2与语言模型
- 基于数据文本构建拼写检查器
- 使用句子编辑距离