🚀 TreeHop:高效生成和过滤多跳问答的下一个查询嵌入
TreeHop 是一个轻量级的嵌入级框架,旨在解决检索增强生成(RAG)领域中传统递归检索范式的计算效率低下问题。通过消除基于大语言模型(LLM)的迭代查询重写需求,TreeHop 在保持先进性能的同时显著降低了延迟。它通过动态查询嵌入更新和剪枝策略,实现了简化的“检索 - 嵌入 - 检索”工作流程。
项目链接
🚀 快速开始
系统要求
- 操作系统:Ubuntu 18.06 LTS 及以上版本,或 MacOS Big Sur 及以上版本。
- 硬件要求:Nvidia GPU 或配备 32GB 内存的 Apple Metal,系统内存 16GB 用于复现论文,64GB 用于训练 TreeHop,硬盘需有 50GB 可用空间。
- Python 环境:Python 3.9 及以上版本,具体依赖请参考 requirements.txt。
准备工作
本仓库包含用于复现的评估嵌入数据库。激活 git LFS 后,使用 git lfs clone [仓库链接]
克隆仓库,或在现有本地仓库中使用以下命令拉取数据:
git lfs pull
多跳检索使用示例
以下示例使用 MultiHop RAG 评估数据集。仓库中包含运行示例所需的文件,请参考准备工作。
from tree_hop import TreeHopModel
from passage_retrieval import MultiHopRetriever
EVALUATE_DATASET = "multihop_rag"
tree_hop_model = TreeHopModel.from_pretrained("allen-li1231/treehop-rag")
retriever = MultiHopRetriever(
"BAAI/bge-m3",
passages=f"embedding_data/{EVALUATE_DATASET}/eval_passages.jsonl",
passage_embeddings=f"embedding_data/{EVALUATE_DATASET}/eval_content_dense.npy",
tree_hop_model=tree_hop_model,
projection_size=1024,
save_or_load_index=True,
indexing_batch_size=10240,
index_device="cuda"
)
⚠️ 重要提示
retriever
具有 multihop_search_passages
方法,支持单查询和批量查询检索。
单查询示例
retrieve_result = retriever.multihop_search_passages(
"Did Engadget report a discount on the 13.6-inch MacBook Air \
before The Verge reported a discount on Samsung Galaxy Buds 2?",
n_hop=2,
top_n=5
)
批量查询示例
LIST_OF_QUESTIONS = [
"Did Engadget report a discount on the 13.6-inch MacBook Air \
before The Verge reported a discount on Samsung Galaxy Buds 2?",
"Did 'The Independent - Travel' report on Tremblant Ski Resort \
before 'Essentially Sports' mentioned Jeff Shiffrin's skiing habits?"
]
retrieve_result = retriever.multihop_search_passages(
LIST_OF_QUESTIONS,
n_hop=2,
top_n=5,
index_batch_size=2048,
generate_batch_size=1024
)
访问检索结果
print(retrieve_result.passage)
retrieve_result = retriever.multihop_search_passages(
LIST_OF_QUESTIONS,
n_hop=2,
top_n=5,
index_batch_size=2048,
generate_batch_size=1024,
return_tree=True
)
retrieval_tree = retrieve_result.tree_hop_graph[0]
retrieval_tree.plot_tree()
print(retrieval_tree.nodes(data=True))
✨ 主要特性
- 处理复杂查询:现实世界中的问题通常需要多跳检索才能获取相关信息,传统检索方法在这方面存在困难,而 TreeHop 能够有效应对。
- 成本效益高:TreeHop 仅包含 2500 万个参数,相比现有查询重写器的数十亿参数,显著降低了计算开销。
- 速度快:与基于迭代 LLM 的方法相比,推理速度快 99%,非常适合对响应速度要求较高的工业应用。
- 性能出色:在控制检索段落数量的情况下,保持高召回率,确保相关性的同时不会使系统负担过重。
📦 安装指南
本项目依赖 Python 3.9 及以上版本,具体依赖请参考 requirements.txt。
📚 详细文档
论文复现
要评估 TreeHop 的多跳检索性能,请运行以下代码。以下以 2WikiMultihop 数据集和三跳下的 recall@5 为例。脚本将打印每一跳的召回率和平均检索段落数,以及按问题类型的统计信息。
⚠️ 重要提示
- 如需更改评估数据集,请将
2wiki
替换为 musique
或 multihop_rag
。
- 修改
n_hop
和 top_n
以更改跳数和顶级检索设置。
- 切换
redundant_pruning
和 layerwise_top_pruning
以复现我们关于停止准则的消融研究。
python evaluation.py \
--dataset_name multihop_rag \
--revision paper-reproduction \
--n_hop 3 \
--top_n 5 \
--redundant_pruning True \
--layerwise_top_pruning True
训练 TreeHop
运行以下代码生成图并训练 TreeHop。有关此脚本的参数,请参考 training.py 中的 parse_args
函数。有关训练嵌入生成,请参考 init_train_vectors.py 中的代码。
python training.py --graph_cache_dir ./train_data/
🔧 技术细节
嵌入数据库
我们采用 BGE-m3 进行嵌入生成,并在此基础上训练 TreeHop 模型进行多跳检索。运行以下两个脚本可生成所有必要的训练和评估嵌入数据库。如果不打算训练 TreeHop,则无需运行这些脚本,因为仓库中已提供所有必要的评估嵌入数据库。
python init_train_vectors.py
python init_multihop_rag.py
📄 许可证
本项目采用 MIT 许可证。有关详细信息,请参阅 LICENSE。
📖 引用
如果您在研究中使用了本项目,请引用以下论文:
@misc{li2025treehopgeneratefilterquery,
title={TreeHop: Generate and Filter Next Query Embeddings Efficiently for Multi-hop Question Answering},
author={Zhonghao Li and Kunpeng Zhang and Jinghuai Ou and Shuliang Liu and Xuming Hu},
year={2025},
eprint={2504.20114},
archivePrefix={arXiv},
primaryClass={cs.IR},
url={https://arxiv.org/abs/2504.20114},
}