license: apache-2.0
inference: false
pipeline_tag: audio-to-audio
Perceiver AR 符号音频模型
该模型是一个基于 Perceiver AR 架构的符号音频模型(1.34亿参数),在 GiantMIDI-Piano 数据集上预训练了27轮(1.57亿个标记)。它采用旋转位置编码进行相对位置编码,是 perceiver-io 库的训练示例。
模型描述
Perceiver AR 是对纯解码器 Transformer(如 GPT-2)的简单扩展。两者的核心构建模块都是由自注意力层和前馈网络组成的解码器层,其中自注意力使用因果注意力掩码。
Perceiver AR 在其首个注意力层中还交叉关注输入序列的更长前缀。该层是自注意力与交叉注意力的混合层:自注意力作用于输入序列的最后 n 个位置(使用因果注意力掩码),而交叉注意力从最后 n 个位置指向前 m 个位置。输入序列总长度为 m + n。这使得 Perceiver AR 能处理比纯自注意力解码器 Transformer 更长的上下文。

图1. Perceiver AR 的注意力机制(m=8前缀标记,n=3潜在标记)。
混合注意力层输出的 n 个潜在数组对应输入序列的最后 n 个标记。这些潜在标记会经过 L-1 个解码器层的堆叠处理(总注意力层数为 L)。最终层(图1未展示)为每个潜在位置预测目标标记,其权重与输入嵌入层共享。除初始对前缀序列的交叉注意力外,Perceiver AR 在架构上与纯解码器 Transformer 完全相同。
训练过程
该模型以符号音频建模为任务,在 GiantMIDI-Piano 数据集上训练了27轮(1.57亿标记)。数据集中的 MIDI 文件采用 Perceiver AR 论文 的标记化方法(详见 Huang et al (2019) 的A.2节)。所有超参数详见训练脚本。上下文长度设为6144个标记,其中2048个潜在位置,最大前缀长度为4096。每个样本的实际前缀长度在0到4096间随机选择。训练使用 PyTorch Lightning 完成,最终检查点通过库专用转换工具转为🤗模型。
用途与局限
该模型可用于基于用户定义初始潜在标记数量的音频生成,主要用于展示如何用 perceiver-io 库 训练 Perceiver AR 模型。要提升生成音频质量,需使用比 GiantMIDI-Piano 更大规模的训练数据集。
使用示例
使用前需先安装带 audio
扩展的 perceiver-io
库:
pip install perceiver-io[audio]
可通过 PyTorch 直接调用模型生成 MIDI 文件:
import torch
from perceiver.model.audio.symbolic import PerceiverSymbolicAudioModel
from perceiver.data.audio.midi_processor import decode_midi, encode_midi
from pretty_midi import PrettyMIDI
repo_id = "krasserm/perceiver-ar-sam-giant-midi"
model = PerceiverSymbolicAudioModel.from_pretrained(repo_id)
prompt_midi = PrettyMIDI("prompt.mid")
prompt = torch.tensor(encode_midi(prompt_midi)).unsqueeze(0)
output = model.generate(prompt, max_new_tokens=64, num_latents=1, do_sample=True, top_p=0.95, temperature=1.0)
output_midi = decode_midi(output[0].cpu().numpy())
或使用 symbolic-audio-generation
流水线生成 MIDI:
from transformers import pipeline
from pretty_midi import PrettyMIDI
from perceiver.model.audio import symbolic
repo_id = "krasserm/perceiver-ar-sam-giant-midi"
prompt = PrettyMIDI("prompt.mid")
audio_generator = pipeline("symbolic-audio-generation", model=repo_id)
output = audio_generator(prompt, max_new_tokens=64, num_latents=1, do_sample=True, top_p=0.95, temperature=1.0)
也可通过 fluidsynth 渲染 MIDI 符号生成 WAV 文件(需提前安装 fluidsynth):
output = audio_generator(prompt, max_new_tokens=64, num_latents=1, do_sample=True, top_p=0.95, temperature=1.0, render=True)
with open("generated_audio.wav", "wb") as f:
f.write(output["generated_audio_wav"])
音频样本
以下精选样本使用 GiantMIDI-Piano 验证集的多种提示生成(输入提示未包含在音频输出中):
音频样本 |
Top-K |
Top-p |
温度 |
前缀长度 |
潜在标记数 |
样本1 |
- |
0.95 |
0.95 |
4096 |
1 |
样本2 |
- |
0.95 |
1.0 |
4096 |
64 |
样本3 |
- |
0.95 |
1.0 |
1024 |
1 |
样本4 |
15 |
- |
1.0 |
4096 |
16 |
样本5 |
- |
0.95 |
1.0 |
4096 |
1 |
检查点转换
krasserm/perceiver-ar-sam-giant-midi
模型通过以下代码从训练检查点转换:
from perceiver.model.audio.symbolic import convert_checkpoint
convert_checkpoint(
save_dir="krasserm/perceiver-ar-sam-giant-midi",
ckpt_url="https://martin-krasser.com/perceiver/logs-0.8.0/sam/version_1/checkpoints/epoch=027-val_loss=1.944.ckpt",
push_to_hub=True,
)
引用
@inproceedings{hawthorne2022general,
title={General-purpose, long-context autoregressive modeling with perceiver ar},
author={Hawthorne, Curtis and Jaegle, Andrew and Cangea, C{\u{a}}t{\u{a}}lina and others},
booktitle={ICML},
pages={8535--8558},
year={2022},
organization={PMLR}
}