语言:
- 英文
标签:
- pytorch
- 因果语言模型
- pythia
许可证: apache-2.0
数据集:
- Dahoas/synthetic-instruct-gptj-pairwise
该模型基于EleutherAI/pythia-2.8b-deduped
在Dahoas/synthetic-instruct-gptj-pairwise
数据集上微调而成。
您可以通过Lambda Cloud体验该模型的在线演示。
模型详情
运行要求
运行该模型推理约需7GB显存。
快速开始
import torch
from transformers import AutoTokenizer, pipeline, StoppingCriteria, StoppingCriteriaList
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
model_name = "lambdalabs/pythia-2.8b-deduped-synthetic-instruct"
max_new_tokens = 2048
stop_token = "<|stop|>"
class 关键词停止标准(StoppingCriteria):
def __init__(self, 关键词id列表: list):
self.关键词 = 关键词id列表
def __call__(
self, 输入id: torch.LongTensor, 分数: torch.FloatTensor, **kwargs
) -> bool:
if 输入id[0][-1] in self.关键词:
return True
return False
分词器 = AutoTokenizer.from_pretrained(
model_name,
)
分词器.pad_token = 分词器.eos_token
分词器.add_tokens([stop_token])
停止id = [分词器.encode(w)[0] for w in [stop_token]]
停止标准 = 关键词停止标准(停止id)
生成器 = pipeline(
"文本生成",
model=model_name,
device=device,
max_new_tokens=max_new_tokens,
torch_dtype=torch.float16,
stopping_criteria=StoppingCriteriaList([停止标准]),
)
示例 = "如何制作煎蛋卷"
文本 = "问题:{}\n回答:".format(示例)
结果 = 生成器(
文本,
num_return_sequences=1,
)
输出 = 结果[0]["生成文本"]
print(输出)
输出示例:
问题:如何制作煎蛋卷
回答:制作煎蛋卷时,首先将两个鸡蛋打入碗中搅拌。加入少许牛奶和盐胡椒调味。用中高火加热不粘锅,放入一汤匙黄油。待黄油融化后倒入蛋液。当蛋液开始凝固时,用铲子掀起边缘让未凝固的蛋液流到底部。待蛋液完全凝固且无流动液体时,加入馅料后将蛋饼对折盛盘。<|stop|>
训练过程
模型基于Dahoas/synthetic-instruct-gptj-pairwise数据集训练。我们将原始数据集划分为训练集(前32000条样本)和验证集(剩余1144条样本)。
模型共训练4个周期,使用8块A100 80GB显卡耗时5小时完成。参数设置为:单卡批量大小2(全局批量大小16),初始学习率0.00001(采用线性衰减至零)。训练过程记录详见Weights and Biases。