语言: 英语
缩略图:
许可证: MIT
标签:
- 问答系统
数据集:
- SQuAD
评估指标:
- SQuAD
小部件示例:
- 文本: "埃菲尔铁塔位于哪里?"
上下文: "埃菲尔铁塔是一座位于法国巴黎战神广场的锻铁格子塔,以工程师古斯塔夫·埃菲尔命名,其公司设计并建造了该塔。"
- 文本: "弗雷德里克·肖邦是谁?"
上下文: "弗雷德里克·弗朗索瓦·肖邦(1810年3月1日-1849年10月17日),原名弗里德里克·弗朗齐歇克·肖邦,是浪漫主义时期波兰作曲家和钢琴演奏家,主要创作独奏钢琴曲。"
基于BERT-base uncased模型在SQuAD v1上的微调
本模型使用nn_pruning Python库创建:**线性层保留了原版27.0%**的权重。
该模型必须配合nn_pruning的optimize_model
函数使用,因为它采用NoNorms替代LayerNorms,而这一特性当前未被Transformers库原生支持。
为加速推理,模型将初始BERT网络中的GeLU激活函数替换为ReLU。这一改动无需特殊处理,Transformers库已支持并通过模型配置中的"hidden_act": "relu"
标识。
模型总体保留43.0%的原始权重(嵌入层占模型重要部分且未参与此次剪枝)。通过简单的线性矩阵缩放,模型在评估时运行速度达到bert-base-uncased的1.96倍。这得益于剪枝方法生成的结构化矩阵——可视化图表可通过悬停查看各矩阵非零/零部分分布。
准确度方面,其F1值为88.33,较bert-base-uncased的88.5仅下降0.17。
精细剪枝细节
本模型基于HuggingFace的bert-base-uncased检查点在SQuAD1.1微调,并蒸馏自bert-large-uncased-whole-word-masking-finetuned-squad模型。该模型不区分大小写(如english和English视为相同)。
块剪枝的副作用是部分注意力头被完全移除:144个注意力头中有55个被剪除(38.2%)。下图展示了剪枝后剩余注意力头在网络中的分布情况。
SQuAD1.1数据集详情
数据集 |
划分 |
样本数 |
SQuAD1.1 |
训练集 |
90.6K |
SQuAD1.1 |
评估集 |
11.1k |
微调环境
CPU: Intel(R) Core(TM) i7-6700K
内存: 64 GiB
GPU: 1块GeForce GTX 3090(24GiB显存)
驱动版本: 455.23.05, CUDA: 11.1
性能结果
PyTorch模型文件大小: 374MB
(原始BERT: 420MB
)
指标 |
本模型值 |
原论文值(Table 2) |
差异 |
EM |
81.31 |
80.8 |
+0.51 |
F1 |
88.33 |
88.5 |
-0.17 |
使用示例
首先安装nn_pruning工具包(包含优化脚本,通过移除空行/列压缩线性层):
pip install nn_pruning
然后可近乎常规地使用transformers库,只需在加载管道后调用optimize_model
:
from transformers import pipeline
from nn_pruning.inference_model_patcher import optimize_model
qa_pipeline = pipeline(
"question-answering",
model="madlag/bert-base-uncased-squadv1-x1.96-f88.3-d27-hybrid-filled-opt-v1",
tokenizer="madlag/bert-base-uncased-squadv1-x1.96-f88.3-d27-hybrid-filled-opt-v1"
)
print("bert-base-uncased参数量: 191.0M")
print(f"当前参数量(仅含注意力头剪枝)={int(qa_pipeline.model.num_parameters() / 1E6)}M")
qa_pipeline.model = optimize_model(qa_pipeline.model, "dense")
print(f"完整优化后参数量={int(qa_pipeline.model.num_parameters() / 1E6)}M")
predictions = qa_pipeline({
'context': "弗雷德里克·弗朗索瓦·肖邦(1810年3月1日-1849年10月17日),原名弗里德里克·弗朗齐歇克·肖邦,是浪漫主义时期波兰作曲家和钢琴演奏家,主要创作独奏钢琴曲。",
'question': "弗雷德里克·肖邦是谁?",
})
print("预测结果", predictions)