license: apache-2.0
ProLLaMA:面向多任务蛋白质语言处理的蛋白质大语言模型
论文详情请见arxiv
Github项目地址
ProLLaMA基于Llama-2-7b模型构建,请遵守Llama2的许可协议。
输入格式:
输入模型的指令需遵循以下格式:
[按超家族生成] 超家族=<xxx>
或
[判定超家族] 序列=<yyy>
输入示例如下:
[按超家族生成] 超家族=<含锚蛋白重复结构域超家族>
# 可指定蛋白质序列的首个氨基酸:
[按超家族生成] 超家族=<含锚蛋白重复结构域超家族> 序列=<MKRVL
[判定超家族] 序列=<MAPGGMPREFPSFVRTLPEADLGYPALRGWVLQGERGCVLYWEAVTEVALPEHCHAECWGVVVDGRMELMVDGYTRVYTRGDLYVVPPQARHRARVFPGFRGVEHLSDPDLLPVRKR>
所有可选超家族列表请参见此文件。
快速使用:
CUDA_VISIBLE_DEVICES=0 python main.py --model "GreatCaptainNemo/ProLLaMA" --interactive
import argparse
import json, os
import torch
from transformers import LlamaForCausalLM, LlamaTokenizer
from transformers import GenerationConfig
from tqdm import tqdm
generation_config = GenerationConfig(
temperature=0.2,
top_k=40,
top_p=0.9,
do_sample=True,
num_beams=1,
repetition_penalty=1.2,
max_new_tokens=400
)
parser = argparse.ArgumentParser()
parser.add_argument('--model', default=None, type=str,help="模型本地路径。若为None,则从HuggingFace下载模型")
parser.add_argument('--interactive', action='store_true',help="若为True,可交互式输入指令;若为False,则需通过input_file输入指令")
parser.add_argument('--input_file', default=None, help="包含所有输入指令的文件(每行一条指令)")
parser.add_argument('--output_file', default=None, help="输出结果保存路径")
args = parser.parse_args()
if __name__ == '__main__':
if args.interactive and args.input_file:
raise ValueError("交互模式已开启,但input_file参数不为空")
if (not args.interactive) and (args.input_file is None):
raise ValueError("非交互模式下input_file参数不能为空")
if args.input_file and (args.output_file is None):
raise ValueError("指定了input_file但未指定output_file")
load_type = torch.bfloat16
if torch.cuda.is_available():
device = torch.device(0)
else:
raise ValueError("未检测到可用GPU")
model = LlamaForCausalLM.from_pretrained(
args.model,
torch_dtype=load_type,
low_cpu_mem_usage=True,
device_map='auto',
quantization_config=None
)
tokenizer = LlamaTokenizer.from_pretrained(args.model)
model.eval()
with torch.no_grad():
if args.interactive:
while True:
raw_input_text = input("请输入:")
if len(raw_input_text.strip())==0:
break
input_text = raw_input_text
input_text = tokenizer(input_text,return_tensors="pt")
generation_output = model.generate(
input_ids = input_text["input_ids"].to(device),
attention_mask = input_text['attention_mask'].to(device),
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
generation_config = generation_config,
output_attentions=False
)
s = generation_output[0]
output = tokenizer.decode(s,skip_special_tokens=True)
print("输出结果:",output)
print("\n")
else:
outputs=[]
with open(args.input_file, 'r') as f:
examples =f.read().splitlines()
print("开始生成...")
for index, example in tqdm(enumerate(examples),total=len(examples)):
input_text = tokenizer(example,return_tensors="pt")
generation_output = model.generate(
input_ids = input_text["input_ids"].to(device),
attention_mask = input_text['attention_mask'].to(device),
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
generation_config = generation_config
)
s = generation_output[0]
output = tokenizer.decode(s,skip_special_tokens=True)
outputs.append(output)
with open(args.output_file,'w') as f:
f.write("\n".join(outputs))
print("所有输出结果已保存至",args.output_file)
引用:
@article{lv2024prollama,
title={ProLLaMA:面向多任务蛋白质语言处理的蛋白质大语言模型},
author={吕六正浩,林宗英,李浩,刘宇阳,崔佳曦,陈宇谦,袁莉,田永鸿},
journal={arXiv预印本 arXiv:2402.16445},
year={2024}
}