许可协议: cc-by-nc-nd-4.0
数据集:
- vandijklab/immune-c2s
语言:
- en
标签:
- pytorch
- causal-lm
- scRNA-seq
模型概览
这是由EleutherAI开发的Pythia-160m模型,使用Cell2Sentence方法在完整的单细胞RNA测序数据上进行微调。
Cell2Sentence是一种将大型语言模型适配单细胞转录组学的新方法。我们将单细胞RNA测序数据转化为按表达水平排序的基因名称序列,称为"细胞句子"。
更多细节请参阅下文链接的论文。该模型基于Domínguez等人的免疫组织数据集训练,使用8块A100 40GB GPU进行了约20小时的训练,主要完成以下任务:
- 条件性细胞生成
- 无条件细胞生成
- 细胞类型预测
Cell2Sentence相关链接:
GitHub: https://github.com/vandijklab/cell2sentence-ft
论文: https://www.biorxiv.org/content/10.1101/2023.09.11.557287v3
Pythia相关链接:
GitHub: https://github.com/EleutherAI/pythia
论文: https://arxiv.org/abs/2304.01373
Hugging Face: https://huggingface.co/EleutherAI/pythia-160m
评估
本模型通过K近邻分类和Gromov-Wasserstein(GW)距离进行评估。生成细胞的标签对应其生成提示中使用的细胞类型。真实细胞从保留的测试数据集中有放回地抽样获得。生成细胞通过论文描述的方法转化为表达向量。完整实验细节请参阅论文。
模型 |
k=3 NN (↑) |
k=5 NN (↑) |
k=10 NN (↑) |
k=25 NN (↑) |
GW (↓) |
scGEN |
0.2376 |
0.2330 |
0.2377 |
0.2335 |
315.9505 |
scVI |
0.2436 |
0.2400 |
0.2425 |
0.2348 |
302.1285 |
scDiffusion |
0.2335 |
0.2288 |
0.2368 |
0.2306 |
72.0208 |
scGPT |
0.1838 |
0.1788 |
0.1811 |
0.1882 |
2989.8066 |
C2S (Pythia-160m) |
0.2588 |
0.2565 |
0.2746 |
0.2715 |
54.3040 |
示例代码
我们提供如何使用该模型进行条件性细胞生成的示例,并配备后处理函数以去除重复和无效基因。如需生成完整细胞,应将生成参数max_length
改为9200。但若需完整细胞生成,建议使用A100 GPU以保证推理速度和内存容量。无条件细胞生成和细胞类型预测提示也包括在内,但未提供示例细胞句子来格式化提示。关于如何将表达向量转化为细胞句子的说明,请参阅论文和GitHub仓库。
import json
import re
from collections import Counter
from typing import List
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
def post_process_generated_cell_sentences(
cell_sentence: str,
gene_dictionary: List
):
"""
对生成的细胞句子进行后处理。
移除无效基因并对重复基因的排名取平均值。
参数:
cell_sentence: 生成的细胞句子字符串
gene_dictionary: 基因词汇表列表(全部大写)
返回:
post_processed_sentence: 经过后处理步骤后的生成细胞句子
"""
generated_gene_names = cell_sentence.split(" ")
generated_gene_names = [generated_gene.upper() for generated_gene in generated_gene_names]
generated_gene_names = [gene_name for gene_name in generated_gene_names if gene_name in gene_dictionary]
gene_name_to_occurrences = Counter(generated_gene_names)
post_processed_sentence = generated_gene_names.copy()
for gene_name in gene_name_to_occurrences:
if gene_name_to_occurrences[gene_name] > 1 and gene_name != replace_nonsense_string:
occurrence_positions = [idx for idx, elem in enumerate(post_processed_sentence) if elem == gene_name]
average_position = int(sum(occurrence_positions) / len(occurrence_positions))
post_processed_sentence = [elem for elem in post_processed_sentence if elem != gene_name]
post_processed_sentence.insert(average_position, gene_name)
return post_processed_sentence
genes_path = "pbmc_vocab.json"
with open(vocab_path, "r") as f:
gene_dictionary = json.load(f)
model_name = "vandijklab/pythia-160m-c2s"
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
attn_implementation="flash_attention_2"
).to(torch.device("cuda"))
tokenizer = AutoTokenizer.from_pretrained(model_name)
cell_type = "T细胞"
ccg = f"列举一个{cell_type}细胞中非零表达的基因,按从高到低的顺序排列。"
tokens = tokenizer(ccg, return_tensors='pt')
input_ids = tokens['input_ids'].to(torch.device("cuda"))
attention_mask = tokens['attention_mask'].to(torch.device("cuda"))
with torch.no_grad():
outputs = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
do_sample=True,
max_length=1024,
top_k=50,
top_p=0.95,
)
output_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
cell_sentence = "".join(re.split(r"\?|\.|:", output_text)[1:]).strip()
processed_genes = post_process_generated_cell_sentences(cell_sentence, gene_dictionary)