库名称:transformers
标签:[]
注意
此前Huggingface权重绑定存在一个错误,导致ESM++的logits与ESMC不一致。该问题现已修复。
ESM++
ESM++是对ESMC(许可证)的忠实实现,支持批处理且兼容标准Huggingface接口,无需依赖ESM Python包。
小型版本对应ESMC的3亿参数版本。
使用🤗 transformers
from transformers import AutoModelForMaskedLM
model = AutoModelForMaskedLM.from_pretrained('Synthyra/ESMplusplus_small', trust_remote_code=True)
tokenizer = model.tokenizer
sequences = ['MPRTEIN', 'MSEQWENCE']
tokenized = tokenizer(sequences, padding=True, return_tensors='pt')
output = model(**tokenized)
print(output.logits.shape)
print(output.last_hidden_state.shape)
print(output.loss)
ESM++与ESM2类似,支持序列和标记级别的分类任务。初始化时传入标签数量即可。
from transformers import AutoModelForSequenceClassification, AutoModelForTokenClassification
model = AutoModelForSequenceClassification.from_pretrained('Synthyra/ESMplusplus_small', num_labels=2, trust_remote_code=True)
logits = model(**tokenized).logits
print(logits.shape)
ESM++权重默认为fp32格式。可按如下方式加载fp16或bf16版本:
import torch
model = AutoModelForMaskedLM.from_pretrained('Synthyra/ESMplusplus_small', trust_remote_code=True, torch_dtype=torch.float16)
无需新代码即可嵌入整个数据集
要快速嵌入蛋白质序列列表,只需调用embed_dataset方法。序列会按长度排序以减少填充标记,因此初始进度条预估时间通常远长于实际耗时。
示例:
embedding_dict = model.embed_dataset(
sequences=[
'MALWMRLLPLLALLALWGPDPAAA', ...
],
tokenizer=model.tokenizer,
batch_size=2,
max_len=512,
full_embeddings=False,
embed_dtype=torch.float32,
pooling_types=['mean', 'cls'],
num_workers=0,
sql=False,
sql_db_path='embeddings.db',
save=True,
save_path='embeddings.pth',
)
model.embed_dataset()
参数说明:
sequences: 蛋白质序列列表
batch_size: 处理批大小
max_len: 最大序列长度
full_embeddings: 返回完整残基级嵌入(True)还是池化结果(False)
pooling_type: 池化类型('mean'或'cls')
num_workers: 数据加载工作进程数,0表示主进程
sql: 是否将嵌入存储到SQLite数据库(仅支持float32格式)
sql_db_path: SQLite数据库路径
返回值:
序列到嵌入向量的字典映射,若sql=True则返回None
注意:
- sql=True时嵌入结果仅能存储为float32格式
- 需要实时流式处理超大数据集时推荐使用sql选项
- 若嵌入字典可完全载入内存,推荐使用save=True选项
- 若同时设置sql和save参数,优先使用sql选项
- 若数据库或.pth文件已存在,将优先扫描已嵌入的序列
- 序列将截断至max_len并按长度降序排列以加速处理
使用🤗 peft进行微调
model = AutoModelForSequenceClassification.from_pretrained('Synthyra/ESMplusplus_small', num_labels=2, trust_remote_code=True)
target_modules = ["layernorm_qkv.1", "out_proj", "query", "key", "value", "dense"]
lora_config = LoraConfig(
r=8,
lora_alpha=16,
lora_dropout=0.01,
bias="none",
target_modules=target_modules,
)
model = get_peft_model(model, lora_config)
for param in model.classifier.parameters():
param.requires_grad = True
更完整的微调示例请参见我们的示例脚本。
返回注意力图
通常使用F.scaled_dot_product_attention进行注意力计算(比原生PyTorch实现更快),但该方法无法返回注意力图。
ESM++可通过设置output_attentions
参数手动计算注意力图。这会显著降低速度,建议仅在需要时使用。
output = model(**tokenized, output_attentions=True)
att = output.attentions
len(att)
浮点精度与实现的对比
我们测量了fp32权重与fp16/bf16版本最后隐藏状态的差异。发现fp16更接近fp32输出,因此推荐加载fp16版本。
需注意ESM包默认以fp32加载ESMC但会转为bf16,在推理/训练中各有利弊——半精度版本可自由选择。
FP32 vs. FP16平均MSE: 0.00000003
FP32 vs. BF16平均MSE: 0.00000140
我们还测量了ESM++与ESMC(均为bfloat16)在1000条随机序列上的输出差异,确保与ESM包的一致性。
最后隐藏状态平均MSE: 7.74e-10
可通过.from_pretrained_esm('esmc_300m')替代.from_pretrained(...)从ESM包加载权重。
模型探针
我们采用类似先前论文的线性探针技术,评估不同PLM在标准数据集上池化隐藏状态与有效属性之间的内在关联。ESMC(及ESM++)表现优异。
下图展示了归一化后的性能对比(负对照为随机向量嵌入,基准为最佳表现者)。分类任务取MCC与F1(多标签任务取F1max)平均值,回归任务取Spearman rho与R2平均值。

推理速度
我们测试了H100上不同ESM模型的吞吐量。相比ESMC,ESM++通过高效批处理显著提升吞吐量,即使在批大小为1时也更快。对于长序列,ESM++ small甚至快于ESM2-35M!
在Linux系统PyTorch > 2.5环境下能获得最大增益。

引用
若使用本实现或相关工作,请引用(同时引用ESMC预印本)。
@misc {ESMPlusPlus,
author = { Hallee, L. and Bichara, D. and Gleghorn, J, P. },
title = { ESMPlusPlus },
year = 2024,
url = { https://huggingface.co/Synthyra/ESMplusplus_small },
doi = { 10.57967/hf/3725 },
publisher = { Hugging Face }
}