库名称:transformers
许可证:其他
许可证名称:NVIDIA开放模型许可证
许可证链接:https://developer.download.nvidia.com/licenses/nvidia-open-model-license-agreement-june-2024.pdf
任务类型:文本生成
Hymba-1.5B-Base
üíæ GitHub   |    üìÑ 论文 |    üìú 博客  
模型概述
Hymba-1.5B-Base 是一款基础文本生成模型,适用于多种自然语言生成任务。
该模型采用混合架构,结合了并行运行的Mamba和注意力头。元标记(一组可学习的标记,预置于每个提示前)有助于提升模型效能。模型在两层之间以及单层内的头之间共享KV缓存。90%的注意力层采用滑动窗口注意力机制。
此模型已开放商用。
模型开发者: NVIDIA
训练时间: Hymba-1.5B-Base 的训练时间为2024年9月1日至2024年11月10日。
许可证:
本模型依据 NVIDIA开放模型许可证协议 发布。
模型架构
‚ö°Ô∏è 我们已在GitHub上发布了Hymba的极简实现,帮助开发者理解其设计原理并应用于自己的模型中。查看详情:barebones-hymba。
Hymba-1.5B-Base 的嵌入维度为1600,包含25个注意力头,MLP中间维度为5504,总层数为32层,其中16个SSM状态层、3个全注意力层,其余为滑动窗口注意力层。与标准Transformer不同,Hymba的每个注意力层均采用标准注意力头与Mamba头并行的混合设计。此外,模型还使用了分组查询注意力(GQA)和旋转位置嵌入(RoPE)。
架构特点:
- 在同一层内融合注意力头和SSM头,实现对相同输入的并行互补处理。
性能亮点
- Hymba-1.5B-Base 在2B参数以下的公开模型中表现最优。
模型使用
步骤1:环境配置
由于Hymba-1.5B-Base采用了依赖PyTorch 2.5的FlexAttention,我们提供两种环境配置方式:
- 本地安装:使用提供的
setup.sh
安装相关依赖(支持CUDA 12.1/12.4):
wget --header="Authorization: Bearer YOUR_HF_TOKEN" https://huggingface.co/nvidia/Hymba-1.5B-Base/resolve/main/setup.sh
bash setup.sh
- Docker:提供已安装所有依赖的Docker镜像,运行以下命令启动容器:
docker pull ghcr.io/tilmto/hymba:v1
docker run --gpus all -v /home/$USER:/home/$USER -it ghcr.io/tilmto/hymba:v1 bash
步骤2:与Hymba-1.5B-Base对话
环境配置完成后,可通过以下脚本与模型交互:
from transformers import LlamaTokenizer, AutoModelForCausalLM, AutoTokenizer, AutoModel
import torch
repo_name = "nvidia/Hymba-1.5B-Base"
tokenizer = AutoTokenizer.from_pretrained(repo_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(repo_name, trust_remote_code=True)
model = model.cuda().to(torch.bfloat16)
prompt = input()
inputs = tokenizer(prompt, return_tensors="pt").to('cuda')
outputs = model.generate(**inputs, max_length=64, do_sample=False, temperature=0.7, use_cache=True)
response = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
print(f"模型回复: {response}")
微调Hymba
LMFlow 是完整的语言模型微调工具链。以下示例展示如何使用LMFlow微调Hymba-1.5B-Base
模型:
-
使用Docker
docker pull ghcr.io/tilmto/hymba:v1
docker run --gpus all -v /home/$USER:/home/$USER -it ghcr.io/tilmto/hymba:v1 bash
-
安装LMFlow
git clone https://github.com/OptimalScale/LMFlow.git
cd LMFlow
conda create -n lmflow python=3.9 -y
conda activate lmflow
conda install mpi4py
pip install -e .
-
运行微调命令
cd LMFlow
bash ./scripts/run_finetune_hymba.sh
使用LMFlow时,您还可以在自定义数据集上微调模型,只需将数据转换为LMFlow数据格式。除全参数微调外,还可通过DoRA、LoRA、LISA、Flash Attention等技术高效微调。详情请参阅LMFlow for Hymba文档。
评估
我们使用LM Evaluation Harness
进行评估,命令如下:
git clone --depth 1 https://github.com/EleutherAI/lm-evaluation-harness
git fetch --all --tags
git checkout tags/v0.4.4
cd lm-evaluation-harness
pip install -e .
lm_eval --model hf --model_args pretrained=nvidia/Hymba-1.5B-Base,dtype=bfloat16,trust_remote_code=True \
--tasks mmlu \
--num_fewshot 5 \
--batch_size 1 \
--output_path ./hymba_HF_base_lm-results \
--log_samples
lm_eval --model hf --model_args pretrained=nvidia/Hymba-1.5B-Base,dtype=bfloat16,trust_remote_code=True \
--tasks arc_easy,arc_challenge,piqa,winogrande,hellaswag \
--num_fewshot 0 \
--batch_size 1 \
--output_path ./hymba_HF_base_lm-results \
--log_samples
lm_eval --model hf --model_args pretrained=nvidia/Hymba-1.5B-Base,dtype=bfloat16,trust_remote_code=True \
--tasks squad_completion \
--num_fewshot 1 \
--batch_size 1 \
--output_path ./hymba_HF_base_lm-results \
--log_samples
局限性
模型训练数据包含从互联网爬取的有毒语言、不安全内容和社会偏见,因此可能放大这些偏见并在有毒提示下生成有害回复。即使提示本身不含攻击性内容,模型仍可能生成不准确、遗漏关键信息或包含无关/冗余文本的答案。
测试表明该模型易受越狱攻击影响。若在RAG或代理场景中使用,建议实施严格的输出验证控制,确保用户可控输出符合预期用途的安全要求。
伦理考量
NVIDIA认为可信AI是共同责任,我们已制定政策与实践以支持广泛AI应用开发。开发者下载或使用本模型时,应与其内部模型团队协作,确保模型符合相关行业要求并防范意外滥用。
请通过此链接报告安全问题或AI伦理问题。
引用
@misc{dong2024hymbahybridheadarchitecturesmall,
title={Hymba: A Hybrid-head Architecture for Small Language Models},
author={Xin Dong and Yonggan Fu and Shizhe Diao and Wonmin Byeon and Zijia Chen and Ameya Sunil Mahabaleshwarkar and Shih-Yang Liu and Matthijs Van Keirsbilck and Min-Hung Chen and Yoshi Suhara and Yingyan Lin and Jan Kautz and Pavlo Molchanov},
year={2024},
eprint={2411.13676},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/2411.13676},
}