基于Wav2vec 2.0的维度语音情感识别模型
请注意,本模型仅供研究用途。如需商业授权版本(基于更大量数据训练),请联系audEERING。该模型接收原始音频信号输入,输出唤醒度、支配度和效价三个维度的预测值(范围约0...1),同时提供最后 transformer 层的池化状态。本模型通过对Wav2Vec2-Large-Robust进行微调训练,数据源自MSP-Podcast(v1.7)数据集。在微调前,模型 transformer 层数已从24层剪枝至12层。模型ONNX版本详见doi:10.5281/zenodo.6221127,更多细节可参阅相关论文及教程。
使用说明
import numpy as np
import torch
import torch.nn as nn
from transformers import Wav2Vec2Processor
from transformers.models.wav2vec2.modeling_wav2vec2 import (
Wav2Vec2Model,
Wav2Vec2PreTrainedModel,
)
class 回归头部(nn.Module):
"""分类头部模块"""
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.dropout = nn.Dropout(config.final_dropout)
self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
def forward(self, features, **kwargs):
x = features
x = self.dropout(x)
x = self.dense(x)
x = torch.tanh(x)
x = self.dropout(x)
x = self.out_proj(x)
return x
class 情感模型(Wav2Vec2PreTrainedModel):
"""语音情感分类器"""
def __init__(self, config):
super().__init__(config)
self.config = config
self.wav2vec2 = Wav2Vec2Model(config)
self.classifier = RegressionHead(config)
self.init_weights()
def forward(self, input_values):
outputs = self.wav2vec2(input_values)
hidden_states = outputs[0]
hidden_states = torch.mean(hidden_states, dim=1)
logits = self.classifier(hidden_states)
return hidden_states, logits
device = 'cpu'
model_name = 'audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim'
processor = Wav2Vec2Processor.from_pretrained(model_name)
model = EmotionModel.from_pretrained(model_name).to(device)
sampling_rate = 16000
signal = np.zeros((1, sampling_rate), dtype=np.float32)
def 处理函数(
x: np.ndarray,
sampling_rate: int,
embeddings: bool = False,
) -> np.ndarray:
"""从原始音频预测情感或提取特征向量"""
y = processor(x, sampling_rate=sampling_rate)
y = y['input_values'][0]
y = y.reshape(1, -1)
y = torch.from_numpy(y).to(device)
with torch.no_grad():
y = model(y)[0 if embeddings else 1]
y = y.detach().cpu().numpy()
return y
print(处理函数(signal, sampling_rate))
print(处理函数(signal, sampling_rate, embeddings=True))