库名称:transformers
许可证:apache-2.0
标签:[]
流水线标签:音频文本转文本
R1-AQA —— 强化学习超越监督微调:音频问答的案例研究
简介
R1-AQA 是基于 Qwen2-Audio-7B-Instruct
的音频问答(AQA)模型,通过群体相对策略优化(GRPO)算法进行强化学习优化。
该实现仅使用 38k 训练后样本即在 MMAU 基准测试中取得了最先进的性能。
更多细节请参阅我们的 Github 和 技术报告。
我们的主要发现如下:
- GRPO 算法可以直接且有效地应用于音频模态,甚至适用于仅 8.2B 参数的
Qwen2-Audio-7B-Instruct
。
- 仅使用 38k 训练后样本,强化学习即超越监督微调,表明基于 RL 的方法无需大数据集即可有效。
- 显式推理过程对 AQA 任务未显示出显著优势,如何高效利用深度思考或逐步推理仍是待研究的开放问题。
- 大型音频语言模型(LALM)的听觉-语言推理能力仍远落后于人类,表明基于 RL 的方法值得进一步探索。
补充说明:
- AVQA 训练集原包含约 40k 样本,但因部分数据源失效,实际使用约 38k 样本。其他使用 YouTube 源的数据集(如 AudioSet)也存在类似问题。我们认为缺失的 2k 样本对训练结果无显著影响。
- 关于 8.2B 参数的说明基于《Qwen2-Audio 技术报告》。
表格:MMAU 基准测试准确率(%)
模型 |
方法 |
Test-mini |
Test |
Test-mini |
Test |
Test-mini |
Test |
Test-mini |
Test |
- |
人类* |
86.31 |
- |
78.22 |
- |
82.17 |
- |
82.23 |
- |
Gemini Pro 2.0 Flash |
直接推理* |
56.46 |
61.73 |
58.68 |
56.53 |
51.65 |
61.53 |
55.60 |
59.93 |
Audio Flamingo 2 |
直接推理* |
61.56 |
65.10 |
73.95 |
72.90 |
30.93 |
40.26 |
55.48 |
59.42 |
GPT4o + 强能力 |
直接推理* |
57.35 |
55.83 |
49.70 |
51.73 |
64.86 |
68.66 |
57.30 |
58.74 |
Llama-3-8B-Instruct + 强能力 |
直接推理* |
50.75 |
49.10 |
48.93 |
48.93 |
55.25 |
62.70 |
52.10 |
53.57 |
Qwen2-Audio-7B-Instruct |
直接推理* |
54.95 |
45.90 |
50.98 |
53.26 |
42.04 |
45.90 |
49.20 |
52.50 |
SALAMONN |
直接推理* |
41.00 |
40.30 |
34.80 |
33.76 |
25.50 |
24.24 |
33.70 |
32.77 |
Qwen2-Audio-7B-Instruct |
CoTA [1] |
60.06 |
- |
64.30 |
- |
60.70 |
- |
61.71 |
- |
Qwen2-Audio-7B-Instruct |
Zero-Shot-CoT [2] |
61.86 |
- |
56.29 |
- |
55.26 |
- |
57.80 |
- |
Qwen2-Audio-7B-Instruct |
GRPO (Ours) 1️⃣ |
69.37 |
- |
66.77 |
- |
57.36 |
- |
64.50 |
- |
Qwen2-Audio-7B-Instruct |
GRPO (Ours) 2️⃣ |
68.77 |
69.76 |
64.37 |
61.40 |
63.66 |
62.70 |
65.60 |
64.36 |
注释
* 数据源自 MMAU 排行榜。
[1] Xie, Zhifei 等. "Audio-Reasoner: 提升大型音频语言模型的推理能力." arXiv 预印本 arXiv:2503.02318 (2025)。
[2] Ma, Ziyang 等. "Audio-CoT: 探索大型音频语言模型的思维链推理." arXiv 预印本 arXiv:2501.07246 (2025)。
1️⃣ 原始模型,与 Hugging Face 上的版本及技术报告中描述一致。
2️⃣ 提交至 MMAU 排行榜 的模型,经多次训练以获得均衡结果。
推理
import torch
import torchaudio
from transformers import Qwen2AudioForConditionalGeneration, AutoProcessor
model_name = "mispeech/r1-aqa"
processor = AutoProcessor.from_pretrained(model_name)
model = Qwen2AudioForConditionalGeneration.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto")
wav_path = "test-mini-audios/3fe64f3d-282c-4bc8-a753-68f8f6c35652.wav"
waveform, sampling_rate = torchaudio.load(wav_path)
if sampling_rate != 16000:
waveform = torchaudio.transforms.Resample(orig_freq=sampling_rate, new_freq=16000)(waveform)
audios = [waveform[0].numpy()]
question = "根据给定音频,识别说话声音的来源。"
options = ["男性", "女性", "儿童", "机器人"]
prompt = f"{question} 请从以下选项中选择答案:{str(options)}。最终答案请用 <answer> </answer> 输出。"
message = [
{"role": "user", "content": [
{"type": "audio", "audio_url": wav_path},
{"type": "text", "text": prompt}
]}
]
texts = processor.apply_chat_template(message, add_generation_prompt=True, tokenize=False)
inputs = processor(text=texts, audios=audios, sampling_rate=16000, return_tensors="pt", padding=True).to(model.device)
generated_ids = model.generate(**inputs, max_new_tokens=256)
generated_ids = generated_ids[:, inputs.input_ids.size(1):]
response = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
print(response)
引用
@article{li2025reinforcement,
title={强化学习超越监督微调:音频问答的案例研究},
author={李刚 and 刘继忠 and Dinkel, Heinrich and 牛亚东 and 张军波 and 栾健},
journal={arXiv 预印本 arXiv:2503.11197},
year={2025},
url={https://github.com/xiaomi-research/r1-aqa; https://huggingface.co/mispeech/r1-aqa}
}