🚀 multi-qa_v1-distilbert-cls_dot
SentenceTransformers 是一套模型和框架,可根据给定数据进行训练并生成句子嵌入向量。生成的句子嵌入向量可用于聚类、语义搜索等任务。本模型使用预训练的 distilbert-base-uncased 模型,并通过孪生网络设置和对比学习目标进行训练。使用 StackExchange 中的问答对作为训练数据,使模型在问题/答案嵌入相似度方面表现更稳健。此模型使用 cls 输出而非平均池化作为句子嵌入,并使用点积来计算相似度以实现学习目标。
✨ 主要特性
- SentenceTransformers 框架:借助 SentenceTransformers 框架,能够进行句子嵌入的训练和生成。
- 预训练模型:基于 distilbert-base-uncased 预训练模型进行训练。
- 对比学习:采用孪生网络设置和对比学习目标,使用问答对数据提升模型性能。
- 特定输出方式:使用 cls 输出作为句子嵌入,通过点积计算相似度。
📚 详细文档
模型描述
SentenceTransformers 是一套能从给定数据中训练和生成句子嵌入的模型与框架。生成的句子嵌入可用于聚类、语义搜索等任务。我们使用预训练的 distilbert-base-uncased 模型,通过孪生网络设置和对比学习目标进行训练。以 StackExchange 的问答对作为训练数据,让模型在问题/答案嵌入相似度上更具鲁棒性。此模型使用 cls 输出而非平均池化作为句子嵌入,用点积计算相似度以达成学习目标。
我们在 Hugging Face 组织的 使用 JAX/Flax 进行 NLP 和 CV 的社区周 期间开发了该模型。此模型是项目 使用 10 亿训练对训练史上最佳句子嵌入模型 的一部分。我们借助高效的硬件基础设施(7 个 TPU v3 - 8)以及 Google 的 Flax、JAX 和云团队成员在高效深度学习框架方面的帮助来运行该项目。
预期用途
我们的模型旨在用作搜索引擎的句子编码器。给定输入句子,它会输出一个捕获句子语义信息的向量。该句子向量可用于语义搜索、聚类或句子相似度任务。
使用方法
以下是使用 SentenceTransformers 库获取给定文本特征的示例代码:
from sentence_transformers import SentenceTransformer
model = SentenceTransformer('flax-sentence-embeddings/multi-qa_v1-distilbert-cls_dot')
text = "Replace me by any question / answer you'd like."
text_embbedding = model.encode(text)
训练过程
预训练
我们使用预训练的 distilbert-base-uncased 模型。有关预训练过程的更多详细信息,请参考该模型的卡片。
微调
我们使用对比目标对模型进行微调。具体而言,我们计算批次中每个可能的句子对的余弦相似度,然后通过与真实对进行比较来应用交叉熵损失。
超参数
我们在 TPU v3 - 8 上训练模型。训练 80k 步,批次大小为 1024(每个 TPU 核心 128)。使用 500 的学习率预热。序列长度限制为 128 个标记。使用 AdamW 优化器,学习率为 2e - 5。完整的训练脚本可在当前仓库中获取。
训练数据
我们使用多个 Stackexchange 问答数据集的拼接来微调模型,同时也使用了 MSMARCO、NQ 等问答数据集。
数据集 |
论文 |
训练元组数 |
Stack Exchange QA - Title & Answer |
- |
4,750,619 |
Stack Exchange |
- |
364,001 |
TriviaqQA |
- |
73,346 |
SQuAD2.0 |
论文 |
87,599 |
Quora Question Pairs |
- |
103,663 |
Eli5 |
论文 |
325,475 |
PAQ |
论文 |
64,371,441 |
WikiAnswers |
论文 |
77,427,422 |
MS MARCO |
论文 |
9,144,553 |
GOOAQ: Open Question Answering with Diverse Answer Types |
论文 |
3,012,496 |
Yahoo Answers Question/Answer |
论文 |
681,164 |
SearchQA |
- |
582,261 |
Natural Questions (NQ) |
论文 |
100,231 |