🚀 wav2vec2-xls-r-300m-cs-cv8
本模型是 facebook/wav2vec2-xls-r-300m 在 common_voice 8.0 数据集上的微调版本。它在自动语音识别任务中表现出色,能够将音频准确地转换为文本,为语音相关应用提供了强大的支持。
🚀 快速开始
本模型可直接使用(无需语言模型),以下是使用示例:
import torch
import torchaudio
from datasets import load_dataset
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
test_dataset = load_dataset("mozilla-foundation/common_voice_8_0", "sk", split="test[:2%]")
processor = Wav2Vec2Processor.from_pretrained("comodoro/wav2vec2-xls-r-300m-sk-cv8")
model = Wav2Vec2ForCTC.from_pretrained("comodoro/wav2vec2-xls-r-300m-sk-cv8")
resampler = torchaudio.transforms.Resample(48_000, 16_000)
def speech_file_to_array_fn(batch):
speech_array, sampling_rate = torchaudio.load(batch["path"])
batch["speech"] = resampler(speech_array).squeeze().numpy()
return batch
test_dataset = test_dataset.map(speech_file_to_array_fn)
inputs = processor(test_dataset[:2]["speech"], sampling_rate=16_000, return_tensors="pt", padding=True)
with torch.no_grad():
logits = model(inputs.input_values, attention_mask=inputs.attention_mask).logits
predicted_ids = torch.argmax(logits, dim=-1)
print("Prediction:", processor.batch_decode(predicted_ids))
print("Reference:", test_dataset[:2]["sentence"])
✨ 主要特性
- 微调优化:基于
facebook/wav2vec2-xls-r-300m
在 common_voice 8.0 数据集上进行微调,更适配特定语音识别任务。
- 多指标评估:在评估集上提供了 WER(词错误率)和 CER(字符错误率)等指标,方便衡量模型性能。
💻 使用示例
基础用法
import torch
import torchaudio
from datasets import load_dataset
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
test_dataset = load_dataset("mozilla-foundation/common_voice_8_0", "sk", split="test[:2%]")
processor = Wav2Vec2Processor.from_pretrained("comodoro/wav2vec2-xls-r-300m-sk-cv8")
model = Wav2Vec2ForCTC.from_pretrained("comodoro/wav2vec2-xls-r-300m-sk-cv8")
resampler = torchaudio.transforms.Resample(48_000, 16_000)
def speech_file_to_array_fn(batch):
speech_array, sampling_rate = torchaudio.load(batch["path"])
batch["speech"] = resampler(speech_array).squeeze().numpy()
return batch
test_dataset = test_dataset.map(speech_file_to_array_fn)
inputs = processor(test_dataset[:2]["speech"], sampling_rate=16_000, return_tensors="pt", padding=True)
with torch.no_grad():
logits = model(inputs.input_values, attention_mask=inputs.attention_mask).logits
predicted_ids = torch.argmax(logits, dim=-1)
print("Prediction:", processor.batch_decode(predicted_ids))
print("Reference:", test_dataset[:2]["sentence"])
高级用法
test_dataset = load_dataset("mozilla-foundation/common_voice_8_0", "sk", split="test[:5%]")
import torch
import torchaudio
from datasets import load_dataset
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
processor = Wav2Vec2Processor.from_pretrained("comodoro/wav2vec2-xls-r-300m-sk-cv8")
model = Wav2Vec2ForCTC.from_pretrained("comodoro/wav2vec2-xls-r-300m-sk-cv8")
resampler = torchaudio.transforms.Resample(48_000, 16_000)
def speech_file_to_array_fn(batch):
speech_array, sampling_rate = torchaudio.load(batch["path"])
batch["speech"] = resampler(speech_array).squeeze().numpy()
return batch
test_dataset = test_dataset.map(speech_file_to_array_fn)
inputs = processor(test_dataset[:2]["speech"], sampling_rate=16_000, return_tensors="pt", padding=True)
with torch.no_grad():
logits = model(inputs.input_values, attention_mask=inputs.attention_mask).logits
predicted_ids = torch.argmax(logits, dim=-1)
print("Prediction:", processor.batch_decode(predicted_ids))
print("Reference:", test_dataset[:2]["sentence"])
📚 详细文档
评估
可使用附带的 eval.py
脚本对模型进行评估:
python eval.py --model_id comodoro/wav2vec2-xls-r-300m-sk-cv8 --dataset mozilla-foundation/common_voice_8_0 --split test --config sk
训练和评估数据
训练使用了 Common Voice 8.0 的 train
和 validation
数据集。
训练超参数
训练过程中使用了以下超参数:
属性 |
详情 |
学习率 |
7e-4 |
训练批次大小 |
32 |
评估批次大小 |
8 |
随机种子 |
42 |
梯度累积步数 |
20 |
总训练批次大小 |
640 |
优化器 |
Adam(betas=(0.9,0.999),epsilon=1e-08) |
学习率调度器类型 |
线性 |
学习率调度器热身步数 |
500 |
训练轮数 |
50 |
混合精度训练 |
Native AMP |
框架版本
- Transformers 4.16.0.dev0
- Pytorch 1.10.1+cu102
- Datasets 1.17.1.dev0
- Tokenizers 0.11.0
📄 许可证
本模型使用 Apache-2.0 许可证。