license: apache-2.0
datasets:
- common_language
language:
- 阿拉伯语
- 巴斯克语
- 布列塔尼语
- 加泰罗尼亚语
- 中文
- 楚瓦什语
- 捷克语
- 荷兰语
- 英语
- 世界语
- 爱沙尼亚语
- 法语
- 格鲁吉亚语
- 德语
- 希腊语
- 印尼语
- 国际语
- 意大利语
- 日语
- 卢旺达语
- 吉尔吉斯语
- 拉脱维亚语
- 马耳他语
- 蒙古语
- 波斯语
- 波兰语
- 葡萄牙语
- 罗马尼亚语
- 罗曼什语
- 俄语
- 斯洛文尼亚语
- 西班牙语
- 瑞典语
- 泰米尔语
- 鞑靼语
- 土耳其语
- 乌克兰语
- 威尔士语
metrics:
- 准确率
- 精确率
- 召回率
- F1值
tags:
- 语言检测
- 弗里斯兰语
- 迪维希语
- 哈卡钦语
- 卡拜尔语
- 萨哈语
概述
该模型支持检测45种语言,基于multilingual-e5-base模型在common-language数据集上微调而成。
整体准确率达98.37%,详细评估结果如下所示。
下载模型
from transformers import AutoTokenizer, AutoModelForSequenceClassification
tokenizer = AutoTokenizer.from_pretrained('Mike0307/multilingual-e5-language-detection')
model = AutoModelForSequenceClassification.from_pretrained('Mike0307/multilingual-e5-language-detection', num_labels=45)
语言检测示例
import torch
languages = [
"阿拉伯语", "巴斯克语", "布列塔尼语", "加泰罗尼亚语", "中文_中国大陆", "中文_中国香港",
"中文_中国台湾", "楚瓦什语", "捷克语", "迪维希语", "荷兰语", "英语",
"世界语", "爱沙尼亚语", "法语", "弗里斯兰语", "格鲁吉亚语", "德语", "希腊语",
"哈卡钦语", "印尼语", "国际语", "意大利语", "日语", "卡拜尔语",
"卢旺达语", "吉尔吉斯语", "拉脱维亚语", "马耳他语", "蒙古语", "波斯语", "波兰语",
"葡萄牙语", "罗马尼亚语", "罗曼什语", "俄语", "萨哈语", "斯洛文尼亚语",
"西班牙语", "瑞典语", "泰米尔语", "鞑靼语", "土耳其语", "乌克兰语", "威尔士语"
]
def predict(text, model, tokenizer, device = torch.device('cpu')):
model.to(device)
model.eval()
tokenized = tokenizer(text, padding='max_length', truncation=True, max_length=128, return_tensors="pt")
input_ids = tokenized['input_ids']
attention_mask = tokenized['attention_mask']
with torch.no_grad():
input_ids = input_ids.to(device)
attention_mask = attention_mask.to(device)
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
logits = outputs.logits
probabilities = torch.nn.functional.softmax(logits, dim=1)
return probabilities
def get_topk(probabilities, languages, k=3):
topk_prob, topk_indices = torch.topk(probabilities, k)
topk_prob = topk_prob.cpu().numpy()[0].tolist()
topk_indices = topk_indices.cpu().numpy()[0].tolist()
topk_labels = [languages[index] for index in topk_indices]
return topk_prob, topk_labels
text = "你的测试句子"
probabilities = predict(text, model, tokenizer)
topk_prob, topk_labels = get_topk(probabilities, languages)
print(topk_prob, topk_labels)
评估结果
测试数据集采用common_language的测试集。
序号 |
语言 |
精确率 |
召回率 |
F1值 |
样本数 |
0 |
阿拉伯语 |
1.00 |
1.00 |
1.00 |
151 |
1 |
巴斯克语 |
0.99 |
1.00 |
1.00 |
111 |
2 |
布列塔尼语 |
1.00 |
0.90 |
0.95 |
252 |
3 |
加泰罗尼亚语 |
0.96 |
0.99 |
0.97 |
96 |
4 |
中文_中国大陆 |
0.98 |
1.00 |
0.99 |
100 |
5 |
中文_中国香港 |
0.97 |
0.87 |
0.92 |
115 |
6 |
中文_中国台湾 |
0.92 |
0.98 |
0.95 |
170 |
7 |
楚瓦什语 |
0.98 |
1.00 |
0.99 |
137 |
8 |
捷克语 |
0.98 |
1.00 |
0.99 |
128 |
9 |
迪维希语 |
1.00 |
1.00 |
1.00 |
111 |
10 |
荷兰语 |
0.99 |
1.00 |
0.99 |
144 |
11 |
英语 |
0.96 |
1.00 |
0.98 |
98 |
12 |
世界语 |
0.98 |
0.98 |
0.98 |
107 |
13 |
爱沙尼亚语 |
1.00 |
0.99 |
0.99 |
93 |
14 |
法语 |
0.95 |
1.00 |
0.98 |
106 |
15 |
弗里斯兰语 |
1.00 |
0.98 |
0.99 |
117 |
16 |
格鲁吉亚语 |
1.00 |
1.00 |
1.00 |
110 |
17 |
德语 |
1.00 |
1.00 |
1.00 |
101 |
18 |
希腊语 |
1.00 |
1.00 |
1.00 |
153 |
19 |
哈卡钦语 |
0.99 |
1.00 |
0.99 |
202 |
20 |
印尼语 |
0.99 |
0.99 |
0.99 |
150 |
21 |
国际语 |
0.96 |
0.97 |
0.96 |
182 |
22 |
意大利语 |
0.99 |
0.94 |
0.96 |
100 |
23 |
日语 |
1.00 |
1.00 |
1.00 |
144 |
24 |
卡拜尔语 |
1.00 |
0.96 |
0.98 |
156 |
25 |
卢旺达语 |
0.97 |
1.00 |
0.99 |
103 |
26 |
吉尔吉斯语 |
0.98 |
1.00 |
0.99 |
129 |
27 |
拉脱维亚语 |
0.98 |
0.98 |
0.98 |
171 |
28 |
马耳他语 |
0.99 |
0.98 |
0.98 |
152 |
29 |
蒙古语 |
1.00 |
1.00 |
1.00 |
112 |
30 |
波斯语 |
1.00 |
1.00 |
1.00 |
123 |
31 |
波兰语 |
0.91 |
0.99 |
0.95 |
128 |
32 |
葡萄牙语 |
0.94 |
0.99 |
0.96 |
124 |
33 |
罗马尼亚语 |
1.00 |
1.00 |
1.00 |
152 |
34 |
罗曼什语 |
0.99 |
0.95 |
0.97 |
106 |
35 |
俄语 |
0.99 |
0.99 |
0.99 |
100 |
36 |
萨哈语 |
0.99 |
1.00 |
1.00 |
105 |
37 |
斯洛文尼亚语 |
0.99 |
1.00 |
1.00 |
166 |
38 |
西班牙语 |
0.96 |
0.95 |
0.95 |
94 |
39 |
瑞典语 |
0.99 |
1.00 |
0.99 |
190 |
40 |
泰米尔语 |
1.00 |
1.00 |
1.00 |
135 |
41 |
鞑靼语 |
1.00 |
0.96 |
0.98 |
173 |
42 |
土耳其语 |
1.00 |
1.00 |
1.00 |
137 |
43 |
乌克兰语 |
0.99 |
1.00 |
1.00 |
126 |
44 |
威尔士语 |
0.98 |
1.00 |
0.99 |
103 |
|
|
|
|
|
|
|
宏平均 |
0.98 |
0.99 |
0.98 |
5963 |
|
加权平均 |
0.98 |
0.98 |
0.98 |
5963 |
|
|
|
|
|
|
|
总体准确率 |
|
|
0.9837 |
5963 |