语言:
- 荷兰语
许可证: cc-by-nc-4.0
标签:
- 对齐手册
- trl
- dpo
- geitje
- 对话式
基础模型: BramVanroy/GEITje-7B-ultra-sft
数据集:
- BramVanroy/ultra_feedback_dutch
管道标签: 文本生成
推理: false
模型索引:
- 名称: BramVanroy/GEITje-7B-ultra
结果: []
GEITje 7B ultra
一个通过AI反馈对齐的荷兰语对话模型。
该模型是基于BramVanroy/GEITje-7B-ultra-sft的微调版本,使用了一个约5600万token的合成DPO数据集,该数据集由gpt-4-turbo和Rijgersberg/GEITje-7B-chat为荷兰语生成。
[!提示]
🚀 寻找快速的GGUF版本?你可以在这里找到它,并了解如何使用ollama
,点击此处。🚀
引用
如果你使用GEITje 7B Ultra(SFT)或其任何衍生或量化版本,请引用以下论文:
@misc{vanroy2024geitje7bultraconversational,
title={GEITje 7B Ultra: 荷兰语对话模型},
author={Bram Vanroy},
year={2024},
eprint={2412.04092},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/2412.04092},
}
模型描述
这是一个基于Mistral的荷兰语指令/聊天模型,通过DPO与AI反馈对齐。它是SFT训练的BramVanroy/GEITje-7B-ultra-sft的DPO延续,而后者又基于Rijgersberg/GEITje-7B,该模型基于Mistral 7B并进一步在荷兰语数据上进行了预训练。在(相当简单的)基准测试中,它平均表现优于所有原始GEITje模型(但差距不大)。然而,请注意,这些基准测试应持保留态度(参见该页面基准测试下方的免责声明)。最佳评估方式是亲自尝试这些模型并自行判断。
使用方法
一次性使用:
from transformers import pipeline, Conversation
chatbot = pipeline("conversational", model="BramVanroy/GEITje-7B-ultra", model_kwargs={"load_in_8bit": True}, device_map="auto")
start_messages = [
{"role": "system", "content": "你是一个名叫Bert的有趣聊天机器人,经常讲笑话。"},
{"role": "user", "content": "你好,我是Bram。今晚我想看电影,你有什么推荐吗?"}
]
conversation = Conversation(start_messages)
conversation = chatbot(conversation)
response = conversation.messages[-1]["content"]
print(response)
交互式对话:
from transformers import pipeline, Conversation
chatbot = pipeline("conversational", model="BramVanroy/GEITje-7B-ultra", model_kwargs={"load_in_8bit": True, "attn_implementation": "flash_attention_2"}, device_map="auto")
while (system_message := input("系统消息(输入'q'退出):")) != "q":
start_messages = [
{"role": "system", "content": system_message},
]
conversation = Conversation(start_messages)
while (user_input := input("用户(输入'r'重置):")) != "r":
conversation.add_user_input(user_input)
conversation = chatbot(conversation)
response = conversation.messages[-1]["content"]
print("助手:", response)
预期用途与限制
尽管该模型已与具有强大内容过滤器的gpt-4-turbo输出对齐,但模型仍可能生成错误、误导甚至可能冒犯的内容。使用风险自负。
由于该模型是在使用OpenAI/Azure服务创建的合成数据上训练的,因此该模型不能用于商业用途。
训练与评估数据
训练数据包括基于UltraFeedback binarized的合成数据集,由gpt-4-turbo和geitje-chat创建。从原始数据集翻译的给定提示被提供给两个模型,然后生成答案。然后,gpt-4-turbo总是被选为最佳答案,DPO将优化这一点。虽然这并不完全公平,但我没有预算让gpt-4对两个回复进行评分。此外,尽管是一个令人印象深刻的模型,但在我进行的测试中,GEITje chat似乎仍落后于gpt-4-turbo。
数据集总共包含56,137,090个token(提示+拒绝+选择的组合),测试集包含6,178,969个token(11.00%)。
训练过程
使用了优秀的对齐手册进行训练,并使用自定义的slurm脚本以适应我们的集群。它是完整训练的,没有使用LoRA或其他适配器。
模型在bfloat16和flash attention 2下训练,在两个节点上各使用四个A100 80GB GPU,耗时约11小时。感谢Flemish Super Computer提供的计算资源。
对于对话使用,模型依赖于Zephyr聊天模板,该系统消息兼容。*-sft的一小部分数据包含系统消息,因此假设模型至少能处理一些系统消息。
在早期迭代中,我发现使用对齐手册的默认值(beta=0.01)会导致不良结果(随机token的幻觉)。经过调查,似乎如此低的beta不适用于此数据集,因为它给了模型太多偏离其初始基础模型的空间。经过超参数搜索和手动分析结果指标,我选择了当前模型作为最佳模型,beta为0.1。
使用手册的配方:
model_name_or_path: BramVanroy/GEITje-7B-ultra-sft
model_revision: main
torch_dtype: bfloat16
use_flash_attention_2: true
dataset_mixer:
BramVanroy/ultra_feedback_dutch: 1.0
dataset_splits:
- train_prefs
- test_prefs
preprocessing_num_workers: 8
bf16: true
beta: 0.1
do_eval: true
evaluation_strategy: steps
eval_steps: 100
gradient_accumulation_steps: 4
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: False
hub_model_id: BramVanroy/GEITje-ultra
learning_rate: 5.0e-7
log_level: info
logging_steps: 10
lr_scheduler_type: cosine
max_length: 2048
max_prompt_length: 1536
num_train_epochs: 1
optim: adamw_torch
output_dir: data/GEITje-ultra
per_device_train_batch_size: 4
per_device_eval_batch_size: 4
push_to_hub: true
save_strategy: "steps"
save_steps: 100
save_total_limit: 3
seed: 42
warmup_ratio: 0.1
训练超参数
训练期间使用的超参数:
- 学习率:5e-07
- 训练批次大小:4
- 评估批次大小:4
- 种子:42
- 分布式类型:多GPU
- 设备数量:8
- 梯度累积步数:4
- 总训练批次大小:128
- 总评估批次大小:32
- 优化器:Adam,betas=(0.9,0.999),epsilon=1e-08
- 学习率调度器类型:cosine
- 学习率调度器预热比例:0.1
- 训练轮数:1.0
训练结果
训练损失 |
轮数 |
步数 |
验证损失 |
奖励/选择 |
奖励/拒绝 |
奖励/准确率 |
奖励/边际 |
对数概率/拒绝 |
对数概率/选择 |
对数/拒绝 |
对数/选择 |
0.03 |
0.22 |
100 |
0.0260 |
-0.9740 |
-9.8635 |
0.9913 |
8.8895 |
-524.8940 |
-508.1891 |
-3.0753 |
-3.0315 |
0.0184 |
0.44 |
200 |
0.0164 |
-1.7162 |
-12.4772 |
0.9926 |
10.7610 |
-551.0317 |
-515.6115 |
-3.0349 |
-2.9873 |
0.0121 |
0.66 |
300 |
0.0142 |
-2.0575 |
-13.6818 |
0.9938 |
11.6244 |
-563.0778 |
-519.0242 |
-3.0325 |
-2.9835 |
0.0198 |
0.88 |
400 |
0.0139 |
-2.1431 |
-13.8857 |
0.9950 |
11.7426 |
-565.1163 |
-519.8801 |
-3.0293 |
-2.9801 |
框架版本
- Transformers 4.36.2
- Pytorch 2.1.2+cu121
- Datasets 2.14.6
- Tokenizers 0.15.0
英语开放LLM排行榜的结果。有关荷兰语的特定结果,请查看ScandEval。
详细结果可在此处找到点击此处
指标 |
值 |
平均 |
10.91 |
IFEval (0-Shot) |
37.23 |
BBH (3-Shot) |
12.88 |
MATH Lvl 5 (4-Shot) |
0.91 |
GPQA (0-shot) |
1.68 |
MuSR (0-shot) |
1.52 |
MMLU-PRO (5-shot) |
11.24 |