许可证:MIT
微件示例:
- 文本:"Apoorv Umang Saxena|姓氏"
示例标题:"姓氏预测"
- 文本:"Apoorv Saxena|国家"
示例标题:"国家预测"
- 文本:"第二次世界大战|后续事件"
示例标题:"后续事件预测"
这是一个基于WikiKG90Mv2数据集从头训练的t5-small模型。方法详情请参阅:https://github.com/apoorvumang/kgt5/
该模型针对尾实体预测任务训练,即给定主语实体和关系,预测宾语实体。输入格式应为"<实体文本>|<关系文本>"。
我们使用原始文本标题和描述作为实体与关系的文本表示。原始文本来自ogb数据集(路径:dataset/wikikg90m-v2/mapping/entity.csv与relation.csv)。实体表示默认采用标题,若标题重复则通过描述消歧。若仍无法消歧,则使用维基数据ID(如Q123456)。
模型在4块1080Ti GPU上训练约1.5个epoch,单epoch耗时约5.5天。
评估时,我们对每个(s,r)输入对从解码器采样300次,剔除无效实体预测后按对数概率排序,并进行结果过滤。验证集MRR为0.22(完整排行榜见:https://ogb.stanford.edu/docs/lsc/leaderboards/#wikikg90mv2)
可通过以下代码调用预训练模型(实体ID映射等完整流程为简洁省略,需者可联系Apoorv:apoorvumang@gmail.com):
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
tokenizer = AutoTokenizer.from_pretrained("apoorvumang/kgt5-wikikg90mv2")
model = AutoModelForSeq2SeqLM.from_pretrained("apoorvumang/kgt5-wikikg90mv2")
import torch
def getScores(ids, scores, pad_token_id):
"""从model.generate输出获取序列分数"""
scores = torch.stack(scores, dim=1)
log_probs = torch.log_softmax(scores, dim=2)
ids = ids[:,1:]
x = ids.unsqueeze(-1).expand(log_probs.shape)
needed_logits = torch.gather(log_probs, 2, x)
final_logits = needed_logits[:, :, 0]
final_logits[ids == pad_token_id] = 0
return final_logits.sum(dim=-1).cpu().detach().numpy()
def topkSample(input, model, tokenizer, num_samples=5, num_beams=1, max_output_length=30):
tokenized = tokenizer(input, return_tensors="pt")
out = model.generate(**tokenized,
do_sample=True,
num_return_sequences=num_samples,
output_scores=True,
return_dict_in_generate=True,
max_length=max_output_length)
out_str = tokenizer.batch_decode(out.sequences, skip_special_tokens=True)
out_scores = getScores(out.sequences, out.scores, tokenizer.pad_token_id)
return sorted(zip(out_str, out_scores), key=lambda x:x[1], reverse=True)
def greedyPredict(input, model, tokenizer):
out_tokens = model.generate(tokenizer([input], return_tensors="pt").input_ids)
return tokenizer.batch_decode(out_tokens, skip_special_tokens=True)[0]
input = "Sophie Valdemarsdottir|贵族头衔"
topkSample(input, model, tokenizer, num_samples=5)
完整评估需加载实体别名列表(需约8GB内存,下载地址:
- 实体别名:https://storage.googleapis.com/kgt5-wikikg90mv2/ent_alias_list.pickle
- 关系别名:https://storage.googleapis.com/kgt5-wikikg90mv2/rel_alias_list.pickle
提交的验证/测试结果通过对每个输入采样300次后过滤得到。由于采样随机性,最终MRR可能存在微小波动(相比固定束搜索,大规模采样效果更优)。
!wget https://storage.googleapis.com/kgt5-wikikg90mv2/valid.txt
k = 1
count_at_k = 0
for line in tqdm(open('valid.txt').readlines()[:1000]):
input, target = line.strip().split('\t')
if target in [pred[0] for pred in topkSample(input, model, tokenizer, num_samples=k)]:
count_at_k += 1
print(f'Hits@{k}未过滤: {count_at_k/1000}')