DiTy/俄语MS-MARCO交叉编码器
这是一个基于DeepPavlov/rubert-base-cased预训练模型并通过MS-MARCO俄语段落排序数据集微调的sentence-transformers模型。该模型可用于俄语信息检索:给定查询时,对所有可能段落(例如通过ElasticSearch检索获得)进行编码,然后按得分降序排序。详见SBERT.net检索与重排序。
使用方式(Sentence-Transformers)
安装sentence-transformers后即可轻松使用:
pip install -U sentence-transformers
使用示例:
from sentence_transformers import CrossEncoder
reranker_model = CrossEncoder('DiTy/cross-encoder-russian-msmarco', max_length=512, device='cuda')
query = ["应该多久去看一次牙医?"]
documents = [
"强制最低就诊频率为每年一次,但专家建议更频繁——每半年一次,最佳是每季度一次。这样能及时发现并解决初期问题。",
"主要原因是牙齿表层釉质变薄,釉质可保护牙齿免受机械、化学和温度影响。釉质下是结构更软的牙本质,布满微小管腔。釉质受损会导致牙本质管暴露,刺激传导至神经末梢引发疼痛。最常见于牙龈附近区域,该处釉质最薄且磨损最快。",
"牙医(或称口腔外科医生)是专门从事牙科医疗的专业人员,专注于牙齿、牙龈和口腔健康的医学分支。",
"尤金叔叔是牙医",
"树莓果实可鲜食,也可冷冻或制作果酱、果冻、果汁及果泥。树莓酒、露酒、浸酒和利口酒具有卓越风味。"
]
predict_result = reranker_model.predict([[query[0], documents[0]]])
print(predict_result)
rank_result = reranker_model.rank(query[0], documents)
print(rank_result)
使用方式(HuggingFace Transformers)
若不使用sentence-transformers,可通过以下方式使用模型:首先将输入传入transformer模型,然后获取模型输出的逻辑值。
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
model = AutoModelForSequenceClassification.from_pretrained('DiTy/cross-encoder-russian-msmarco')
tokenizer = AutoTokenizer.from_pretrained('DiTy/cross-encoder-russian-msmarco')
features = tokenizer(["应该多久去看一次牙医?", "应该多久去看一次牙医?"], ["强制最低就诊频率为每年一次,但专家建议更频繁——每半年一次,最佳是每季度一次。这样能及时发现并解决初期问题。", "尤金叔叔是牙医"], padding=True, truncation=True, return_tensors='pt')
model.eval()
with torch.no_grad():
scores = model(**features).logits
print(scores)