library_name: transformers
tags: []
重要说明
此前Huggingface权重绑定存在一个bug,导致ESM++的输出逻辑与ESMC不一致。该问题现已修复。
ESM++模型
ESM++是对ESMC(非商业许可协议)的忠实实现,支持批处理操作且完全兼容Huggingface生态,无需依赖ESM官方Python包。
其中large版本对应ESMC的6亿参数模型。
使用🤗 transformers调用
from transformers import AutoModelForMaskedLM
model = AutoModelForMaskedLM.from_pretrained('Synthyra/ESMplusplus_large', trust_remote_code=True)
tokenizer = model.tokenizer
sequences = ['MPRTEIN', 'MSEQWENCE']
tokenized = tokenizer(sequences, padding=True, return_tensors='pt')
output = model(**tokenized)
print(output.logits.shape)
print(output.last_hidden_state.shape)
print(output.loss)
与ESM2类似,ESM++也支持序列级和token级分类任务。初始化时指定标签数量即可:
from transformers import AutoModelForSequenceClassification, AutoModelForTokenClassification
model = AutoModelForSequenceClassification.from_pretrained('Synthyra/ESMplusplus_large', num_labels=2, trust_remote_code=True)
logits = model(**tokenized).logits
print(logits.shape)
ESM++默认加载fp32精度权重,支持fp16/bf16加载方式:
import torch
model = AutoModelForMaskedLM.from_pretrained('Synthyra/ESMplusplus_large', trust_remote_code=True, torch_dtype=torch.float16)
零代码实现数据集嵌入
通过embed_dataset方法可快速嵌入蛋白质序列列表。系统会自动按长度排序减少填充token,初始进度条预估时间通常远长于实际耗时。
示例:
embedding_dict = model.embed_dataset(
sequences=[
'MALWMRLLPLLALLALWGPDPAAA', ...
],
tokenizer=model.tokenizer,
batch_size=2,
max_len=512,
full_embeddings=False,
embed_dtype=torch.float32,
pooling_types=['mean', 'cls'],
num_workers=0,
sql=False,
sql_db_path='embeddings.db',
save=True,
save_path='embeddings.pth',
)
model.embed_dataset()参数说明:
sequences: 蛋白质序列列表
batch_size: 处理批次大小
max_len: 最大序列长度
full_embeddings: 返回完整残基级嵌入(True)或池化结果(False)
pooling_type: 池化类型('mean'或'cls')
num_workers: 数据加载线程数,0表示主进程处理
sql: 是否存入SQLite数据库(强制float32精度)
sql_db_path: 数据库路径
返回:
序列到嵌入向量的映射字典(sql=True时返回None)
注意:
- SQL模式适合需要实时流式处理超大规模训练数据的场景
- save模式适合能全量加载嵌入字典到内存的情况
- 当sql和save同时设置时优先使用SQL存储
- 会自动检测已有数据库/pth文件中的缓存序列
- 序列将按长度降序排列并截断至max_len以优化处理速度
使用🤗 peft微调
model = AutoModelForSequenceClassification.from_pretrained('Synthyra/ESMplusplus_large', num_labels=2, trust_remote_code=True)
target_modules = ["layernorm_qkv.1", "out_proj", "query", "key", "value", "dense"]
lora_config = LoraConfig(
r=8,
lora_alpha=16,
lora_dropout=0.01,
bias="none",
target_modules=target_modules,
)
model = get_peft_model(model, lora_config)
for param in model.classifier.parameters():
param.requires_grad = True
完整微调示例参见脚本。
获取注意力图谱
默认使用F.scaled_dot_product_attention加速计算,但无法返回注意力图谱。设置output_attentions=True
将切换为手动计算模式(速度较慢,仅需时启用)。
output = model(**tokenized, output_attentions=True)
att = output.attentions
len(att)
精度与实现对比
我们测量了fp32权重在fp16/bf16下的隐藏状态差异,发现fp16更接近fp32输出。需注意ESM官方包默认将ESMC转为bf16,各有优劣,可按需选择半精度类型。
FP16平均MSE: 0.00000003
BF16平均MSE: 0.00000122
在1000条随机序列上对比ESM++与ESMC(均为bfloat16)输出,确保与官方包一致性:
最终隐藏状态平均MSE: 2.46e-09
可通过.from_pretrained_esm('esmc_600m')直接加载ESM官方包权重。
模型探针分析
我们采用类似先前研究的线性探针方法评估隐藏状态与生物特性的关联性。ESMC(及ESM++)表现优异。
下图显示各模型在标准化指标下的表现(分类任务取MCC与F1平均值,回归任务取Spearman ρ与R²平均值):

推理速度基准
在H100显卡上测试不同ESM模型的吞吐量。相比ESMC,ESM++的批处理优化带来显著加速,单批次推理也更快。ESM++ small处理长序列时甚至快于ESM2-35M!Linux系统搭配PyTorch > 2.5可获得最佳性能。

引用声明
使用本实现或相关研究成果时请引用(同时需引用ESMC预印本):
@misc {ESMPlusPlus,
author = { Hallee, L. and Bichara, D. and Gleghorn, J, P. },
title = { ESMPlusPlus },
year = 2024,
url = { https://huggingface.co/Synthyra/ESMplusplus_small },
doi = { 10.57967/hf/3726 },
publisher = { Hugging Face }
}