许可协议: cc-by-nc-sa-4.0
语言: 英语
标签:
- SPLADE
- 对话式搜索
- 多轮检索
- 查询扩展
- 文档扩展
- 段落检索
- 知识蒸馏
管道标签: 填充掩码
DiSCo:面向高效稀疏检索的对话式搜索大模型知识蒸馏
本模型是基于原始SPLADE++ (CoCondenser-EnsembleDistil)模型改进的对话式搜索版本。它保留了原始文档编码器,并在QReCC数据集上微调了查询编码器,该数据集专为多轮对话式搜索设计。
训练过程采用多教师蒸馏策略:人工与Mistral改写版本,使模型能更好地捕捉对话查询的语义。详见原论文:
- DiSCo SPLADE - SIGIR 2025全文:https://arxiv.org/abs/2410.14609
注意: 此为查询编码器。实际使用时需配合未修改的原始SPLADE++文档编码器。SPLADE支持非对称架构:查询与文档可使用不同表征模型。
使用说明
完整用法请参考DiSCo GitHub仓库[github]。
以下是对话编码的示例脚本:
输入格式为扁平化的对话历史序列:
当前问题 [SEP] 上轮回答 [SEP] 上轮问题 [SEP] ... [SEP] 初始回答 [SEP] 初始问题
from transformers import AutoTokenizer, AutoModelForMaskedLM
import torch.nn.functional as F
import torch
model = AutoModelForMaskedLM.from_pretrained("slupart/splade-disco-human-mistral")
tokenizer = AutoTokenizer.from_pretrained("slupart/splade-disco-human-mistral")
model.eval()
conv = [
("今天天气怎么样?", "晴天。"),
("需要涂防晒霜吗?", "需要,紫外线指数很高。"),
("要戴太阳镜吗?", "一定要戴。"),
("哪里能买到太阳镜?", "可以去附近的眼镜店。"),
("大概多少钱?", None)
]
parts = [conv[-1][0]] + [x for q, a in reversed(conv[:-1]) for x in (a, q) if x]
text = " [SEP] ".join(parts)
inputs = tokenizer(text, return_tensors="pt")
with torch.no_grad():
logits = model(**inputs).logits
sparse = F.relu(logits).max(1).values.squeeze(0)
scores = [(tokenizer.convert_ids_to_tokens([i.item()])[0], sparse[i].item())
for i in torch.nonzero(sparse).squeeze(1)]
for token, score in sorted(scores, key=lambda x: -x[1]):
print(f"词元: {token:15} | 得分: {score:.4f}")
引用
若使用本模型,请引用我们的工作:
@article{lupart2024disco,
title={DiSCo与LLMs的邂逅:对话式搜索中稀疏检索与上下文蒸馏的统一方法},
author={Lupart, Simon and Aliannejadi, Mohammad and Kanoulas, Evangelos},
journal={arXiv预印本 arXiv:2410.14609},
year={2024}
}