library_name: transformers
tags: []
language:
- en
- fr
- es
- de
- el
- bg
- ru
- tr
- ar
- vi
- th
- zh
- hi
- sw
- ur
datasets:
- allenai/c4

MrT5(MergeT5)是ByT5(Xue等学者,2022)的高效改进版本,通过在编码器中集成令牌删除机制来动态缩短输入序列长度。经过固定数量的编码层处理后,一个可学习的删除门控会决定哪些令牌需要保留或删除。通过将已删除令牌的关键信息有效"合并"到更紧凑的序列中,MrT5为现有字节级模型的实际局限性提供了解决方案。
引用说明
若使用本模型,请引用MrT5论文:
@inproceedings{
kallini2025mrt,
title={MrT5:面向高效字节级语言模型的动态令牌合并技术},
author={Julie Kallini and Shikhar Murty and Christopher D Manning and Christopher Potts and R{\'o}bert Csord{\'a}s},
booktitle={第十三届国际学习表征会议},
year={2025},
url={https://openreview.net/forum?id=VYWBMq1L7H}
}
同时请引用ByT5论文:
@article{xue-etal-2022-byt5,
title = "{B}y{T}5:基于预训练字节到字节模型的去令牌化未来探索",
author = "Xue, Linting and
Barua, Aditya and
Constant, Noah and
Al-Rfou, Rami and
Narang, Sharan and
Kale, Mihir and
Roberts, Adam and
Raffel, Colin",
editor = "Roark, Brian and
Nenkova, Ani",
journal = "计算语言学协会汇刊",
volume = "10",
year = "2022",
address = "剑桥,马萨诸塞州",
publisher = "麻省理工出版社",
url = "https://aclanthology.org/2022.tacl-1.17",
doi = "10.1162/tacl_a_00461",
pages = "291--306",
}
模型详情
本技术说明卡针对12.3亿参数的MrT5大模型(mrt5-large
),该模型是ByT5大模型(google/byt5-large
)的高效改进版,平均可缩短约50%的序列长度。
- 开发团队:Julie Kallini, Shikhar Murty, Christopher D. Manning, Christopher Potts, Róbert Csordás
- 模型类型:MrT5
- 支持语言:英语、法语、西班牙语、德语、希腊语、保加利亚语、俄语、土耳其语、阿拉伯语、越南语、泰语、汉语、印地语、斯瓦希里语和乌尔都语
- 基础模型:google/byt5-large
- 更多信息:
模型架构
MrT5大模型采用标准ByT5大模型的配置:前馈网络维度3840,模型维度1536,36个编码层,12个解码层,每层16个注意力头,总计12.3亿参数。
MrT5新增的删除门控可动态缩减编码器序列长度。本模型在第三编码层后部署该机制,后续所有层均在缩减后的序列上运行。训练时采用δ=0.5的删除率,即第三层后序列长度缩减约50%。该门控机制仅额外引入3000个参数。
模型基于ByT5大模型初始化,仅随机初始化删除门控后继续训练。另一特点是采用softmax1注意力机制。
使用场景
本模型为编码器-解码器架构,主要用于序列到序列任务。虽然可直接用于探索研究,但建议针对具体下游任务进行微调以获得最佳性能。
要启用动态删除功能,请使用配套代码库中的MrT5Trainer专用训练器。该模型作为学术研究用途的基础模型,不建议直接用于生产环境。
偏差与风险
语言模型可能产生社会偏见或有害内容(参见Bender等, 2021;Bommasani等, 2022;Liang等, 2022)。本模型未经安全微调,敏感场景需谨慎使用。
快速开始
与ByT5类似,MrT5直接处理UTF-8字节流,无需分词器。加载时需设置trust_remote_code=True
:
from transformers import AutoModelForSeq2SeqLM
import torch
model = AutoModelForSeq2SeqLM.from_pretrained('stanfordnlp/mrt5-large', trust_remote_code=True)
input_ids = torch.tensor([list("生活就像一盒巧克力".encode("utf-8"))]) + 3
labels = torch.tensor([list("Life is like a box of chocolates.".encode("utf-8"))]) + 3
loss = model(input_ids, labels=labels, hard_delete=True).loss
批处理场景可使用ByT5的分词器:
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
model = AutoModelForSeq2SeqLM.from_pretrained('stanfordnlp/mrt5-large', trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained('google/byt5-large')
model_inputs = tokenizer(["生活就像一盒巧克力", "今天是周一"], padding="longest", return_tensors="pt")
labels = tokenizer(["Life is like a box of chocolates.", "Today is Monday."], padding="longest", return_tensors="pt").input_ids
loss = model(**model_inputs, labels=labels, hard_delete=True).loss
训练详情
训练数据
采用多语言C4语料库(Raffel等, 2020;Xue等, 2021),涵盖15种语言并确保各语言字节量均衡。
训练过程
基于ByT5的跨度损坏预训练目标:用哨兵令牌标记随机掩码的字节跨度(平均跨度长度20令牌,噪声密度15%)。
优化配置
- 训练步数:5,000步
- 批次规模:2^20令牌(编码器序列长度1024,有效批次1024)
- 优化器:AdamW(初始学习率1e-4,线性衰减)
- 删除控制:PI控制器(目标删除率δ=0.5)
- 正则化:采用论文附录D所述的注意力分数正则化
环境影响
- 硬件:NVIDIA A100-SXM4-80GB
- GPU数量:4
- 训练时长:约73小时
- 计算平台:斯坦福NLP集群
技术说明卡作者
Julie Kallini
kallini@stanford.edu