🚀 基于 RoBERTa 微调的医学问诊意图识别模型
本项目是中科围绕心理健康大模型研发的对话导诊系统中的意图识别任务,能对用户输入的 query
文本进行意图识别,判断是【问诊】还是【闲聊】,为医学对话交互提供精准支持。
🚀 快速开始
单样本推理示例
from transformers import AutoTokenizer
from transformers import AutoModelForSequenceClassification
ID2LABEL = {0: "闲聊", 1: "问诊"}
MODEL_NAME = 'HZhun/RoBERTa-Chinese-Med-Inquiry-Intention-Recognition-base'
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSequenceClassification.from_pretrained(
MODEL_NAME,
torch_dtype='auto'
)
query = '这孩子目前28岁,情绪不好时经常无征兆吐血,呼吸系统和消化系统做过多次检查,没有检查出结果,最近三天连续早晨出现吐血现象'
tokenized_query = tokenizer(query, return_tensors='pt')
tokenized_query = {k: v.to(model.device) for k, v in tokenized_query.items()}
outputs = model(**tokenized_query)
pred_id = outputs.logits.argmax(-1).item()
intent = ID2LABEL[pred_id]
print(intent)
终端结果
问诊
批次数据推理示例
from transformers import AutoTokenizer
from transformers import AutoModelForSequenceClassification
ID2LABEL = {0: "闲聊", 1: "问诊"}
MODEL_NAME = 'HZhun/RoBERTa-Chinese-Med-Inquiry-Intention-Recognition-base'
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, padding_side='left')
model = AutoModelForSequenceClassification.from_pretrained(
MODEL_NAME,
torch_dtype='auto'
)
query = [
'胃痛,连续拉肚子好几天了,有时候半夜还呕吐',
'腿上的毛怎样去掉,不用任何药学和医学器械',
'你好,感冒咳嗽用什么药?',
'你觉得今天天气如何?我感觉咱可以去露营了!'
]
tokenized_query = tokenizer(query, return_tensors='pt', padding=True, truncation=True)
tokenized_query = {k: v.to(model.device) for k, v in tokenized_query.items()}
outputs = model(**tokenized_query)
pred_ids = outputs.logits.argmax(-1).tolist()
intent = [ID2LABEL[pred_id] for pred_id in pred_ids]
print(intent)
终端结果
["问诊", "闲聊", "问诊", "闲聊"]
✨ 主要特性
- 精准意图识别:能准确判别用户输入文本是【问诊】还是【闲聊】意图。
- 数据融合构建:融合开源与内部垂域医学对话数据集,确保数据多样性。
- 微调预训练模型:基于
transformers
库微调 chinese - roberta - wwm - ext
模型,提升性能。
📦 安装指南
在 Featurize 在线平台实例 上,需手动安装以下库:
pip install transformers datasets evaluate accelerate
平台环境信息:
- CPU:6核 E5 - 2680 V4
- GPU:RTX3060,12.6GB显存
- 预装镜像:Ubuntu 20.04,Python 3.9/3.10,PyTorch 2.0.1,TensorFlow 2.13.0,Docker 20.10.10, CUDA 尽量维持在最新版本
📚 详细文档
项目简介
- 项目来源:中科(安徽)G60智慧健康创新研究院(以下简称 “中科”)围绕心理健康大模型研发的对话导诊系统,本项目为其中的意图识别任务。
- 模型用途:将用户输入对话系统中的
query
文本进行意图识别,判别其意向是【问诊】or【闲聊】。
数据描述
- 数据来源:由 Hugging Face 的开源对话数据集,以及中科内部的垂域医学对话数据集经过清洗和预处理融合构建而成。
- 数据划分:共计 6000 条样本,其中,训练集 4800 条,测试集1200 条,并在数据构建过程中确保了正负样例的平衡。
- 数据样例:
[
{
"query": "最近热门的5部电影叫什么名字",
"label": "nonmed"
},
{
"query": "关节疼痛,足痛可能是什么原因",
"label": "med"
},
{
"query": "最近出冷汗,肚子疼,恶心与呕吐,严重影响学习工作",
"label": "med"
}
]
训练方式
基于 Hugging Face 的 transformers
库对哈工大讯飞联合实验室 (HFL) 发布的 [chinese - roberta - wwm - ext](https://github.com/ymcui/Chinese - BERT - wwm) 中文预训练模型进行微调。
训练参数、效果与局限性
训练参数
{
output_dir: "output",
num_train_epochs: 2,
learning_rate: 3e-5,
lr_scheduler_type: "cosine",
per_device_train_batch_size: 16,
per_device_eval_batch_size: 16,
weight_decay: 0.01,
warmup_ratio: 0.02,
logging_steps: 0.01,
logging_strategy: "steps",
fp16: True,
eval_strategy: "steps",
eval_steps: 0.1,
save_strategy: 'epoch'
}
微调效果
数据集 |
准确率 |
F1分数 |
测试集 |
0.99 |
0.98 |
局限性
整体而言,微调后模型对于医学问诊的意图识别效果不错;但碍于本次用于模型训练的数据量终究有限且样本多样性欠佳,故在某些情况下的效果可能存在偏差。
🔧 技术细节
- 模型选择:选择
chinese - roberta - wwm - ext
预训练模型,因其在中文任务上有较好的表现,通过微调可适配医学问诊意图识别任务。
- 数据处理:融合开源与内部数据集,清洗和预处理确保数据质量,划分训练集和测试集保证模型泛化能力。
- 训练优化:使用
transformers
库进行微调,设置合适的训练参数,如学习率、批次大小等,提升模型性能。
📄 许可证
本项目采用 apache - 2.0
许可证。
📋 其他信息
属性 |
详情 |
模型类型 |
基于 RoBERTa 微调的医学问诊意图识别模型 |
训练数据 |
由 Hugging Face 的开源对话数据集和中科内部垂域医学对话数据集融合构建,共 6000 条样本 |
评估指标 |
混淆矩阵、准确率、F1分数 |
应用领域 |
医学 |