Wav2vec2 德语模型
该模型基于wav2vec-large-xlsr-53框架,使用德语CommonVoice数据集进行了微调。
在完整测试数据集上取得了11.26%的词错误率(WER)。模型训练主要采用了Max Idahl提供的代码,并在数据预处理和训练参数上进行了小幅调整。
可通过以下代码转录您的音频文件。请注意输入文件必须是16kHz单声道*.wav格式。使用ffmpeg转换音频文件命令:"ffmpeg -i input.wav -ar 16000 -ac 1 output.wav"。转录过程内存消耗较大(约10GB/10秒),若脚本以"Killed"终止,说明Python解释器内存不足,此时请尝试缩短音频时长。
import soundfile as sf
import torch
from transformers import Wav2Vec2ForCTC, Wav2Vec2Tokenizer
tokenizer = Wav2Vec2Tokenizer.from_pretrained("Noricum/wav2vec2-large-xlsr-53-german")
model = Wav2Vec2ForCTC.from_pretrained("Noricum/wav2vec2-large-xlsr-53-german")
audio_input, _ = sf.read("/path/to/your/audio.wav")
input_values = tokenizer(audio_input, return_tensors="pt").input_values
logits = model(input_values).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = tokenizer.batch_decode(predicted_ids)[0]
print(str(transcription))
评估模型在完整CommonVoice测试集上的表现,可运行以下脚本:
import re
import torch
import torchaudio
from datasets import load_dataset, load_metric
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
test_dataset = load_dataset("common_voice", "de", split="test")
wer = load_metric("wer")
processor = Wav2Vec2Processor.from_pretrained("Noricum/wav2vec2-large-xlsr-53-german")
model = Wav2Vec2ForCTC.from_pretrained("Noricum/wav2vec2-large-xlsr-53-german")
model.to("cuda")
chars_to_ignore_regex = '[\\\\,\\\\?\\\\.\\\\!\\\\-\\\\;\\\\:\\\\"\\\\“]'
resampler = torchaudio.transforms.Resample(48_000, 16_000)
def speech_file_to_array_fn(batch):
batch["sentence"] = re.sub(chars_to_ignore_regex, '', batch["sentence"]).lower()
speech_array, sampling_rate = torchaudio.load(batch["path"])
batch["speech"] = resampler(speech_array).squeeze().numpy()
return batch
test_dataset = test_dataset.map(speech_file_to_array_fn)
def evaluate(batch):
inputs = processor(batch["speech"], sampling_rate=16_000, return_tensors="pt", padding=True)
with torch.no_grad():
logits = model(inputs.input_values.to("cuda"), attention_mask=inputs.attention_mask.to("cuda")).logits
pred_ids = torch.argmax(logits, dim=-1)
batch["pred_strings"] = processor.batch_decode(pred_ids)
return batch
result = test_dataset.map(evaluate, batched=True, batch_size=4)
import jiwer
def chunked_wer(targets, predictions, chunk_size=None):
if chunk_size is None: return jiwer.wer(targets, predictions)
start = 0
end = chunk_size
H, S, D, I = 0, 0, 0, 0
while start < len(targets):
chunk_metrics = jiwer.compute_measures(targets[start:end], predictions[start:end])
H = H + chunk_metrics["hits"]
S = S + chunk_metrics["substitutions"]
D = D + chunk_metrics["deletions"]
I = I + chunk_metrics["insertions"]
start += chunk_size
end += chunk_size
return float(S + D + I) / float(H + S + D)
print("Total (chunk_size=1000), WER: {:2f}".format(100 * chunked_wer(result["pred_strings"], result["sentence"], chunk_size=1000)))
输出结果:Total (chunk_size=1000), WER: 11.256522