语言:
- 德语
- 法语
- 意大利语
- 罗曼什语
任务标签: 句子相似度
SwissBERT模型通过自监督的SimCSE(Gao等,EMNLP 2021)方法进行了微调,用于生成句子嵌入。训练数据来自截至2023年的约150万篇瑞士新闻文章(通过Swissdox@LiRI获取)。采用Sentence Transformers方法(Reimers和Gurevych,2019),使用最后一层隐藏状态的平均值(pooler_type=avg)作为句子表示。

模型详情
模型描述
- 开发者: Juri Grosjean
- 模型类型: XMOD
- 支持语言(NLP): 瑞士德语(de_CH)、瑞士法语(fr_CH)、瑞士意大利语(it_CH)、瑞士罗曼什语(rm_CH)
- 许可证: 署名-非商业性使用 4.0 国际 (CC BY-NC 4.0)
- 基础模型: SwissBERT
使用方法
import torch
from transformers import AutoModel, AutoTokenizer
model_name = "jgrosjean-mathesis/sentence-swissbert"
model = AutoModel.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
def generate_sentence_embedding(sentence, language):
if "de" in language:
model.set_default_language("de_CH")
if "fr" in language:
model.set_default_language("fr_CH")
if "it" in language:
model.set_default_language("it_CH")
if "rm" in language:
model.set_default_language("rm_CH")
inputs = tokenizer(sentence, padding=True, truncation=True, return_tensors="pt", max_length=512)
with torch.no_grad():
outputs = model(**inputs)
token_embeddings = outputs.last_hidden_state
attention_mask = inputs['attention_mask'].unsqueeze(-1).expand(token_embeddings.size()).float()
sum_embeddings = torch.sum(token_embeddings * attention_mask, 1)
sum_mask = torch.clamp(attention_mask.sum(1), min=1e-9)
embedding = sum_embeddings / sum_mask
return embedding
sentence_0 = "8月1日我们庆祝瑞士国庆节。"
sentence_0_embedding = generate_sentence_embedding(sentence_0, language="de")
print(sentence_0_embedding)
输出:
tensor([[ 0.0563, -0.2837, -0.0415, ..., ]])
语义文本相似度
from sklearn.metrics.pairwise import cosine_similarity
sentence_1 = ["火车9点到达苏黎世。"]
sentence_2 = ["列车将于9点抵达洛桑。"]
embedding_1 = generate_sentence_embedding(sentence_1, language="de")
embedding_2 = generate_sentence_embedding(sentence_2, language="fr")
cosine_score = cosine_similarity(embedding_1, embedding_2)
print("句子1与句子2的相似度得分为:", cosine_score)
输出:
句子1与句子2的相似度得分为: [[0.85555995]]
偏差、风险与限制
该模型仅基于新闻文章训练,在其他文本类型上可能表现不佳。其瑞士语境特异性也意味着对非相关文本效果有限。此外,该模型未经机器翻译任务的训练或评估。
训练详情
训练数据
训练数据来自Swissdox@LiRI数据库中截至2023年的德语、法语、意大利语和罗曼什语文档。
训练过程
采用自监督SimCSE方法微调,正样本对为文章正文与标题导语的组合,未使用硬负样本。
训练脚本见GitHub仓库。
超参数设置
- 训练轮次: 1
- 学习率: 1e-5
- 批次大小: 512
- 温度系数: 0.05
评估
测试数据
使用Kew等(2023)编制的20 Minuten数据集,包含带主题标签和摘要的瑞士新闻文章,部分数据通过Google Cloud API自动翻译为法语、意大利语,通过Textshuttle API翻译为罗曼什语。
文档检索评估
计算文档摘要与正文的嵌入向量,通过最大化余弦相似度进行匹配,以准确率(正确匹配比例)衡量性能。评估脚本见GitHub。
文本分类评估
将带主题标签的文章映射到10个类别,按80%/20%划分训练测试集,采用k近邻方法进行分类。评估脚本见GitHub。
注:法语、意大利语和罗曼什语测试使用德语训练数据,以评估跨语言迁移能力。
评估结果
在多数评估任务中,Sentence SwissBERT表现优于最佳多语言Sentence-BERT模型(distiluse-base-multilingual-cased),仅在意大利语文本分类任务中稍逊。
评估任务 |
Swissbert |
|
Sentence Swissbert |
|
Sentence-BERT |
|
|
准确率 |
F1值 |
准确率 |
F1值 |
准确率 |
F1值 |
文档检索(德语) |
87.20% |
-- |
93.40% |
-- |
91.80% |
-- |
文档检索(法语) |
84.97% |
-- |
93.99% |
-- |
93.19% |
-- |
文档检索(意语) |
84.17% |
-- |
92.18% |
-- |
91.58% |
-- |
文档检索(罗语) |
83.17% |
-- |
91.58% |
-- |
73.35% |
-- |
文本分类(德语) |
-- |
77.93% |
-- |
78.49% |
-- |
77.23% |
文本分类(法语) |
-- |
69.62% |
-- |
77.18% |
-- |
76.83% |
文本分类(意语) |
-- |
67.09% |
-- |
76.65% |
-- |
76.90% |
文本分类(罗语) |
-- |
43.79% |
-- |
77.20% |
-- |
65.35% |
基线模型
基线使用原始SwissBERT模型的最后一层隐藏状态均值池化嵌入,以及表现最佳的多语言Sentence-BERT模型distiluse-base-multilingual-cased-v1。