license: apache-2.0
base_model: facebook/hubert-base-ls960
tags:
- generated_from_trainer
metrics:
- accuracy
model-index:
- name: hubert-finetuned-animals
results: []
模型空间
点击此处体验模型:
动物声音分类空间
hubert微调动物声音模型
本模型hubert-finetuned-animals
是基于facebook/hubert-base-ls960
微调的动物声音分类专用版本。该模型经过训练可识别ESC-50数据集中动物子类别的各类叫声,在评估集上取得如下结果:
模型说明
HuBERT模型最初通过大量无标注音频数据训练,本版本针对动物声音分类任务进行微调。该模型能精准识别犬吠、鸡鸣、猪哼、牛哞、蛙鸣、猫叫、母鸡啼、昆虫声、羊咩及乌鸦啼等10类动物叫声,可应用于生物声学监测、教育工具及野生动物保护等交互场景。
应用场景与限制
应用场景
• 野生动物研究的音频分析
• 动物科普教育内容生成
• 娱乐类动物声音识别应用
局限性
• 训练数据仅涵盖ESC-50数据集部分动物类别
• 音频质量、背景噪声及训练集未覆盖的叫声变体会显著影响识别效果
训练与评估数据
使用ESC-50数据集的动物声音子集进行微调,该公开数据集专为环境声音分类设计。每个动物类别包含40个样本,为模型提供多样化的训练评估素材。
训练流程
- 预处理:音频文件转为频谱图
- 数据划分:70%训练集/20%测试集/10%验证集
- 微调阶段:进行10个epoch的训练
- 评估机制:每epoch后验证集性能监测
超参数配置
- 学习率:5e-05
- 训练批次:8
- 评估批次:8
- 随机种子:42
- 优化器:Adam(betas=0.9/0.999, epsilon=1e-08)
- 学习率调度:线性预热(比例0.1)
- 训练轮次:10
训练结果
训练损失 |
轮次 |
步数 |
验证损失 |
准确率 |
2.1934 |
1.0 |
45 |
2.1765 |
0.3 |
2.0239 |
2.0 |
90 |
1.8169 |
0.45 |
1.7745 |
3.0 |
135 |
1.4817 |
0.65 |
1.3787 |
4.0 |
180 |
1.2497 |
0.75 |
1.2168 |
5.0 |
225 |
1.0048 |
0.85 |
1.0359 |
6.0 |
270 |
0.9969 |
0.775 |
0.7983 |
7.0 |
315 |
0.7467 |
0.9 |
0.7466 |
8.0 |
360 |
0.7698 |
0.85 |
0.6284 |
9.0 |
405 |
0.6097 |
0.9 |
0.8365 |
10.0 |
450 |
0.5596 |
0.95 |
框架版本
- Transformers 4.33.2
- Pytorch 2.0.1+cu118
- Datasets 2.14.5
- Tokenizers 0.13.3
GitHub仓库
动物声音分类项目
本地调用示例
import librosa
import torch
from transformers import HubertForSequenceClassification, Wav2Vec2FeatureExtractor
model_name = "ardneebwar/wav2vec2-animal-sounds-finetuned-hubert-finetuned-animals"
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_name)
model = HubertForSequenceClassification.from_pretrained(model_name)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()
def predict_audio_class(audio_file):
speech, sr = librosa.load(audio_file, sr=16000)
input_values = feature_extractor(speech, return_tensors="pt", sampling_rate=16000).input_values
input_values = input_values.to(device)
with torch.no_grad():
logits = model(input_values).logits
predicted_id = torch.argmax(logits, dim=-1)
return model.config.id2label[predicted_id.item()]
audio_path = "待识别音频.wav"
print(f"预测类别: {predict_audio_class(audio_path)}")