许可证:apache-2.0
数据集:
- cerebras/SlimPajama-627B
语言:
- en
TinyLlama-1.1B-v1.1
我们采用了与Llama 2完全相同的架构和分词器,这意味着TinyLlama可以即插即用地应用于许多基于Llama的开源项目中。此外,TinyLlama仅包含11亿参数,这种紧凑性使其能够满足多种对计算和内存占用有限制的应用需求。
概述
在本项目中,我们不仅训练了一个单一的TinyLlama模型,还首先在1.5万亿token的语料库上训练TinyLlama以获得基础语言能力。随后,我们通过三种不同的数据采样方式持续预训练,将这个模型转化为三个不同的模型。关于这一过程的视觉呈现,请参考下图。

预训练
由于这些问题(bug1、bug2),我们尝试重新训练TinyLlama以提供更好的模型。我们使用2T token训练模型,并将预训练分为三个阶段:1)基础预训练,2)特定领域的持续预训练,3)冷却阶段。
基础预训练
在这一初始阶段,我们仅使用SlimPajama训练模型以发展其常识推理能力。在基础预训练期间,模型训练了1.5T token。由于我们使用的集群每个节点配备4块A100-40G显卡,且仅在节点内分片模型权重,因此本次训练只能将批次大小设置为约180万。
特定领域的持续预训练
在这一阶段,我们引入了三种不同的语料库:SlimPajama(与第一阶段相同)、数学与代码(Starcoder和Proof Pile)以及中文(Skypile)。这种方法使我们能够开发出具备专项能力的三种变体模型。
在此阶段的前约60亿token中,我们线性增加了领域特定语料库(不包括SlimPajama,因为其比例与第一阶段保持一致)的采样比例。这种逐步增加采样的策略旨在调整预训练数据的分布,确保训练过程更加稳定。在采样比例调整阶段之后,我们继续以稳定的采样策略预训练模型,直至达到约1.85T token。
冷却阶段
在预训练结束时实施冷却阶段已成为实现更好模型收敛的关键技术。然而,由于我们一开始就采用了余弦学习率策略,因此很难像MiniCPM或DeepSeek那样调整冷却阶段的学习率。因此,我们尝试通过调整批次大小来实现冷却。具体来说,在冷却阶段,我们将批次大小从180万增加到720万,同时保持原始的余弦学习率计划。
TinyLlama模型家族
经过广泛而详细的预训练过程,我们现在发布了三个专用版本的模型:
- TinyLlama_v1.1:标准版本,用于通用目的。
- TinyLlama_v1.1_Math&Code:具备更强的数学和代码能力。
- TinyLlama_v1.1_Chinese:对中文有较好的理解能力。
数据
以下是我们在每个阶段的数据分布:
TinyLlama_v1.1
语料库 |
基础预训练 |
特定领域持续预训练 |
冷却阶段 |
Slimpajama |
100.0 |
100.0 |
100.0 |
TinyLlama_v1.1_math_code
语料库 |
基础预训练 |
特定领域持续预训练 |
冷却阶段 |
Slimpajama |
100.0 |
75.0 |
75.0 |
Starcoder |
- |
15.0 |
15.0 |
Proof_pile |
- |
10.0 |
10.0 |
TinyLlama_v1.1_chinese
语料库 |
基础预训练 |
特定领域持续预训练 |
冷却阶段 |
Slimpajama |
100.0 |
50.0 |
50.0 |
Skypile |
- |
50.0 |
50.0 |
使用方法
需要安装transformers>=4.31。更多信息请查看TinyLlama GitHub页面。
from transformers import AutoTokenizer
import transformers
import torch
model = "TinyLlama/TinyLlama_v1.1"
tokenizer = AutoTokenizer.from_pretrained(model)
pipeline = transformers.pipeline(
"text-generation",
model=model,
torch_dtype=torch.float16,
device_map="auto",
)
sequences = pipeline(
'TinyLlama项目旨在用3万亿token预训练一个11亿参数的Llama模型。通过适当的优化,我们可以在90天内使用16块A100-40G显卡完成训练🚀🚀。训练已于2023-09-01开始。',
do_sample=True,
top_k=10,
num_return_sequences=1,
repetition_penalty=1.5,
eos_token_id=tokenizer.eos_token_id,
max_length=500,
)
for seq in sequences:
print(f"结果: {seq['generated_text']}")
评估
模型 |
预训练Token |
HellaSwag |
Obqa |
WinoGrande |
ARC_c |
ARC_e |
BoolQ |
Piqa |
平均 |
Pythia-1.0B |
300B |
47.16 |
31.40 |
53.43 |
27.05 |
48.99 |
60.83 |
69.21 |
48.30 |
TinyLlama-1.1B-intermediate-step-1431k-3T |
3T |
59.20 |
36.00 |
59.12 |
30.12 |
55.25 |
57.83 |
73.29 |
52.99 |
TinyLlama-1.1B-v1.1 |
2T |
61.47 |
36.80 |
59.43 |
32.68 |
55.47 |
55.99 |
73.56 |
53.63 |
TinyLlama-1.1B-v1_math_code |
2T |
60.80 |
36.40 |
60.22 |
33.87 |
55.20 |
57.09 |
72.69 |
53.75 |
TinyLlama-1.1B-v1.1_chinese |
2T |
58.23 |
35.20 |
59.27 |
31.40 |
55.35 |
61.41 |
73.01 |
53.41 |