库名称: transformers
许可证: apache-2.0
标签:
这是Jamba模型的基础版本。我们现已发布更优的指令调优版本Jamba-1.5-Mini。如需更强性能,请查看升级版Jamba-1.5-Large。
Jamba模型卡
Jamba是最先进的混合SSM-Transformer大语言模型。它在传统Transformer模型基础上实现了吞吐量提升,同时在多数常见基准测试中达到或超越同规模领先模型水平。
作为首个生产级Mamba实现,Jamba为研究和应用开辟了新机遇。虽然初步实验已显示出积极效果,我们预期未来优化和探索将带来更大提升。
本模型卡针对Jamba基础版。这是一个预训练的混合专家(MoE)生成文本模型,激活参数120亿,总参数520亿(含所有专家)。支持256K上下文长度,可在单块80GB GPU上处理最多140K tokens。
完整细节请参阅白皮书和发布博客。
模型详情
- 开发方: AI21
- 模型类型: 联合注意力与Mamba架构(Jamba)
- 许可证: Apache 2.0
- 上下文长度: 256K
- 知识截止日期: 2024年3月5日
使用指南
前置要求
建议使用transformers
4.40.0或更高版本(最低要求4.39.0):
pip install transformers>=4.40.0
如需运行优化版Mamba实现,需先安装:
pip install mamba-ssm causal-conv1d>=1.2.0
模型需运行在CUDA设备上。
虽可不使用优化内核运行模型,但不推荐,这将导致显著延迟。如需关闭优化,加载模型时需指定use_mamba_kernels=False
。
运行模型
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1")
tokenizer = AutoTokenizer.from_pretrained("ai21labs/Jamba-v0.1")
input_ids = tokenizer("在最近的第58届超级碗中,", return_tensors='pt').to(model.device)["input_ids"]
outputs = model.generate(input_ids, max_new_tokens=216)
print(tokenizer.batch_decode(outputs))
注意:若使用transformers<4.40.0
,运行新架构需设置trust_remote_code=True
。
半精度加载
发布的检查点以BF16格式保存。如需以BF16/FP16加载至内存:
model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1",
torch_dtype=torch.bfloat16)
半精度下可启用FlashAttention2实现。由于模型过大无法单卡运行,需使用accelerate并行:
model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1",
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
device_map="auto")
8位量化加载
8位精度下,单块80GB GPU可支持140K序列长度。使用bitsandbytes量化时,建议排除Mamba模块以保持质量:
quantization_config = BitsAndBytesConfig(load_in_8bit=True,
llm_int8_skip_modules=["mamba"])
model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1",
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
quantization_config=quantization_config)
微调示例
Jamba作为基础模型可针对定制需求微调(包括对话/指令版本)。以下是使用PEFT库的微调示例(约需120GB GPU显存,示例为2xA100 80GB):
lora_config = LoraConfig(
r=8,
target_modules=[
"embed_tokens",
"x_proj", "in_proj", "out_proj",
"gate_proj", "up_proj", "down_proj",
"q_proj", "k_proj", "v_proj"
],
task_type="CAUSAL_LM",
bias="none"
)
基准测试结果
测试项 |
得分 |
HellaSwag |
87.1% |
Arc挑战赛 |
64.4% |
WinoGrande |
82.5% |
PIQA |
83.2% |
MMLU |
67.4% |
BBH |
45.4% |
TruthfulQA |
46.4% |
GSM8K (思维链) |
59.9% |
注意所有提示需添加'BOS'标记,部分评估框架可能未默认启用。
注意事项
Jamba是预训练基础模型,未针对指令/对话场景进行对齐。
作为基础模型,Jamba适用于作为微调、训练和开发定制解决方案的基础层。本模型未内置安全审核机制,使用时需自行添加保障措施。
关于AI21
AI21为企业构建可靠、实用、可扩展的AI解决方案。
Jamba是AI21新模型系列的首个产品,其指令版本即将登陆AI21平台。