pipeline_tag: 特征提取
tags:
- 句子转换器
- 特征提取
- 句子相似度
language: 英文
license: apache-2.0
all-mpnet-base-v2 克隆版
这是一个sentence-transformers模型:它能将句子和段落映射到768维的密集向量空间,可用于聚类或语义搜索等任务。
本模型与官方版本唯一区别在于修改了README.md中的pipeline_tag: feature-extraction
字段。
使用方法(Sentence-Transformers)
安装sentence-transformers后即可轻松使用:
pip install -U sentence-transformers
使用示例:
from sentence_transformers import SentenceTransformer
sentences = ["这是一个示例句子", "每个句子都会被转换"]
model = SentenceTransformer('sentence-transformers/all-mpnet-base-v2')
embeddings = model.encode(sentences)
print(embeddings)
使用方法(HuggingFace Transformers)
若不使用sentence-transformers,可按以下方式操作:首先通过transformer模型处理输入,然后对上下文词嵌入执行正确的池化操作。
from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn.functional as F
def mean_pooling(model_output, attention_mask):
token_embeddings = model_output[0]
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
sentences = ['这是一个示例句子', '每个句子都会被转换']
tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-mpnet-base-v2')
model = AutoModel.from_pretrained('sentence-transformers/all-mpnet-base-v2')
encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')
with torch.no_grad():
model_output = model(**encoded_input)
sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
print("句子嵌入向量:")
print(sentence_embeddings)
评估结果
自动化评估结果请参见句子嵌入基准测试:https://seb.sbert.net
背景
本项目旨在通过自监督对比学习目标,在超大规模句子级数据集上训练句子嵌入模型。我们基于预训练的microsoft/mpnet-base
模型,在10亿句对数据集上进行微调。采用对比学习目标:给定句对中的一个句子,模型需从随机采样的负例中预测出真实配对句子。
该模型开发于Hugging Face组织的使用JAX/Flax进行NLP与CV的社区周活动期间,属于用10亿训练句对打造最佳句子嵌入模型项目。我们受益于高效的硬件基础设施(7块TPU v3-8)以及Google Flax、JAX和云团队关于高效深度学习框架的指导。
用途说明
本模型设计用于句子和短段落编码。输入文本后,可输出包含语义信息的向量。该句子向量可用于信息检索、聚类或句子相似度任务。
默认情况下,超过384个词片的输入文本会被截断。
训练流程
预训练阶段
使用预训练的microsoft/mpnet-base
模型,预训练细节详见模型卡片。
微调阶段
采用对比目标进行微调。具体而言,我们计算批次内所有可能句对的余弦相似度,然后通过与真实句对比较应用交叉熵损失。
超参数配置
使用TPU v3-8进行训练,共10万步,批次大小为1024(每TPU核心128)。采用500步学习率预热,序列长度限制为128个词片。使用AdamW优化器,学习率2e-5。完整训练脚本见仓库中的train_script.py
。
训练数据
通过合并多个数据集进行微调,总句对数量超过10亿。各数据集采样权重详见data_config.json
文件。
数据集 |
论文 |
训练句对数 |
Reddit评论(2015-2018) |
论文 |
726,484,430 |
S2ORC引文对(摘要) |
论文 |
116,288,806 |
WikiAnswers重复问题对 |
论文 |
77,427,422 |
PAQ(问答对) |
论文 |
64,371,441 |
S2ORC引文对(标题) |
论文 |
52,603,982 |
S2ORC(标题,摘要) |
论文 |
41,769,185 |
Stack Exchange(标题,正文) |
- |
25,316,456 |
Stack Exchange(标题+正文,答案) |
- |
21,396,559 |
Stack Exchange(标题,答案) |
- |
21,396,559 |
MS MARCO三元组 |
论文 |
9,144,553 |
GOOAQ: 开放域多样化答案问答 |
论文 |
3,012,496 |
Yahoo Answers(标题,答案) |
论文 |
1,198,260 |
代码搜索 |
- |
1,151,414 |
COCO图像描述 |
论文 |
828,395 |
SPECTER引文三元组 |
论文 |
684,100 |
Yahoo Answers(问题,答案) |
论文 |
681,164 |
Yahoo Answers(标题,问题) |
论文 |
659,896 |
SearchQA |
论文 |
582,261 |
Eli5 |
论文 |
325,475 |
Flickr 30k |
论文 |
317,695 |
Stack Exchange重复问题(标题) |
|
304,525 |
AllNLI(SNLI与MultiNLI) |
SNLI论文, MultiNLI论文 |
277,230 |
Stack Exchange重复问题(正文) |
|
250,519 |
Stack Exchange重复问题(标题+正文) |
|
250,460 |
句子压缩 |
论文 |
180,000 |
Wikihow |
论文 |
128,542 |
Altlex |
论文 |
112,696 |
Quora问题三元组 |
- |
103,663 |
简单维基百科 |
论文 |
102,225 |
自然问题(NQ) |
论文 |
100,231 |
SQuAD2.0 |
论文 |
87,599 |
TriviaQA |
- |
73,346 |
总计 |
|
1,170,060,424 |