🚀 用于SQuAD(QNLI)的交叉编码器
该模型使用 SentenceTransformers 的 Cross-Encoder 类进行训练,可有效解决文本排序相关问题。
🚀 快速开始
本模型基于 distilbert/distilroberta-base
基础模型,使用 sentence-transformers
库进行开发,适用于文本排序任务。
✨ 主要特性
- 基于
SentenceTransformers
库的 Cross-Encoder
类训练,在文本排序任务上表现出色。
- 模型在 GLUE QNLI 数据集上进行训练,该数据集将 SQuAD 数据集 转换为自然语言推理(NLI)任务。
📦 安装指南
文档未提及具体安装步骤,若需使用 sentence-transformers
库,可使用以下命令安装:
pip install sentence-transformers
💻 使用示例
基础用法
from sentence_transformers import CrossEncoder
model = CrossEncoder('cross-encoder/qnli-distilroberta-base')
scores = model.predict([('Query1', 'Paragraph1'), ('Query2', 'Paragraph2')])
scores = model.predict([('How many people live in Berlin?', 'Berlin had a population of 3,520,031 registered inhabitants in an area of 891.82 square kilometers.'), ('What is the size of New York?', 'New York City is famous for the Metropolitan Museum of Art.')])
高级用法
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
model = AutoModelForSequenceClassification.from_pretrained('cross-encoder/qnli-distilroberta-base')
tokenizer = AutoTokenizer.from_pretrained('cross-encoder/qnli-distilroberta-base')
features = tokenizer(['How many people live in Berlin?', 'What is the size of New York?'], ['Berlin had a population of 3,520,031 registered inhabitants in an area of 891.82 square kilometers.', 'New York City is famous for the Metropolitan Museum of Art.'], padding=True, truncation=True, return_tensors="pt")
model.eval()
with torch.no_grad():
scores = torch.nn.functional.sigmoid(model(**features).logits)
print(scores)
📚 详细文档
📄 许可证
本项目采用 Apache-2.0 许可证。
属性 |
详情 |
基础模型 |
distilbert/distilroberta-base |
模型类型 |
用于SQuAD(QNLI)的交叉编码器 |
训练数据 |
GLUE QNLI 数据集,将 SQuAD 数据集转换为 NLI 任务 |
库名称 |
sentence-transformers |
标签 |
transformers |