语言: 英语
许可证: Apache-2.0
标签:
- 音素识别
- 训练生成
数据集:
- w11wo/ljspeech_phonemes
基础模型: Wav2Vec2-Base
推理参数:
应用函数: 无
模型索引:
- 名称: Wav2Vec2 LJSpeech Gruut
结果:
- 任务:
类型: 自动语音识别
名称: 自动语音识别
数据集:
名称: LJSpeech
类型: ljspeech_phonemes
指标:
- 类型: PER
值: 0.0099
名称: 测试PER(无重音)
- 类型: CER
值: 0.0058
名称: 测试CER(无重音)
Wav2Vec2 LJSpeech Gruut
Wav2Vec2 LJSpeech Gruut 是一个基于 wav2vec 2.0 架构的自动语音识别模型。该模型是在 LJSpech Phonemes 数据集上对 Wav2Vec2-Base 进行微调的版本。
与训练预测单词序列不同,该模型被训练用于预测音素序列,例如 ["h", "ɛ", "l", "ˈoʊ", "w", "ˈɚ", "l", "d"]
。因此,模型的词汇表 包含了 gruut 中发现的国际音标(IPA)音素。
该模型使用 HuggingFace 的 PyTorch 框架进行训练。所有训练均在配备了 Tesla A100 GPU 的 Google Cloud Engine VM 上完成。训练使用的所有脚本可在 文件和版本 标签页中找到,同时 训练指标 通过 Tensorboard 记录。
模型
模型 |
参数量 |
架构 |
训练/验证数据(文本) |
wav2vec2-ljspeech-gruut |
94M |
wav2vec 2.0 |
LJSpech Phonemes 数据集 |
评估结果
模型在评估中取得以下结果:
数据集 |
PER(无重音) |
CER(无重音) |
LJSpech Phonemes 测试数据 |
0.99% |
0.58% |
使用方法
from transformers import AutoProcessor, AutoModelForCTC, Wav2Vec2Processor
import librosa
import torch
from itertools import groupby
from datasets import load_dataset
def decode_phonemes(
ids: torch.Tensor, processor: Wav2Vec2Processor, ignore_stress: bool = False
) -> str:
"""类CTC解码。先去除连续重复项,再去除特殊标记。"""
ids = [id_ for id_, _ in groupby(ids)]
special_token_ids = processor.tokenizer.all_special_ids + [
processor.tokenizer.word_delimiter_token_id
]
phonemes = [processor.decode(id_) for id_ in ids if id_ not in special_token_ids]
prediction = " ".join(phonemes)
if ignore_stress == True:
prediction = prediction.replace("ˈ", "").replace("ˌ", "")
return prediction
checkpoint = "bookbot/wav2vec2-ljspeech-gruut"
model = AutoModelForCTC.from_pretrained(checkpoint)
processor = AutoProcessor.from_pretrained(checkpoint)
sr = processor.feature_extractor.sampling_rate
ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
audio_array = ds[0]["audio"]["array"]
inputs = processor(audio_array, return_tensors="pt", padding=True)
with torch.no_grad():
logits = model(inputs["input_values"]).logits
predicted_ids = torch.argmax(logits, dim=-1)
prediction = decode_phonemes(predicted_ids[0], processor, ignore_stress=True)
训练过程
训练超参数
训练中使用的超参数如下:
学习率
: 0.0001
训练批次大小
: 16
评估批次大小
: 8
随机种子
: 42
梯度累积步数
: 2
总训练批次大小
: 32
优化器
: Adam,参数 betas=(0.9,0.999)
和 epsilon=1e-08
学习率调度器类型
: 线性
学习率预热步数
: 1000
训练轮数
: 30.0
混合精度训练
: Native AMP
训练结果
训练损失 |
轮次 |
步数 |
验证损失 |
WER |
CER |
无记录 |
1.0 |
348 |
2.2818 |
1.0 |
1.0 |
2.6692 |
2.0 |
696 |
0.2045 |
0.0527 |
0.0299 |
0.2225 |
3.0 |
1044 |
0.1162 |
0.0319 |
0.0189 |
0.2225 |
4.0 |
1392 |
0.0927 |
0.0235 |
0.0147 |
0.0868 |
5.0 |
1740 |
0.0797 |
0.0218 |
0.0143 |
0.0598 |
6.0 |
2088 |
0.0715 |
0.0197 |
0.0128 |
0.0598 |
7.0 |
2436 |
0.0652 |
0.0160 |
0.0103 |
0.0447 |
8.0 |
2784 |
0.0571 |
0.0152 |
0.0095 |
0.0368 |
9.0 |
3132 |
0.0608 |
0.0163 |
0.0112 |
0.0368 |
10.0 |
3480 |
0.0586 |
0.0137 |
0.0083 |
0.0303 |
11.0 |
3828 |
0.0641 |
0.0141 |
0.0085 |
0.0273 |
12.0 |
4176 |
0.0656 |
0.0131 |
0.0079 |
0.0232 |
13.0 |
4524 |
0.0690 |
0.0133 |
0.0082 |
0.0232 |
14.0 |
4872 |
0.0598 |
0.0128 |
0.0079 |
0.0189 |
15.0 |
5220 |
0.0671 |
0.0121 |
0.0074 |
0.017 |
16.0 |
5568 |
0.0654 |
0.0114 |
0.0069 |
0.017 |
17.0 |
5916 |
0.0751 |
0.0118 |
0.0073 |
0.0146 |
18.0 |
6264 |
0.0653 |
0.0112 |
0.0068 |
0.0127 |
19.0 |
6612 |
0.0682 |
0.0112 |
0.0069 |
0.0127 |
20.0 |
6960 |
0.0678 |
0.0114 |
0.0068 |
0.0114 |
21.0 |
7308 |
0.0656 |
0.0111 |
0.0066 |
0.0101 |
22.0 |
7656 |
0.0669 |
0.0109 |
0.0066 |
0.0092 |
23.0 |
8004 |
0.0677 |
0.0108 |
0.0065 |
0.0092 |
24.0 |
8352 |
0.0653 |
0.0104 |
0.0063 |
0.0088 |
25.0 |
8700 |
0.0673 |
0.0102 |
0.0063 |
0.0074 |
26.0 |
9048 |
0.0669 |
0.0105 |
0.0064 |
0.0074 |
27.0 |
9396 |
0.0707 |
0.0101 |
0.0061 |
0.0066 |
28.0 |
9744 |
0.0673 |
0.0100 |
0.0060 |
0.0058 |
29.0 |
10092 |
0.0689 |
0.0100 |
0.0059 |
0.0058 |
30.0 |
10440 |
0.0683 |
0.0099 |
0.0058 |
免责声明
请注意,预训练数据集可能存在的偏见可能会影响该模型的结果。
作者
Wav2Vec2 LJSpeech Gruut 由 Wilson Wongso 训练和评估。所有计算和开发均在 Google Cloud 上完成。
框架版本
- Transformers 4.26.0.dev0
- Pytorch 1.10.0
- Datasets 2.7.1
- Tokenizers 0.13.2
- Gruut 2.3.4