语言:
- 日语
库名称: sentence-transformers
标签:
- sentence-transformers
- 句子相似度
- 特征提取
指标:
小部件: []
管道标签: 句子相似度
许可证: apache-2.0
数据集:
- hpprc/emb
- hpprc/mqa-ja
- google-research-datasets/paws-x
RoSEtta
RoSEtta(基于RoFormer的句子编码器通过蒸馏)是一款通用日语文本嵌入模型,擅长检索任务。它支持最大1024的序列长度,可处理长句输入。该模型可在CPU上运行,专为测量句子间语义相似度设计,也可作为基于查询的段落检索系统使用。
核心特性:
- 采用RoPE(旋转位置编码)
- 最大1024标记的序列长度
- 从大型句子嵌入模型蒸馏而来
- 专为检索任务优化
推理时需添加"query: "或"passage: "前缀,具体用法请参阅使用说明。
模型说明
本模型基于RoFormer架构。在使用MLM损失进行预训练后,进行了弱监督学习。此外,通过多个大型嵌入模型蒸馏和多阶段对比学习(类似GLuCoSE v2)进行了进一步训练。
- 最大序列长度: 1024标记
- 输出维度: 768标记
- 相似度函数: 余弦相似度
使用方法
直接使用(Sentence Transformers)
可通过以下代码使用SentenceTransformer进行推理:
from sentence_transformers import SentenceTransformer
import torch.nn.functional as F
model = SentenceTransformer("pkshatech/RoSEtta-base-ja",trust_remote_code=True)
sentences = [
'query: PKSHA是怎样的公司?',
'passage: 我们将研发的算法引入众多企业的软件运营中。',
'query: 日本最高山峰是?',
'passage: 富士山(海拔3776.12米)是日本最高峰(剑峰),其优美轮廓作为日本象征在国际上广为人知。',
]
embeddings = model.encode(sentences,convert_to_tensor=True)
print(embeddings.shape)
similarities = F.cosine_similarity(embeddings.unsqueeze(0), embeddings.unsqueeze(1), dim=2)
print(similarities)
直接使用(Transformers)
可通过以下代码使用Transformers进行推理:
import torch.nn.functional as F
from torch import Tensor
from transformers import AutoTokenizer, AutoModel
def mean_pooling(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
emb = last_hidden_states * attention_mask.unsqueeze(-1)
emb = emb.sum(dim=1) / attention_mask.sum(dim=1).unsqueeze(-1)
return emb
tokenizer = AutoTokenizer.from_pretrained("pkshatech/RoSEtta-base-ja")
model = AutoModel.from_pretrained("pkshatech/RoSEtta-base-ja", trust_remote_code=True)
训练细节
RoSEtta的微调通过以下步骤完成:
阶段1:预训练
阶段2:弱监督学习
阶段3:集成蒸馏
阶段4:对比学习
阶段5:检索专用对比学习
基准测试
检索性能
在MIRACL-ja等数据集评估:
模型 |
参数量 |
MIRACL召回率@5 |
JQaRA nDCG@10 |
JaCWIR MAP@10 |
MLDR nDCG@10 |
RoSEtta |
0.2B |
79.3 |
57.7 |
83.8 |
32.3 |
JMTEB评估
综合评估结果:
模型 |
参数量 |
平均分 |
检索 |
STS |
分类 |
重排序 |
聚类 |
句子对分类 |
RoSEtta |
0.2B |
72.45 |
73.21 |
81.39 |
72.41 |
92.69 |
53.23 |
61.74 |
作者
矢野千尋、後町萌、立花英之、竹川大翔、渡辺洋太郎
许可
本模型基于Apache 2.0许可证发布。