语言:
- 英文
标签:
- 特征提取
- 句子相似度
数据集:
- biu-nlp/abstract-sim
小部件:
- 句子相似度
- 特征提取
一个将抽象句子描述映射到符合描述的句子的模型。基于维基百科训练。使用load_finetuned_model
加载查询和句子编码器,使用encode_batch()
对句子进行编码。
注意:该方法采用双编码器架构。这是查询编码器;应与句子编码器一起使用。
from transformers import AutoTokenizer, AutoModel
import torch
from typing import List
from sklearn.metrics.pairwise import cosine_similarity
def load_finetuned_model():
sentence_encoder = AutoModel.from_pretrained("biu-nlp/abstract-sim-sentence")
query_encoder = AutoModel.from_pretrained("biu-nlp/abstract-sim-query")
tokenizer = AutoTokenizer.from_pretrained("biu-nlp/abstract-sim-sentence")
return tokenizer, query_encoder, sentence_encoder
def encode_batch(model, tokenizer, sentences: List[str], device: str):
input_ids = tokenizer(sentences, padding=True, max_length=512, truncation=True, return_tensors="pt",
add_special_tokens=True).to(device)
features = model(**input_ids)[0]
features = torch.sum(features[:,1:,:] * input_ids["attention_mask"][:,1:].unsqueeze(-1), dim=1) / torch.clamp(torch.sum(input_ids["attention_mask"][:,1:], dim=1, keepdims=True), min=1e-9)
return features
使用示例:
tokenizer, query_encoder, sentence_encoder = load_finetuned_model()
relevant_sentences = ["Fingersoft的母公司是Finger Group。",
"WHIRC——Wright-Hennepin的子公司",
"CK Life Sciences International (Holdings) Inc. (),或称CK Life Sciences,是CK Hutchison Holdings的子公司",
"EM Microelectronic-Marin(The Swatch Group的子公司)。",
"该公司目前是Jam Industries企业集团的一个部门。",
"Volt Technical Resources是Volt Workforce Solutions的业务部门,后者是Volt Information Sciences的子公司(目前在柜台交易代码为VISI.)。"
]
irrelevant_sentences = ["第二家公司被视为母公司的子公司。",
"该公司经历了多次转型。",
"该公司由其员工所有。",
"大型公司通过收购可能拥有特定市场领域的小公司来争夺市场份额。",
"母公司是指拥有另一家公司(或子公司)51%或以上有表决权股份的公司。",
"这是一家控股公司,通过其子公司在以下领域提供服务:石油和天然气、工业和基础设施、政府和电力。",
"RXVT Technologies不再是母公司的子公司。"
]
all_sentences = relevant_sentences + irrelevant_sentences
query = "<query>: 一家公司是更大公司的一部分。"
embeddings = encode_batch(sentence_encoder, tokenizer, all_sentences, "cpu").detach().cpu().numpy()
query_embedding = encode_batch(query_encoder, tokenizer, [query], "cpu").detach().cpu().numpy()
sims = cosine_similarity(query_embedding, embeddings)[0]
sentences_sims = list(zip(all_sentences, sims))
sentences_sims.sort(key=lambda x: x[1], reverse=True)
for s, sim in sentences_sims:
print(s, sim)
预期输出:
WHIRC——Wright-Hennepin的子公司 0.9396286
EM Microelectronic-Marin(The Swatch Group的子公司)。 0.93929046
Fingersoft的母公司是Finger Group。 0.936247
CK Life Sciences International (Holdings) Inc. (),或称CK Life Sciences,是CK Hutchison Holdings的子公司 0.9350312
该公司目前是Jam Industries企业集团的一个部门。 0.9273489
Volt Technical Resources是Volt Workforce Solutions的业务部门,后者是Volt Information Sciences的子公司(目前在柜台交易代码为VISI.)。 0.9005086
第二家公司被视为母公司的子公司。 0.6723645
这是一家控股公司,通过其子公司在以下领域提供服务:石油和天然气、工业和基础设施、政府和电力。 0.60081375
母公司是指拥有另一家公司(或子公司)51%或以上有表决权股份的公司。 0.59490484
该公司由其员工所有。 0.55286574
RXVT Technologies不再是母公司的子公司。 0.4321953
该公司经历了多次转型。 0.38889483
大型公司通过收购可能拥有特定市场领域的小公司来争夺市场份额。 0.25472647