数据集:
- spider
- spider-Syn
评估指标:
- 精确匹配
语言:
- 英文
结果:
- 任务:
类型: 文本转SQL
名称: 文本到SQL
数据集:
类型: spider
名称: Spider
拆分: 验证集
指标:
管道标签: 文本到文本生成
标签:
T5大型语言模型适配器:文本到SQL转换
该模型旨在根据自然语言提示生成结构化的SQL查询语句。
简介
在Text2SQL任务中,模型学习如何基于自然语言提出的问题生成SQL查询。然而在某些情况下,生成的SQL查询可能包含未知列名等问题,且未充分考虑特定数据库的表结构。
这正是我们解决方案的创新点——在训练过程中,我们将数据库表结构整合至输入问题中,明确指定可用的数据列和关联关系,从而生成可实际执行的SQL查询。
通过将数据库表结构与问题提示共同输入,模型能够学习表结构与预期输出之间的映射关系。这种设计使模型能够更好地泛化至训练数据中未出现过的数据库结构。
基础模型
本模型基于t5-large-LM-adapt检查点进行微调。
Spider与Spider-Syn数据集
模型使用Spider和Spider-Syn数据集的训练集进行微调。不同于仅使用问题文本的传统方法,我们将数据库表结构信息与问题结合输入,使模型能够针对给定数据库生成查询语句。
输入示例:
问题: 所有法国音乐家的平均、最小和最大年龄是多少?
表结构: "体育场" "场馆编号" 整型 , "位置" 文本 , "名称" 文本 , "容量" 整型 , "最高值" 整型 , "最低值" 整型 ,
"平均值" 整型 , 外键: 主键: "场馆编号" [分隔符] "歌手" "歌手编号" 整型 , "姓名" 文本 , "国籍" 文本 ,
"歌曲名称" 文本 , "歌曲发行年份" 文本 , "年龄" 整型 , "性别" 布尔型 ,
外键: 主键: "歌手编号" [分隔符],
"演唱会" "演出编号" 整型 , "演唱会名称" 文本 , "主题" 文本 , "年份" 文本 , 外键: "场馆编号" 文本 来自 "体育场",
"场馆编号" , 主键: "演出编号" [分隔符] "演唱会歌手关联表",
外键: "演出编号" 整型 来自 "演唱会",
"演出编号" , "歌手编号" 文本 来自 "歌手" "歌手编号" , 主键: "演出编号" "歌手编号"
预期输出:
SELECT avg(年龄), min(年龄), max(年龄) FROM 歌手 WHERE 国籍 = '法国'
执行评估时,我们查询_SQLite_数据库获得结果:
[[34.5, 25, 43]]
数据库表结构格式
模型训练使用的标准化表结构表示:
表名 列1名称 列1类型 列2名称 列2类型 ... 外键: 外键名 外键类型 来自 表名 列名 主键: 列名 [分隔符]
表名2 ...
使用方式
以下是通过🤗 Transformers库在PyTorch中使用本模型的示例:
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
模型路径 = 'gaussalgo/T5-LM-Large-text2sql-spider'
模型 = AutoModelForSeq2SeqLM.from_pretrained(模型路径)
分词器 = AutoTokenizer.from_pretrained(模型路径)
问题 = "所有法国音乐家的平均、最小和最大年龄是多少?"
表结构 = """
"体育场" "场馆编号" 整型 , "位置" 文本 , "名称" 文本 , "容量" 整型 , "最高值" 整型 , "最低值" 整型 , "平均值" 整型 , 外键: 主键: "场馆编号" [分隔符] "歌手" "歌手编号" 整型 , "姓名" 文本 , "国籍" 文本 , "歌曲名称" 文本 , "歌曲发行年份" 文本 , "年龄" 整型 , "性别" 布尔型 , 外键: 主键: "歌手编号" [分隔符] "演唱会" "演出编号" 整型 , "演唱会名称" 文本 , "主题" 文本 , "年份" 文本 , 外键: "场馆编号" 文本 来自 "体育场" "场馆编号" , 主键: "演出编号" [分隔符] "演唱会歌手关联表" 外键: "演出编号" 整型 来自 "演唱会" "演出编号" , "歌手编号" 文本 来自 "歌手" "歌手编号" , 主键: "演出编号" "歌手编号"
"""
输入文本 = " ".join(["问题: ",问题, "表结构:", 表结构])
模型输入 = 分词器(输入文本, return_tensors="pt")
输出 = 模型.generate(**模型输入, max_length=512)
输出文本 = 分词器.batch_decode(输出, skip_special_tokens=True)
print("SQL查询:")
print(输出文本)
输出结果:
SQL查询:
SELECT avg(年龄), min(年龄), max(年龄) FROM 歌手 WHERE 国籍 = '法国'
评估
模型在Spider和Spider-syn数据集的开发集上进行评估。开发集包含的数据库与训练集完全无交集,确保模型在训练过程中未接触过评估用的数据库结构。
Spider和Spider-Syn开发集各包含1032个样本。
- Spider开发集准确率: 49.2%
- Spider Syn开发集准确率: 39.5%
训练过程
模型使用Adaptor库0.2.1版本训练,针对Spider和Spider-syn数据集的训练集采用以下参数配置:
训练参数 = 自适应参数(输出目录="训练目录",
学习率=5e-5,
停止策略=停止策略.全目标收敛,
停止耐心值=8,
保存总数限制=8,
是否训练=True,
是否评估=True,
启用bf16=True,
预热步数=1000,
梯度累积步数=8,
日志记录步长=10,
评估步长=200,
保存步长=1000,
训练轮数=10,
评估策略="分步执行")
虽然训练过程较易复现,但我们不计划公开修改后的Spider数据集副本。如需进一步了解,欢迎通过提交PR或发送邮件至stefanik@gaussalgo.com与我们联系。