pipeline_tag: 文本生成
tags:
- PyTorch
- Transformers
- gpt2
license: 无许可证
language: 俄语
widget:
- text: "- 朱丽叶有7个甜甜圈,然后她吃了3个。她还剩多少个甜甜圈? -"
- text: "- 已经撸了4只兔狲,还剩6只需要撸。总共有多少只兔狲要撸? -"
- text: "- 先告诉我五乘九等于多少? -"
- text: "- 你咋这么嚣张呢? -"
- text: "- 嗨!您那儿一切都好吗? -"
俄语闲聊、演绎与常识推理模型
该模型是对话系统原型的核心组件,具有两大主要功能。
第一项功能是生成闲聊对话。输入上下文为历史对话记录(前1-10轮对话内容):
- 嗨,最近怎么样?
- 嗨,就那样吧。
- <<< 此处期待模型生成的回复 >>>
第二项功能是基于附加事实或"常识"进行问题解答。假设相关事实通过其他模型(如sbert_pq)从外部知识库检索获得。模型将利用给定事实和问题文本,构建符合语法且最简洁的人类化回答。相关事实应置于问题文本前,格式模拟对话者陈述:
- 今天是9月15日。现在是几月份?
- 九月
模型不要求所有添加上下文的事实都与问题严格相关。因此知识检索模型可以牺牲精确性换取覆盖率,允许包含冗余信息。当前版本的闲聊模型能自动筛选关键事实并忽略无关内容,最多支持前置5个事实。例如:
- 斯塔斯16岁。斯塔斯住在波多利斯克。斯塔斯没有私家车。斯塔斯住在哪里?
- 波多利斯克
某些情况下,模型可基于两个相互关联的前提进行三段论推理。推导得出的隐含结论不会显式呈现,而是用于间接生成答案:
- 如果亚里士多德是希腊哲学家,且所有哲学家都终有一死,那么亚里士多德会死吗?
- 会
从示例可见,模型输入的事实信息格式极其自然灵活。
除逻辑推理外,模型还能解决小学1-2年级水平的简单算术题(含两个数字参数):
- 2加8等于几?
- 10
模型版本与指标
当前发布的模型参数量为7.6亿(相当于sberbank-ai/rugpt3large_based_on_gpt2级别)。延迟测试集的算术题准确率如下:
基础模型 |
算术准确率 |
sberbank-ai/rugpt3large_based_on_gpt2 |
0.91 |
sberbank-ai/rugpt3medium_based_on_gpt2 |
0.70 |
sberbank-ai/rugpt3small_based_on_gpt2 |
0.58 |
tinkoff-ai/ruDialoGPT-small |
0.44 |
tinkoff-ai/ruDialoGPT-medium |
0.69 |
"算术准确率"列中的0.91表示91%的测试题目完全正确。任何与标准答案的偏差(如输出"120"而非"119")均计为错误。
使用示例
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
device = "cuda" if torch.cuda.is_available() else "cpu"
model_name = "inkoziev/rugpt_chitchat"
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.add_special_tokens({'bos_token': '<s>', 'eos_token': '</s>', 'pad_token': '<pad>'})
model = AutoModelForCausalLM.from_pretrained(model_name)
model.to(device)
model.eval()
# 输入最后2-3轮对话,每轮以"-"开头单独成行
input_text = """<s>- 嗨!在干嘛呢?
- 嗨 :) 正打车呢
-"""
encoded_prompt = tokenizer.encode(input_text, add_special_tokens=False, return_tensors="pt").to(device)
output_sequences = model.generate(input_ids=encoded_prompt, max_length=100, num_return_sequences=1, pad_token_id=tokenizer.pad_token_id)
text = tokenizer.decode(output_sequences[0].tolist(), clean_up_tokenization_spaces=True)[len(input_text)+1:]
text = text[: text.find('</s>')]
print(text)
联系方式
如有使用疑问或改进建议,请联系mentalcomputing@gmail.com
引用格式:
@MISC{rugpt_chitchat,
author = {伊利亚·科济耶夫},
title = {具备常识推理的俄语闲聊模型},
url = {https://huggingface.co/inkoziev/rugpt_chitchat},
year = 2022
}