许可协议:Apache-2.0
数据集:
- ZTE-AIM/Curr-ReFT-data
基础模型:
- Qwen/Qwen2.5-VL-3B-Instruct
- Qwen/Qwen2.5-VL-7B-Instruct
任务标签:图文生成
Curr-ReFT数据集
[📂 GitHub]
[🤗 HF数据集]
Curr-ReFT模型
[🤗 3B版模型]
[🤗 7B版模型]
模型概述
本模型是基于Qwen2.5-VL通过创新性Curr-ReFT方法微调的多模态大语言模型。训练过程分为两个阶段:先通过课程强化学习逐步提升任务复杂度,再基于拒绝样本进行自我优化以保持基础能力。该模型显著提升了视觉语言理解与推理能力,特别适用于视觉推理、精细图像理解和多模态问题求解等复杂任务。凭借强大的多模态推理能力,Curr-ReFT成为能应对跨领域挑战的智能助手,具有更高的准确性和情境感知力。
训练配置
- 框架:采用开源R1-V库,以Qwen2.5-VL-Instruct为基础模型,提供3B/7B两种规格
梯度反转优化配置如下:
最大像素值 401408
单设备训练批次大小:1
梯度累积步数:1
学习率:1.0e-5
训练周期数:1.0
学习率调度器:余弦衰减
启用BF16:是
注意力机制:flash_attn fa2
使用方式
可通过Hugging Face的transformers
库加载模型:
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
import torch
from qwen_vl_utils import process_vision_info
MODEL_ID = "Curr-ReFT-3B"
processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
MODEL_ID,
trust_remote_code=True,
torch_dtype=torch.bfloat16
).to("cuda").eval()
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": "<图片路径>"},
{"type": "text", "text": "提示:请回答问题并在最后给出最终答案。问题:最后一朵雏菊上应该写哪个数字?"},
],
}
]
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
inputs = inputs.to(model.device)
generated_ids = model.generate(**inputs, max_new_tokens=4096)
generated_ids_trimmed = [
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
print(output_text)
研发机构
模型联系人
- huilin_deng@mail.ustc.edu.cn
- zoudinghust@gmail.com
- 214711069@csu.edu.cn