库名称: transformers
许可证: apache-2.0
标签:
Jamba-v0.1-9B
这是Jamba-v0.1的密集版本,提取了第一个专家的权重。
它不再使用混合专家(MoE)架构。详情请参考此脚本。
该版本可使用单张3090/4090显卡进行推理,使用方法与Jamba-v0.1完全相同。
Jamba原模型说明
Jamba是最先进的混合SSM-Transformer架构大语言模型。它在传统Transformer模型基础上实现了吞吐量提升,同时在同类尺寸模型中,大多数常见基准测试表现优于或持平最佳模型。
作为首个生产级Mamba实现,Jamba为研究和应用开辟了新机遇。虽然当前实验已显示出积极效果,我们预期未来优化和探索将进一步增强这些优势。
本模型卡针对Jamba基础版本。这是一个预训练的混合专家(MoE)生成文本模型,激活参数120亿,所有专家总参数520亿。支持256K上下文长度,可在单张80GB GPU上处理高达140K令牌。
完整详情请参阅发布博客。
模型详情
- 开发机构: AI21
- 模型类型: 联合注意力与Mamba架构(Jamba)
- 许可证: Apache 2.0
- 上下文长度: 256K
- 知识截止日期: 2024年3月5日
使用指南
环境准备
需使用transformers
4.39.0或更高版本:
pip install transformers>=4.39.0
为运行优化的Mamba实现,需先安装mamba-ssm
和causal-conv1d
:
pip install mamba-ssm causal-conv1d>=1.2.0
注意模型需部署在CUDA设备上。
虽然可不使用优化内核运行(通过加载模型时指定use_mamba_kernels=False
),但会导致显著延迟增加,故不推荐。
运行模型
当前运行新架构需设置trust_remote_code=True
:
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1",
trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained("ai21labs/Jamba-v0.1")
input_ids = tokenizer("In the recent Super Bowl LVIII,", return_tensors='pt').to(model.device)["input_ids"]
outputs = model.generate(input_ids, max_new_tokens=216)
print(tokenizer.batch_decode(outputs))
半精度加载
发布的检查点以BF16格式保存。加载至内存时需指定torch_dtype
:
model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1",
trust_remote_code=True,
torch_dtype=torch.bfloat16)
半精度下可启用Attention块的FlashAttention2实现。由于模型尺寸超过单张80GB GPU容量,需使用accelerate进行并行化:
model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1",
trust_remote_code=True,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
device_map="auto")
8位量化加载
使用8位精度时,单张80GB GPU可支持最高140K序列长度。 可通过bitsandbytes进行量化。为避免质量下降,建议排除Mamba块的量化:
quantization_config = BitsAndBytesConfig(load_in_8bit=True,
llm_int8_skip_modules=["mamba"])
model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1",
trust_remote_code=True,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
quantization_config=quantization_config)
微调示例
Jamba作为基础模型可通过PEFT库进行微调:
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
args=training_args,
peft_config=lora_config,
train_dataset=dataset,
dataset_text_field="quote",
)
基准测试结果
测试项 |
得分 |
HellaSwag |
87.1% |
GSM8K (CoT) |
59.9% |
注意所有提示需添加'BOS'令牌,部分评估框架可能未默认启用。
注意事项
Jamba是未经指令/对话对齐的预训练基础模型。
作为基础模型,Jamba适用于作为微调和定制开发的底层架构。该模型未内置安全审核机制,实际使用时需自行添加保障措施。
关于AI21
AI21为企业构建可靠、实用且可扩展的AI解决方案。
Jamba是AI21新模型家族的首个成员,其指令调优版本可通过AI21平台获取测试版。