数据集:
- sentence-transformers/embedding-training-data
- flax-sentence-embeddings/stackexchange_xml
- snli
- eli5
- search_qa
- multi_nli
- wikihow
- natural_questions
- trivia_qa
- ms_marco
- gooaq
- yahoo_answers_topics
语言:
- en
推理:false
管道标签:sentence-similarity
任务类别:
- 句子相似度
- 特征提取
- 文本检索
标签:
- 信息检索
- ir
- 文档检索
- 段落检索
- beir
- 基准测试
- sts
- 语义搜索
- sentence-transformers
- 特征提取
- 句子相似度
- transformers
bert-base-1024-biencoder-64M-pairs
基于MosaicML预训练的1024序列长度BERT的长上下文双编码器。该模型将句子和段落映射到768维密集向量空间,可用于聚类或语义搜索等任务。
使用方法
下载模型及相关脚本
git clone https://huggingface.co/shreyansh26/bert-base-1024-biencoder-64M-pairs
推理
import torch
from torch import nn
from transformers import AutoModelForMaskedLM, AutoTokenizer, pipeline, AutoModel
from mosaic_bert import BertModel
class AutoModelForSentenceEmbedding(nn.Module):
def __init__(self, model, tokenizer, normalize=True):
super(AutoModelForSentenceEmbedding, self).__init__()
self.model = model.to("cuda")
self.normalize = normalize
self.tokenizer = tokenizer
def forward(self, **kwargs):
model_output = self.model(**kwargs)
embeddings = self.mean_pooling(model_output, kwargs['attention_mask'])
if self.normalize:
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
return embeddings
def mean_pooling(self, model_output, attention_mask):
token_embeddings = model_output[0]
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
model = AutoModel.from_pretrained("<模型路径>", trust_remote_code=True).to("cuda")
model = AutoModelForSentenceEmbedding(model, tokenizer)
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
sentences = ["这是一个示例句子", "每个句子都会被转换"]
encoded_input = tokenizer(sentences, padding=True, truncation=True, max_length=1024, return_tensors='pt').to("cuda")
embeddings = model(**encoded_input)
print(embeddings)
print(embeddings.shape)
其他详情
训练
该模型已在64M随机采样的句子/段落对上进行了训练,这些数据来自与Sentence Transformers模型相同的训练集。训练集的详细信息请参见此处。
训练(包括超参数)、推理和数据加载脚本均可在此Github仓库中找到。
评估
我们在几个基于检索的基准测试(CQADupstackEnglishRetrieval、DBPedia、MSMARCO、QuoraRetrieval)上运行了该模型,结果见此处。