许可证: mit
数据集:
- adrianhenkel/lucidprots_full_data
管道标签: translation
标签:
- biology
ProstT5模型卡
ProstT5是一种蛋白质语言模型(pLM),能够在蛋白质序列与结构之间进行翻译。

模型详情
模型描述
ProstT5(蛋白质结构序列T5)基于ProtT5-XL-U50,这是一个通过跨度去噪目标在数十亿蛋白质序列上训练的T5模型。ProstT5在AlphaFoldDB提供的1700万高质量3D结构预测蛋白质上微调了ProtT5-XL-U50,实现了蛋白质序列与结构间的翻译。蛋白质结构通过Foldseek引入的3Di标记从3D转换到1D表示。
ProstT5首先通过延续原始的跨度去噪目标学习表示新引入的3Di标记(应用于3Di和氨基酸序列)。第二阶段才训练模型在两种模态间进行翻译。翻译方向由两个特殊标记指示(“<fold2AA>”表示从3Di到氨基酸的翻译,“<AA2fold>”表示反向翻译)。为避免与氨基酸标记冲突,3Di标记转为小写(字母表其他部分相同)。
用途
-
特征提取:推荐使用半精度(fp16)的编码器配合批处理。示例脚本和Colab见链接(替换仓库链接并添加前缀即可适配ProstT5)。
- 原ProtT5仅能嵌入氨基酸序列,而ProstT5新增了对3Di标记(代表3D结构)的嵌入能力。3Di标记可通过Foldseek从3D结构导出,或由ProstT5从氨基酸序列预测生成。
-
“折叠”功能:将氨基酸序列(AA)翻译为3Di结构序列。生成的3Di字符串可与Foldseek结合用于远程同源检测,无需显式计算3D结构。
-
“逆向折叠”功能:将3Di结构序列翻译回氨基酸序列。
快速开始
特征提取示例代码:
from transformers import T5Tokenizer, T5EncoderModel
import torch
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
tokenizer = T5Tokenizer.from_pretrained('Rostlab/ProstT5', do_lower_case=False).to(device)
model = T5EncoderModel.from_pretrained("Rostlab/ProstT5").to(device)
model.full() if device=='cpu' else model.half()
sequence_examples = ["PRTEINO", "strct"]
sequence_examples = [" ".join(list(re.sub(r"[UZOB]", "X", seq))) for seq in sequence_examples]
sequence_examples = [ "<AA2fold> " + s if s.isupper() else "<fold2AA> " + s for s in sequence_examples ]
ids = tokenizer.batch_encode_plus(sequence_examples, add_special_tokens=True, padding="longest", return_tensors='pt').to(device)
with torch.no_grad():
embedding_rpr = model(ids.input_ids, attention_mask=ids.attention_mask)
emb_0 = embedding_rpr.last_hidden_state[0,1:8]
emb_0_per_protein = emb_0.mean(dim=0)
翻译(“折叠”AA→3Di)示例代码:
from transformers import T5Tokenizer, AutoModelForSeq2SeqLM
import torch
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
tokenizer = T5Tokenizer.from_pretrained('Rostlab/ProstT5', do_lower_case=False).to(device)
model = AutoModelForSeq2SeqLM.from_pretrained("Rostlab/ProstT5").to(device)
model.full() if device=='cpu' else model.half()
sequence_examples = ["PRTEINO", "SEQWENCE"]
sequence_examples = [" ".join(list(re.sub(r"[UZOB]", "X", seq))) for seq in sequence_examples]
sequence_examples = ["<AA2fold> " + s for s in sequence_examples]
ids = tokenizer.batch_encode_plus(sequence_examples, add_special_tokens=True, padding="longest", return_tensors='pt').to(device)
gen_kwargs_aa2fold = {
"do_sample": True, "num_beams": 3, "top_p": 0.95,
"temperature": 1.2, "top_k": 6, "repetition_penalty": 1.2
}
with torch.no_grad():
translations = model.generate(
ids.input_ids, attention_mask=ids.attention_mask,
max_length=max_len, min_length=min_len,
early_stopping=True, num_return_sequences=1,
**gen_kwargs_aa2fold
)
decoded_translations = tokenizer.batch_decode(translations, skip_special_tokens=True)
structure_sequences = ["".join(ts.split(" ")) for ts in decoded_translations]
训练详情
训练数据
预训练数据(17M蛋白质的3Di+AA序列)
训练流程
- 第一阶段:使用脚本延续3Di和AA序列的跨度去噪预训练。
- 第二阶段:使用脚本进行双向翻译训练。
超参数
- 训练配置:DeepSpeed(stage-2)、梯度累积(5步)、混合精度(bf16)及PyTorch 2.0的torchInductor编译器。
速度
- 嵌入生成:人类蛋白质组约35分钟(单个RTX A6000 GPU,半精度),平均每蛋白0.1秒。
- 翻译速度:因需逐标记解码,较慢(0.6-2.5秒/蛋白,长度135-406不等),当前仅支持半精度批处理未进一步优化。