语言: 葡萄牙语
数据集:
- 通用语音(common_voice)
- 多语言语音库(mls)
- 天主教大学语音库(cetuc)
- 圣保罗大学商业管理语音库(lapsbm)
- 开源语音库(voxforge)
- TEDx演讲数据集(tedx)
- 语音识别数据集(sid)
评估指标:
标签:
- 音频
- 语音识别
- wav2vec2模型
- 葡萄牙语(pt)
- 葡萄牙语语音语料库
- 自动语音识别(ASR)
- PyTorch框架
许可协议: Apache-2.0
TEDx100-XLSR:基于TEDx数据集的Wav2vec 2.0模型
这是使用TEDx葡萄牙语多语言数据集微调的巴西葡萄牙语Wav2vec模型演示。
本笔记本测试了该模型在其他巴西葡萄牙语数据集上的表现。
数据集 |
训练时长 |
验证时长 |
测试时长 |
CETUC |
|
-- |
5.4小时 |
通用语音(Common Voice) |
|
-- |
9.5小时 |
LaPS BM |
|
-- |
0.1小时 |
MLS |
|
-- |
3.7小时 |
多语言TEDx(葡萄牙语) |
148.8小时 |
-- |
1.8小时 |
SID |
|
-- |
1.0小时 |
VoxForge |
|
-- |
0.1小时 |
总计 |
148.8小时 |
-- |
21.6小时 |
性能摘要
|
CETUC |
通用语音 |
LaPS |
MLS |
SID |
TEDx |
VoxForge |
平均 |
tedx_100 (如下演示) |
0.138 |
0.369 |
0.169 |
0.165 |
0.794 |
0.222 |
0.395 |
0.321 |
tedx_100 + 4-gram语言模型 (如下演示) |
0.123 |
0.414 |
0.171 |
0.152 |
0.982 |
0.215 |
0.395 |
0.350 |
演示
MODEL_NAME = "lgris/tedx100-xlsr"
导入依赖
%%capture
!pip install torch==1.8.2+cu111 torchvision==0.9.2+cu111 torchaudio===0.8.2 -f https://download.pytorch.org/whl/lts/1.8/torch_lts.html
!pip install datasets
!pip install jiwer
!pip install transformers
!pip install soundfile
!pip install pyctcdecode
!pip install https://github.com/kpu/kenlm/archive/master.zip
import jiwer
import torchaudio
from datasets import load_dataset, load_metric
from transformers import (
Wav2Vec2ForCTC,
Wav2Vec2Processor,
)
from pyctcdecode import build_ctcdecoder
import torch
import re
import sys
辅助函数
chars_to_ignore_regex = '[\,\?\.\!\;\:\"]'
def map_to_array(batch):
speech, _ = torchaudio.load(batch["path"])
batch["speech"] = speech.squeeze(0).numpy()
batch["sampling_rate"] = 16_000
batch["sentence"] = re.sub(chars_to_ignore_regex, '', batch["sentence"]).lower().replace("’", "'")
batch["target"] = batch["sentence"]
return batch
def calc_metrics(truths, hypos):
wers = []
mers = []
wils = []
for t, h in zip(truths, hypos):
try:
wers.append(jiwer.wer(t, h))
mers.append(jiwer.mer(t, h))
wils.append(jiwer.wil(t, h))
except:
pass
wer = sum(wers)/len(wers)
mer = sum(mers)/len(mers)
wil = sum(wils)/len(wils)
return wer, mer, wil
def load_data(dataset):
data_files = {'test': f'{dataset}/test.csv'}
dataset = load_dataset('csv', data_files=data_files)["test"]
return dataset.map(map_to_array)
语音识别模型类
class STT:
def __init__(self,
model_name,
device='cuda' if torch.cuda.is_available() else 'cpu',
lm=None):
self.model_name = model_name
self.model = Wav2Vec2ForCTC.from_pretrained(model_name).to(device)
self.processor = Wav2Vec2Processor.from_pretrained(model_name)
self.vocab_dict = self.processor.tokenizer.get_vocab()
self.sorted_dict = {
k.lower(): v for k, v in sorted(self.vocab_dict.items(),
key=lambda item: item[1])
}
self.device = device
self.lm = lm
if self.lm:
self.lm_decoder = build_ctcdecoder(
list(self.sorted_dict.keys()),
self.lm
)
def batch_predict(self, batch):
features = self.processor(batch["speech"],
sampling_rate=batch["sampling_rate"][0],
padding=True,
return_tensors="pt")
input_values = features.input_values.to(self.device)
attention_mask = features.attention_mask.to(self.device)
with torch.no_grad():
logits = self.model(input_values, attention_mask=attention_mask).logits
if self.lm:
logits = logits.cpu().numpy()
batch["predicted"] = []
for sample_logits in logits:
batch["predicted"].append(self.lm_decoder.decode(sample_logits))
else:
pred_ids = torch.argmax(logits, dim=-1)
batch["predicted"] = self.processor.batch_decode(pred_ids)
return batch
下载测试数据集
%%capture
!gdown --id 1HFECzIizf-bmkQRLiQD0QVqcGtOG5upI
!mkdir bp_dataset
!unzip bp_dataset -d bp_dataset/
各数据集测试结果
基础模型测试
stt = STT(MODEL_NAME)
CETUC数据集
词错误率: 0.138
通用语音数据集
词错误率: 0.370
LaPS数据集
词错误率: 0.169
MLS数据集
词错误率: 0.166
SID数据集
词错误率: 0.794
TEDx数据集
词错误率: 0.222
VoxForge数据集
词错误率: 0.395
结合4-gram语言模型的测试
stt = STT(MODEL_NAME, lm='pt-BR-wiki.word.4-gram.arpa')
CETUC数据集
词错误率: 0.123
通用语音数据集
词错误率: 0.415
LaPS数据集
词错误率: 0.171
MLS数据集
词错误率: 0.152
SID数据集
词错误率: 0.983
TEDx数据集
词错误率: 0.216
VoxForge数据集
词错误率: 0.395