标签:
- 由Keras回调生成
- dpr
许可证: apache-2.0
模型索引:
- 名称: dpr-question_encoder_bert_uncased_L-2_H-128_A-2
结果: []
dpr-question_encoder_bert_uncased_L-2_H-128_A-2
该模型(google/bert_uncased_L-2_H-128_A-2)基于训练数据data.retriever.nq-adv-hn-train(facebookresearch/DPR)从头开始训练。在评估集上取得了以下结果:
评估数据
评估数据集:来自官方DPR GitHub的facebook-dpr-dev-dataset
模型名称 |
数据名称 |
查询数量 |
段落数量 |
R@10 |
R@20 |
R@50 |
R@100 |
R@100 |
nlpconnect/dpr-ctx_encoder_bert_uncased_L-2_H-128_A-2(我们的) |
nq-dev数据集 |
6445 |
199795 |
60.53% |
68.28% |
76.07% |
80.98% |
91.45% |
nlpconnect/dpr-ctx_encoder_bert_uncased_L-12_H-128_A-2(我们的) |
nq-dev数据集 |
6445 |
199795 |
65.43% |
71.99% |
79.03% |
83.24% |
92.11% |
*facebook/dpr-ctx_encoder-single-nq-base(hf/fb) |
nq-dev数据集 |
6445 |
199795 |
40.94% |
49.27% |
59.05% |
66.00% |
82.00% |
评估数据集:UKPLab/beir测试数据,但我们仅使用了前20万段落。
模型名称 |
数据名称 |
查询数量 |
段落数量 |
R@10 |
R@20 |
R@50 |
R@100 |
R@100 |
nlpconnect/dpr-ctx_encoder_bert_uncased_L-2_H-128_A-2(我们的) |
nq-test数据集 |
3452 |
200001 |
49.68% |
59.06% |
69.40% |
75.75% |
89.28% |
nlpconnect/dpr-ctx_encoder_bert_uncased_L-12_H-128_A-2(我们的) |
nq-test数据集 |
3452 |
200001 |
51.62% |
61.09% |
70.10% |
76.07% |
88.70% |
*facebook/dpr-ctx_encoder-single-nq-base(hf/fb) |
nq-test数据集 |
3452 |
200001 |
32.93% |
43.74% |
56.95% |
66.30% |
83.92% |
注:*表示我们在相同的评估数据集上进行了评估。
使用方式(HuggingFace Transformers)
passage_encoder = TFAutoModel.from_pretrained("nlpconnect/dpr-ctx_encoder_bert_uncased_L-12_H-128_A-2")
query_encoder = TFAutoModel.from_pretrained("nlpconnect/dpr-question_encoder_bert_uncased_L-12_H-128_A-2")
p_tokenizer = AutoTokenizer.from_pretrained("nlpconnect/dpr-ctx_encoder_bert_uncased_L-12_H-128_A-2")
q_tokenizer = AutoTokenizer.from_pretrained("nlpconnect/dpr-question_encoder_bert_uncased_L-12_H-128_A-2")
def get_title_text_combined(passage_dicts):
res = []
for p in passage_dicts:
res.append(tuple((p['title'], p['text'])))
return res
processed_passages = get_title_text_combined(passage_dicts)
def extracted_passage_embeddings(processed_passages, model_config):
passage_inputs = tokenizer.batch_encode_plus(
processed_passages,
add_special_tokens=True,
truncation=True,
padding="max_length",
max_length=model_config.passage_max_seq_len,
return_token_type_ids=True
)
passage_embeddings = passage_encoder.predict([np.array(passage_inputs['input_ids']),
np.array(passage_inputs['attention_mask']),
np.array(passage_inputs['token_type_ids'])],
batch_size=512,
verbose=1)
return passage_embeddings
passage_embeddings = extracted_passage_embeddings(processed_passages, model_config)
def extracted_query_embeddings(queries, model_config):
query_inputs = tokenizer.batch_encode_plus(
queries,
add_special_tokens=True,
truncation=True,
padding="max_length",
max_length=model_config.query_max_seq_len,
return_token_type_ids=True
)
query_embeddings = query_encoder.predict([np.array(query_inputs['input_ids']),
np.array(query_inputs['attention_mask']),
np.array(query_inputs['token_type_ids'])],
batch_size=512,
verbose=1)
return query_embeddings
query_embeddings = extracted_query_embeddings(queries, model_config)
训练超参数
训练过程中使用了以下超参数:
框架版本
- Transformers 4.15.0
- TensorFlow 2.7.0
- Tokenizers 0.10.3