库名称: transformers
标签:
- 语法纠错
- 语法检查
语言:
- 英语
评估指标:
- 准确率
基础模型:
- microsoft/deberta-v3-large
任务标签: 标记分类
模型卡片
模型详情
模型描述
该模型是基于microsoft/deberta-v3-large
微调的语法纠错(GEC)系统,旨在检测和修正英语文本中的语法错误。模型专注于常见语法错误,如动词时态、名词变形、形容词用法等。特别适合语言学习者或需要提升语法精确度的应用场景。
- 模型类型: 带序列到序列修正的标记分类
- 支持语言: 英语
- 微调基础模型:
microsoft/deberta-v3-large
使用场景
直接使用
该模型可直接用于英语文本的语法错误检测与修正,适合集成到写作助手、教育软件或校对工具中。
下游任务
可针对特定领域(如学术写作、商务沟通或非正式文本修正)进行微调,确保在特定语境下的语法纠错精度。
不适用场景
本模型不适用于非英语文本、非语法类修正(如风格、语气或逻辑),也无法处理超出基础语法结构的复杂错误。
偏差、风险与限制
模型基于通用英语语料库训练,在非标准方言(如口语)或专业术语场景可能表现不佳。用户在此类语境中应谨慎使用,训练数据的局限性可能导致错误识别或遗漏。
使用建议
虽然模型整体表现良好,但在专业或创意写作等语法规则较灵活的语境中,建议人工复核修正结果。
快速开始
使用以下代码加载模型:
from dataclasses import dataclass
from typing import Optional, Tuple
import torch
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers import AutoConfig, AutoTokenizer
from transformers.file_utils import ModelOutput
from transformers.models.deberta_v2.modeling_deberta_v2 import (
DebertaV2Model,
DebertaV2PreTrainedModel,
)
@dataclass
class XGECToROutput(ModelOutput):
"""
`XGECToRForTokenClassification.forward()`的输出类型。
loss (`torch.FloatTensor`, 可选)
logits_correction (`torch.FloatTensor`) : 每个token的修正logits。
logits_detection (`torch.FloatTensor`) : 每个token的检测logits。
hidden_states (`Tuple[torch.FloatTensor]`, 可选)
attentions (`Tuple[torch.FloatTensor]`, 可选)
"""
loss: Optional[torch.FloatTensor] = None
logits_correction: torch.FloatTensor = None
logits_detection: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
class XGECToRDebertaV3(DebertaV2PreTrainedModel):
"""
扩展GECToR模型,包含错误检测头和标记分类头。
"""
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids"]
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.unk_tag_idx = config.label2id.get("@@UNKNOWN@@", None)
self.deberta = DebertaV2Model(config)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
if self.unk_tag_idx is not None:
self.error_detector = nn.Linear(config.hidden_size, 3)
else:
self.error_detector = nn.Linear(config.hidden_size, 2)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.deberta(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
logits_correction = self.classifier(sequence_output)
logits_detection = self.error_detector(sequence_output)
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits_correction.view(-1, self.num_labels), labels.view(-1))
labels_detection = torch.ones_like(labels)
labels_detection[labels == 0] = 0
labels_detection[labels == -100] = -100
if self.unk_tag_idx is not None:
labels_detection[labels == self.unk_tag_idx] = 2
loss_detection = loss_fct(logits_detection.view(-1, 3), labels_detection.view(-1))
else:
loss_detection = loss_fct(logits_detection.view(-1, 2), labels_detection.view(-1))
loss += loss_detection
if not return_dict:
output = (logits_correction, logits_detection) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return XGECToROutput(
loss=loss,
logits_correction=logits_correction,
logits_detection=logits_detection,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def get_input_embeddings(self):
return self.deberta.get_input_embeddings()
def set_input_embeddings(self, value):
self.deberta.set_input_embeddings(value)
config = AutoConfig.from_pretrained("manred1997/deberta-v3-large-lemon-spell_5k")
tokenizer = AutoTokenizer.from_pretrained("manred1997/deberta-v3-large-lemon-spell_5k")
model = XGECToRDeberta.from_pretrained(
"manred1997/deberta-v3-large-lemon-spell_5k", config=config
)
训练详情
训练数据
模型训练分为三个阶段,各阶段使用特定数据集:
阶段 |
使用数据集 |
说明 |
阶段1 |
PIE语料库(A1部分)900万条乱序句子 |
聚焦A1难度级别的句子 |
阶段2 |
NUCLE/FCE/Lang8/W&I+Locness混合数据集 |
Lang8数据集含947,344句,52.5%存在源/目标句差异 |
阶段3 |
W&I+Locness乱序数据集 |
最终优化数据集 |
评估
测试数据与指标
测试数据
使用标准语法纠错基准测试集W&I+Locness进行评估。
评估指标
主要使用F0.5分数衡量模型识别和修正语法错误的能力。
评估结果
F0.5 = 74.61