语言:
LLM-jp-3 VILA 14B
本仓库提供由日本国立情报学研究所下属大语言模型研发中心开发的大型视觉语言模型(VLM)。
使用说明
Python版本要求: 3.10.12
-
克隆仓库并安装依赖库
git clone git@github.com:llm-jp/llm-jp-VILA.git
cd llm-jp-VILA
python3 -m venv venv
source venv/bin/activate
pip install --upgrade pip
wget https://github.com/Dao-AILab/flash-attention/releases/download/v2.4.2/flash_attn-2.4.2+cu118torch2.0cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
pip install flash_attn-2.4.2+cu118torch2.0cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
pip install -e .
pip install -e ".[train]"
pip install git+https://github.com/huggingface/transformers@v4.36.2
cp -rv ./llava/train/transformers_replace/* ./venv/lib/python3.10/site-packages/transformers/
-
运行Python脚本。可自定义修改image_path
和query
参数
import argparse
from io import BytesIO
import requests
import torch
from PIL import Image
from llava.constants import IMAGE_TOKEN_INDEX
from llava.conversation import conv_templates
from llava.mm_utils import (get_model_name_from_path,
process_images, tokenizer_image_token)
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
def load_image(image_file):
if image_file.startswith("http") or image_file.startswith("https"):
response = requests.get(image_file)
image = Image.open(BytesIO(response.content)).convert("RGB")
else:
image = Image.open(image_file).convert("RGB")
return image
def load_images(image_files):
out = []
for image_file in image_files:
image = load_image(image_file)
out.append(image)
return out
disable_torch_init()
model_checkpoint_path = "llm-jp/llm-jp-3-vila-14b"
model_name = get_model_name_from_path(model_checkpoint_path)
tokenizer, model, image_processor, context_len = load_pretrained_model(model_checkpoint_path, model_name)
image_path = "图片路径"
image_files = [
image_path
]
images = load_images(image_files)
query = "<image>\n请描述这张图片。"
conv_mode = "llmjp_v3"
conv = conv_templates[conv_mode].copy()
conv.append_message(conv.roles[0], query)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
images_tensor = process_images(images, image_processor, model.config).to(model.device, dtype=torch.float16)
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).cuda()
with torch.inference_mode():
output_ids = model.generate(
input_ids,
images=[
images_tensor,
],
do_sample=False,
num_beams=1,
max_new_tokens=256,
use_cache=True,
)
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
print(outputs)
模型架构
训练数据
模型训练分为三个阶段:
阶段0
使用以下数据集调整投影层参数:
阶段1
使用以下数据集调整投影层和LLM参数:
阶段2
使用以下数据集进行最终微调:
评估结果
我们在Heron Bench、JA-VLM-Bench-In-the-Wild和JA-VG-VQA500基准测试上进行了评估,使用gpt-4o-2024-05-13
作为评判标准。
Heron基准
JA-VLM野外基准
JA-VG视觉问答500
风险与限制
本仓库发布的模型仍处于研发早期阶段,其输出内容尚未完全符合社会规范、伦理标准及法律法规。
许可协议
模型权重基于Apache 2.0许可证发布。由于模型使用了OpenAI GPT-4生成的合成数据,使用者还需遵守OpenAI使用条款。
补充说明
关于synthdog-ja数据集的许可:虽然我们尝试联系《OCR-free Document Understanding Transformer》论文通讯作者确认,但未获回复。基于以下两点:
- 基于该数据集训练的donut-base模型采用MIT许可证
- 数据集使用的维基百科内容遵循CC-BY-SA协议
我们推定该数据集应遵循CC-BY-SA许可协议,并依此开展训练。