library_name: transformers
tags: []
Mamba模型
本仓库包含与transfromers
兼容的mamba-2.8b
模型。检查点文件保持原样,完整的config.json
配置和分词器已推送至该仓库。
使用指南
在transformers=4.39.0
正式发布前,需通过以下命令安装main
分支版本:
pip install git+https://github.com/huggingface/transformers@main
建议同时安装causal_conv_1d
和mamba-ssm
优化组件:
pip install causal-conv1d>=1.2.0
pip install mamba-ssm
若未安装上述组件,系统将自动回退至"eager"基础实现模式;安装后则可启用优化过的cuda
内核。
文本生成
可使用标准generate
接口:
>>> from transformers import MambaConfig, MambaForCausalLM, AutoTokenizer
>>> import torch
>>> tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-130m-hf")
>>> model = MambaForCausalLM.from_pretrained("state-spaces/mamba-130m-hf")
>>> input_ids = tokenizer("最近过得怎么样?", return_tensors="pt")["input_ids"]
>>> out = model.generate(input_ids, max_new_tokens=10)
>>> print(tokenizer.batch_decode(out))
["最近过得怎么样?\n\n很高兴见到你。"]
PEFT微调示例
使用peft
库微调时,建议保持模型为float32精度:
from datasets import load_dataset
from trl import SFTTrainer
from peft import LoraConfig
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-130m-hf")
model = AutoModelForCausalLM.from_pretrained("state-spaces/mamba-130m-hf")
dataset = load_dataset("Abirate/english_quotes", split="train")
training_args = TrainingArguments(
output_dir="./results",
num_train_epochs=3,
per_device_train_batch_size=4,
logging_dir='./logs',
logging_steps=10,
learning_rate=2e-3
)
lora_config = LoraConfig(
r=8,
target_modules=["x_proj", "embeddings", "in_proj", "out_proj"],
task_type="CAUSAL_LM",
bias="none"
)
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
args=training_args,
peft_config=lora_config,
train_dataset=dataset,
dataset_text_field="quote",
)
trainer.train()