language: pt
datasets:
- common_voice
- mls
- cetuc
- lapsbm
- voxforge
- tedx
- sid
metrics:
- wer
tags:
- audio
- speech
- wav2vec2
- pt
- portuguese-speech-corpus
- automatic-speech-recognition
- speech
- PyTorch
license: apache-2.0
bp500-base100k_voxpopuli: 基于巴西葡萄牙语(BP)数据集的Wav2vec 2.0模型
这是针对巴西葡萄牙语微调的Wav2vec模型演示,使用了以下数据集:
这些数据集被合并构建更大的巴西葡萄牙语训练集。除Common Voice开发/测试集用于验证/测试外,其余数据均用于训练。我们还为所有数据集创建了测试集。
数据集 |
训练集 |
验证集 |
测试集 |
CETUC |
94.0h |
-- |
5.4h |
Common Voice |
37.8h |
8.9h |
9.5h |
LaPS BM |
0.8h |
-- |
0.1h |
MLS |
161.0h |
-- |
3.7h |
TEDx(葡萄牙语) |
148.9h |
-- |
1.8h |
SID |
7.2h |
-- |
1.0h |
VoxForge |
3.9h |
-- |
0.1h |
总计 |
453.6h |
8.9h |
21.6h |
原始模型使用fairseq微调,本演示使用转换后的版本。原始模型下载链接见此。
性能摘要
模型 |
CETUC |
CV |
LaPS |
MLS |
SID |
TEDx |
VF |
平均 |
bp_500-base100k_voxpopuli(下方演示) |
0.142 |
0.201 |
0.052 |
0.224 |
0.102 |
0.317 |
0.048 |
0.155 |
加4-gram语言模型(下方演示) |
0.099 |
0.149 |
0.047 |
0.192 |
0.115 |
0.371 |
0.127 |
0.157 |
转录示例
原文 |
模型输出 |
qual o instagram dele |
qualo está gramedele |
o capitão foi expulso do exército porque era doido |
o capitãl foi exposo do exército porque era doido |
também por que não |
também porque não |
não existe tempo como o presente |
não existe tempo como o presente |
eu pulei para salvar rachel |
eu pulei para salvar haquel |
augusto cezar passos marinho |
augusto cesa passoesmarinho |
演示
MODEL_NAME = "lgris/bp500-base100k_voxpopuli"
依赖安装
%%capture
!pip install torch==1.8.2+cu111 torchvision==0.9.2+cu111 torchaudio===0.8.2 -f https://download.pytorch.org/whl/lts/1.8/torch_lts.html
!pip install datasets
!pip install jiwer
!pip install transformers
!pip install soundfile
!pip install pyctcdecode
!pip install https://github.com/kpu/kenlm/archive/master.zip
核心代码
import jiwer
import torchaudio
from datasets import load_dataset, load_metric
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
from pyctcdecode import build_ctcdecoder
import torch
import re
import sys
chars_to_ignore_regex = '[\,\?\.\!\;\:\"]'
def map_to_array(batch):
speech, _ = torchaudio.load(batch["path"])
batch["speech"] = speech.squeeze(0).numpy()
batch["sampling_rate"] = 16_000
batch["sentence"] = re.sub(chars_to_ignore_regex, '', batch["sentence"]).lower().replace("’", "'")
batch["target"] = batch["sentence"]
return batch
def calc_metrics(truths, hypos):
wers, mers, wils = [], [], []
for t, h in zip(truths, hypos):
try:
wers.append(jiwer.wer(t, h))
mers.append(jiwer.mer(t, h))
wils.append(jiwer.wil(t, h))
except:
pass
return sum(wers)/len(wers), sum(mers)/len(mers), sum(wils)/len(wils)
class STT:
def __init__(self, model_name, device='cuda' if torch.cuda.is_available() else 'cpu', lm=None):
self.model = Wav2Vec2ForCTC.from_pretrained(model_name).to(device)
self.processor = Wav2Vec2Processor.from_pretrained(model_name)
self.device = device
self.lm = lm
if self.lm:
vocab_dict = self.processor.tokenizer.get_vocab()
sorted_dict = {k.lower(): v for k, v in sorted(vocab_dict.items(), key=lambda item: item[1])}
self.lm_decoder = build_ctcdecoder(list(sorted_dict.keys()), self.lm)
def batch_predict(self, batch):
features = self.processor(batch["speech"], sampling_rate=batch["sampling_rate"][0],
padding=True, return_tensors="pt")
input_values = features.input_values.to(self.device)
with torch.no_grad():
logits = self.model(input_values).logits
if self.lm:
batch["predicted"] = [self.lm_decoder.decode(sample) for sample in logits.cpu().numpy()]
else:
batch["predicted"] = self.processor.batch_decode(torch.argmax(logits, dim=-1))
return batch
数据集测试结果
基础模型测试
- CETUC WER: 0.142
- Common Voice WER: 0.201
- LaPS WER: 0.052
- MLS WER: 0.224
- SID WER: 0.102
- TEDx WER: 0.317
- VoxForge WER: 0.048
语言模型增强测试
- CETUC WER: 0.100
- Common Voice WER: 0.149
- LaPS WER: 0.047
- MLS WER: 0.192
- SID WER: 0.115
- TEDx WER: 0.371
- VoxForge WER: 0.127
注:语言模型使用葡萄牙语维基百科文本训练获得。