标签:
该模型是SPAR论文中Wiki BM25词汇模型(Λ)的查询编码器:
显著短语感知的密集检索:密集检索器能否模仿稀疏检索器?
作者:Xilun Chen, Kushal Lakhotia, Barlas Oğuz, Anchit Gupta, Patrick Lewis, Stan Peshterliev, Yashar Mehdad, Sonal Gupta 和 Wen-tau Yih
Meta AI
相关GitHub仓库地址:https://github.com/facebookresearch/dpr-scale/tree/main/spar
该模型是一个基于BERT-base架构的密集检索器,在维基百科文章上训练,旨在模仿BM25的行为。以下是其他可用模型:
预训练模型 |
语料库 |
教师模型 |
架构 |
查询编码器路径 |
上下文编码器路径 |
Wiki BM25 Λ |
维基百科 |
BM25 |
BERT-base |
facebook/spar-wiki-bm25-lexmodel-query-encoder |
facebook/spar-wiki-bm25-lexmodel-context-encoder |
PAQ BM25 Λ |
PAQ |
BM25 |
BERT-base |
facebook/spar-paq-bm25-lexmodel-query-encoder |
facebook/spar-paq-bm25-lexmodel-context-encoder |
MARCO BM25 Λ |
MS MARCO |
BM25 |
BERT-base |
facebook/spar-marco-bm25-lexmodel-query-encoder |
facebook/spar-marco-bm25-lexmodel-context-encoder |
MARCO UniCOIL Λ |
MS MARCO |
UniCOIL |
BERT-base |
facebook/spar-marco-unicoil-lexmodel-query-encoder |
facebook/spar-marco-unicoil-lexmodel-context-encoder |
单独使用词汇模型(Λ)
该模型应与关联的上下文编码器配合使用,类似于DPR模型。
import torch
from transformers import AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained('facebook/spar-wiki-bm25-lexmodel-query-encoder')
query_encoder = AutoModel.from_pretrained('facebook/spar-wiki-bm25-lexmodel-query-encoder')
context_encoder = AutoModel.from_pretrained('facebook/spar-wiki-bm25-lexmodel-context-encoder')
query = "玛丽·居里出生在哪里?"
contexts = [
"玛丽亚·斯克沃多夫斯卡,后来被称为玛丽·居里,出生于1867年11月7日。",
"皮埃尔·居里1859年5月15日出生于巴黎,是来自阿尔萨斯的法国天主教医生欧仁·居里的儿子。"
]
query_input = tokenizer(query, return_tensors='pt')
ctx_input = tokenizer(contexts, padding=True, truncation=True, return_tensors='pt')
query_emb = query_encoder(**query_input).last_hidden_state[:, 0, :]
ctx_emb = context_encoder(**ctx_input).last_hidden_state[:, 0, :]
score1 = query_emb @ ctx_emb[0]
score2 = query_emb @ ctx_emb[1]
结合基础密集检索器使用词汇模型(Λ)(如SPAR中所示)
由于Λ从稀疏教师检索器学习了词汇匹配,它可以与标准密集检索器(如DPR、Contriever)结合使用,构建一个在词汇和语义匹配上都表现出色的密集检索器。
以下示例展示了如何通过拼接DPR和Wiki BM25 Λ的嵌入来构建SPAR-Wiki模型,用于开放域问答:
import torch
from transformers import AutoTokenizer, AutoModel
from transformers import DPRQuestionEncoder, DPRQuestionEncoderTokenizer
from transformers import DPRContextEncoder, DPRContextEncoderTokenizer
dpr_ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-multiset-base")
dpr_ctx_encoder = DPRContextEncoder.from_pretrained("facebook/dpr-ctx_encoder-multiset-base")
dpr_query_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained("facebook/dpr-question_encoder-multiset-base")
dpr_query_encoder = DPRQuestionEncoder.from_pretrained("facebook/dpr-question_encoder-multiset-base")
lexmodel_tokenizer = AutoTokenizer.from_pretrained('facebook/spar-wiki-bm25-lexmodel-query-encoder')
lexmodel_query_encoder = AutoModel.from_pretrained('facebook/spar-wiki-bm25-lexmodel-query-encoder')
lexmodel_context_encoder = AutoModel.from_pretrained('facebook/spar-wiki-bm25-lexmodel-context-encoder')
query = "玛丽·居里出生在哪里?"
contexts = [
"玛丽亚·斯克沃多夫斯卡,后来被称为玛丽·居里,出生于1867年11月7日。",
"皮埃尔·居里1859年5月15日出生于巴黎,是来自阿尔萨斯的法国天主教医生欧仁·居里的儿子。"
]
dpr_query_input = dpr_query_tokenizer(query, return_tensors='pt')['input_ids']
dpr_query_emb = dpr_query_encoder(dpr_query_input).pooler_output
dpr_ctx_input = dpr_ctx_tokenizer(contexts, padding=True, truncation=True, return_tensors='pt')
dpr_ctx_emb = dpr_ctx_encoder(**dpr_ctx_input).pooler_output
lexmodel_query_input = lexmodel_tokenizer(query, return_tensors='pt')
lexmodel_query_emb = lexmodel_query_encoder(**query_input).last_hidden_state[:, 0, :]
lexmodel_ctx_input = lexmodel_tokenizer(contexts, padding=True, truncation=True, return_tensors='pt')
lexmodel_ctx_emb = lexmodel_context_encoder(**ctx_input).last_hidden_state[:, 0, :]
concat_weight = 0.7
spar_query_emb = torch.cat(
[dpr_query_emb, concat_weight * lexmodel_query_emb],
dim=-1,
)
spar_ctx_emb = torch.cat(
[dpr_ctx_emb, lexmodel_ctx_emb],
dim=-1,
)
score1 = spar_query_emb @ spar_ctx_emb[0]
score2 = spar_query_emb @ spar_ctx_emb[1]