pipeline_tag: 句子相似度
license: apache-2.0
tags:
- sentence-transformers
- 特征提取
- 句子相似度
- transformers
sentence-transformers/msmarco-distilbert-base-tas-b
此模型是将DistilBert TAS-B模型移植到sentence-transformers框架的版本:它能将句子和段落映射到768维的稠密向量空间,并针对语义搜索任务进行了优化。
使用方法(Sentence-Transformers)
安装sentence-transformers后即可轻松使用此模型:
pip install -U sentence-transformers
然后按如下方式使用模型:
from sentence_transformers import SentenceTransformer, util
query = "伦敦有多少人口?"
docs = ["大约有900万人生活在伦敦", "伦敦以其金融区闻名"]
model = SentenceTransformer('sentence-transformers/msmarco-distilbert-base-tas-b')
query_emb = model.encode(query)
doc_emb = model.encode(docs)
scores = util.dot_score(query_emb, doc_emb)[0].cpu().tolist()
doc_score_pairs = list(zip(docs, scores))
doc_score_pairs = sorted(doc_score_pairs, key=lambda x: x[1], reverse=True)
for doc, score in doc_score_pairs:
print(score, doc)
使用方法(HuggingFace Transformers)
若不使用sentence-transformers,可按如下方式使用模型:首先将输入传递给transformer模型,然后对上下文化的词嵌入应用正确的池化操作。
from transformers import AutoTokenizer, AutoModel
import torch
def cls_pooling(model_output):
return model_output.last_hidden_state[:,0]
def encode(texts):
encoded_input = tokenizer(texts, padding=True, truncation=True, return_tensors='pt')
with torch.no_grad():
model_output = model(**encoded_input, return_dict=True)
embeddings = cls_pooling(model_output)
return embeddings
query = "伦敦有多少人口?"
docs = ["大约有900万人生活在伦敦", "伦敦以其金融区闻名"]
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/msmarco-distilbert-base-tas-b")
model = AutoModel.from_pretrained("sentence-transformers/msmarco-distilbert-base-tas-b")
query_emb = encode(query)
doc_emb = encode(docs)
scores = torch.mm(query_emb, doc_emb.transpose(0, 1))[0].cpu().tolist()
doc_score_pairs = list(zip(docs, scores))
doc_score_pairs = sorted(doc_score_pairs, key=lambda x: x[1], reverse=True)
for doc, score in doc_score_pairs:
print(score, doc)
评估结果
关于此模型的自动化评估,请参见句子嵌入基准测试:https://seb.sbert.net
完整模型架构
SentenceTransformer(
(0): Transformer({'max_seq_length': 512, 'do_lower_case': False}) with Transformer model: DistilBertModel
(1): Pooling({'word_embedding_dimension': 768, 'pooling_mode_cls_token': True, 'pooling_mode_mean_tokens': False, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False})
)
引用与作者
详情请参阅:DistilBert TAS-B Model