基础模型:
- BAAI/bge-m3
库名称: treehop-rag
许可证: mit
管道标签: 文本排序
标签:
- 信息检索
- 检索增强生成
- 模型中心混合
- 多跳问答
- pytorch模型中心混合
TreeHop:高效生成与过滤多跳问答中的下一查询嵌入

目录
简介
TreeHop是一个轻量级的嵌入级框架,旨在解决检索增强生成(RAG)领域中传统递归检索范式的计算效率问题。通过消除基于LLM的迭代查询重写的需求,TreeHop显著降低了延迟,同时保持了最先进的性能。它通过动态查询嵌入更新和剪枝策略实现这一点,从而实现了简化的“检索-嵌入-检索”工作流程。

为何选择TreeHop进行多跳检索?
- 处理复杂查询:现实世界的问题通常需要多次跳转才能检索到相关信息,传统检索方法难以应对。
- 成本效益高:2500万参数与现有查询重写器的数十亿参数相比,显著降低了计算开销。
- 速度快:与迭代LLM方法相比,推理速度快99%,非常适合对响应速度要求高的工业应用。
- 性能优异:在控制检索段落数量的情况下保持高召回率,确保相关性而不压垮系统。

系统要求
Ubuntu 18.06 LTS+ 或 MacOS Big Sur+。
最低要求Nvidia GPU或Apple Metal,32GB内存。
用于复现论文需16GB系统内存,训练需64GB。
硬盘需50GB可用空间。
Python环境
请参考requirements.txt。
预备知识
本仓库附带评估嵌入数据库用于复现目的。激活git LFS后使用git lfs clone [仓库链接]
克隆仓库,或在现有本地仓库中通过以下命令拉取数据:
git lfs pull
嵌入数据库
我们采用BGE-m3生成嵌入,并在此基础上训练TreeHop模型进行多跳检索。
运行以下两个脚本生成所有必要的训练和评估嵌入数据库。
如果不计划训练TreeHop,则无需运行这些脚本,因为仓库中已提供所有必要的评估嵌入数据库。
python init_train_vectors.py
python init_multihop_rag.py
使用TreeHop进行多跳检索:操作指南
以下示例使用MultiHop RAG评估数据集。
本仓库附带运行示例所需的文件,参见预备知识。
from tree_hop import TreeHopModel
from passage_retrieval import MultiHopRetriever
EVALUATE_DATASET = "multihop_rag"
tree_hop_model = TreeHopModel.from_pretrained("allen-li1231/treehop-rag")
retriever = MultiHopRetriever(
"BAAI/bge-m3",
passages=f"embedding_data/{EVALUATE_DATASET}/eval_passages.jsonl",
passage_embeddings=f"embedding_data/{EVALUATE_DATASET}/eval_content_dense.npy",
tree_hop_model=tree_hop_model,
projection_size=1024,
save_or_load_index=True,
indexing_batch_size=10240,
index_device="cuda"
)
:bell: 注意
retriever
的multihop_search_passages
方法支持检索单个查询和批量查询。
对于单个查询:
retrieve_result = retriever.multihop_search_passages(
"Engadget是否在The Verge报道三星Galaxy Buds 2折扣之前报道了13.6英寸MacBook Air的折扣?",
n_hop=2,
top_n=5
)
对于批量查询:
LIST_OF_QUESTIONS = [
"Engadget是否在The Verge报道三星Galaxy Buds 2折扣之前报道了13.6英寸MacBook Air的折扣?",
"'The Independent - Travel'是否在'Essentially Sports'提及Jeff Shiffrin的滑雪习惯之前报道了Tremblant滑雪度假村?"
]
retrieve_result = retriever.multihop_search_passages(
LIST_OF_QUESTIONS,
n_hop=2,
top_n=5,
index_batch_size=2048,
generate_batch_size=1024
)
访问检索到的段落及对应的多跳检索路径:
print(retrieve_result.passage)
retrieve_result = retriever.multihop_search_passages(
LIST_OF_QUESTIONS,
n_hop=2,
top_n=5,
index_batch_size=2048,
generate_batch_size=1024,
return_tree=True
)
retrieval_tree = retrieve_result.tree_hop_graph[0]
retrieval_tree.plot_tree()
print(retrieval_tree.nodes(data=True))
论文复现
要评估TreeHop的多跳检索性能,运行以下代码。这里以2WikiMultihop数据集和三跳下的recall@5为例。
脚本将打印每跳的召回率和平均检索段落数,以及按问题类型的统计数据。
:bell: 注意
- 要更改评估数据集,将
2wiki
替换为musique
或multihop_rag
。
- 修改
n_hop
和top_n
以更改跳数和顶部检索设置。
- 切换
redundant_pruning
和layerwise_top_pruning
以复现我们对停止标准的消融研究。
python evaluation.py \
--dataset_name multihop_rag \
--revision paper-reproduction \
--n_hop 3 \
--top_n 5 \
--redundant_pruning True \
--layerwise_top_pruning True
训练TreeHop
运行以下代码生成图并训练TreeHop。有关脚本参数的详细信息,请参考training.py中的parse_args
函数。
有关训练嵌入生成,请参考init_train_vectors.py中的代码。
python training.py --graph_cache_dir ./train_data/
引用
@misc{li2025treehopgeneratefilterquery,
title={TreeHop: Generate and Filter Next Query Embeddings Efficiently for Multi-hop Question Answering},
author={Zhonghao Li and Kunpeng Zhang and Jinghuai Ou and Shuliang Liu and Xuming Hu},
year={2025},
eprint={2504.20114},
archivePrefix={arXiv},
primaryClass={cs.IR},
url={https://arxiv.org/abs/2504.20114},
}