🚀 Self-RAG 13B模型
本模型是一个13B的 Self-RAG 模型,它能够针对用户的各种查询生成输出,还能生成 反思标记(reflection tokens),以自适应地调用检索系统,并对自身的输出和检索到的段落进行评估。
Self-RAG 在我们的指令跟随语料库上进行训练,这些语料库包含交错的段落和反思标记,采用标准的下一个标记预测目标,从而能够通过细粒度的反馈实现高效且稳定的学习。在推理阶段,我们利用涵盖生成各个方面的反思标记,来采样出最符合用户偏好的输出。更多详细描述请参阅 我们的论文。
🚀 快速开始
✨ 主要特性
- 能够针对用户的多样化查询生成输出。
- 生成反思标记,自适应调用检索系统并评估自身输出和检索段落。
- 在包含交错段落和反思标记的指令跟随语料库上训练,实现高效稳定学习。
- 推理时利用反思标记采样最符合用户偏好的输出。
📦 安装指南
请确保安装 self-rag/requirements.txt 中列出的依赖项。
💻 使用示例
基础用法
from transformers import AutoTokenizer, AutoModelForCausalLM
from vllm import LLM, SamplingParams
model = LLM("selfrag/selfrag_llama2_13b", download_dir="/gscratch/h2lab/akari/model_cache", dtype="half")
sampling_params = SamplingParams(temperature=0.0, top_p=1.0, max_tokens=100, skip_special_tokens=False)
def format_prompt(input, paragraph=None):
prompt = "### Instruction:\n{0}\n\n### Response:\n".format(input)
if paragraph is not None:
prompt += "[Retrieval]<paragraph>{0}</paragraph>".format(paragraph)
return prompt
query_1 = "Leave odd one out: twitter, instagram, whatsapp."
query_2 = "Can you tell me the difference between llamas and alpacas?"
queries = [query_1, query_2]
preds = model.generate([format_prompt(query) for query in queries], sampling_params)
for pred in preds:
print("Model prediction: {0}".format(pred.outputs[0].text))
prompt = format_prompt("Can you tell me the difference between llamas and alpacas?", paragraph="The alpaca (Lama pacos) is a species of South American camelid mammal. It is similar to, and often confused with, the llama. Alpacas are considerably smaller than llamas, and unlike llamas, they were not bred to be working animals, but were bred specifically for their fiber.")
preds = model.generate([prompt], sampling_params)
print([pred.outputs[0].text for pred in preds])
📚 详细文档
输入格式
如 format_prompt
函数中所述,输入应按照以下格式:
### Instruction:\n{instruction}\n\n### Response:\n".format(instruction)
或者,如果有额外输入:
### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
你可以在 ### Response:\n"
之后的任何位置插入段落,但要确保将段落标记为段落标记(即 <paragraph>{0}</paragraph>
)。
🔧 技术细节
- 训练数据:我们的训练数据可在HuggingFace数据集 selfrag_train_data 中获取。
- 训练环境:我们在Stability HPC服务器上使用8个A100 40GB进行训练。
📄 许可证
本项目采用MIT许可证。
引用说明
如果您使用此模型,请引用我们的工作:
@article{asai2023selfrag,
author = {Asai, Akari and Wu, Zeqiu and Wang, Yizhong and Sil, Avirup and Hajishirzi, Hannaneh},
title = {{Self-RAG}: Learning to Retrieve, Generate, and Critique through Self-Reflection},
year = {2023},
journal = { arXiv preprint arXiv:2310.11511 },
URL = {https://arxiv.org/abs/2310.11511}
}