license: apache-2.0
language:
- en
- zh
base_model:
- deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B
- BlinkDL/rwkv-7-world
pipeline_tag: text-generation
library_name: transformers
ARWKV🪿
论文链接👁️ | Github✅
ARWKV-R1-1B5 (预览版 0.1)
预览版采用RWKV-7时间混合与Transformer MLP架构
📌 概述
一切尽在RWKV
这是基于RNN的70亿参数模型的早期预览版,通过从DeepSeek-R1-Distill-Qwen-1.5B进行三阶段知识蒸馏训练而成(仅应用第二阶段,未进行SFT或DPO),上下文长度为2k。作为基础版本,它展示了:
- ✅ RWKV-7的高效循环机制
- ✅ 无自注意力,完全O(n)复杂度
- ✅ 恒定显存占用
- ✅ 单GPU可训练
路线图说明:我们将很快开源不同增强版本,包括:
- 🚀 16k+上下文支持
- 🧮 数学专项优化
- 📚 强化学习增强的推理模型
使用方法
pip3 install --upgrade rwkv-fla transformers
训练前设置:export WKV_MODE=chunk
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained(
"RWKV-Red-Team/ARWKV-R1-1B5",
device_map="auto",
torch_dtype=torch.float16,
trust_remote_code=True,
)
tokenizer = AutoTokenizer.from_pretrained(
"RWKV-Red-Team/ARWKV-R1-1B5"
)
system_prompt = "你是一个世界级的 trivia AI —— 请提供准确、简洁的回答。"
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": prompt}]
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
text = text + "<think>"
print(text)
model_inputs = tokenizer([text], return_tensors="pt").to(device)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=False, skip_special_tokens=False)
generation_kwargs = dict(model_inputs, streamer=streamer, max_new_tokens=8192, do_sample=True,tokenizer=tokenizer,stop_strings=[""])
thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
print("流式输出:")
for new_text in streamer:
print(new_text, end="", flush=True)
thread.join()
输出示例:
流式输出:
🔑 核心特性
组件 |
规格 |
备注 |
架构 |
RWKV-7时间混合 + SwiGLU |
混合设计 |
上下文窗口 |
2048训练CTX |
预览版限制 |
训练token数 |
4000万 |
蒸馏为主 |
精度 |
推荐FP16推理(需16G显存) |
比BF16快15%↑ |
🏗️ 架构亮点
核心修改流程
Transformer解码层:
- 多头潜在注意力(MLA)
+ RWKV-7时间混合 (公式3)
- RoPE位置编码
+ 状态循环机制
= 混合层输出
应用案例