语言:
- 英语
标签:
- 特征提取
- 句子相似度
数据集:
- biu-nlp/abstract-sim
组件:
- 句子相似度
- 特征提取
一个将抽象句子描述映射到符合描述的句子的模型。基于维基百科训练。使用load_finetuned_model
加载查询和句子编码器,使用encode_batch()
通过模型对句子进行编码。
注意:该方法采用双编码器架构。这是句子编码器;需与查询编码器配合使用。
from transformers import AutoTokenizer, AutoModel
import torch
from typing import List
from sklearn.metrics.pairwise import cosine_similarity
def load_finetuned_model():
sentence_encoder = AutoModel.from_pretrained("biu-nlp/abstract-sim-sentence")
query_encoder = AutoModel.from_pretrained("biu-nlp/abstract-sim-query")
tokenizer = AutoTokenizer.from_pretrained("biu-nlp/abstract-sim-sentence")
return tokenizer, query_encoder, sentence_encoder
def encode_batch(model, tokenizer, sentences: List[str], device: str):
input_ids = tokenizer(sentences, padding=True, max_length=512, truncation=True, return_tensors="pt",
add_special_tokens=True).to(device)
features = model(**input_ids)[0]
features = torch.sum(features[:,1:,:] * input_ids["attention_mask"][:,1:].unsqueeze(-1), dim=1) / torch.clamp(torch.sum(input_ids["attention_mask"][:,1:], dim=1, keepdims=True), min=1e-9)
return features
使用示例:
tokenizer, query_encoder, sentence_encoder = load_finetuned_model()
relevant_sentences = [
"Fingersoft的母公司是Finger Group。",
"WHIRC——Wright-Hennepin的子公司",
"CK Life Sciences International (Holdings) Inc. (),或称CK Life Sciences,是CK Hutchison Holdings的子公司",
"EM Microelectronic-Marin(斯沃琪集团的子公司)。",
"该公司目前是Jam Industries企业集团的一个部门。",
"Volt Technical Resources是Volt Workforce Solutions的业务部门,后者是Volt Information Sciences的子公司(目前在场外交易市场以VISI.交易)。"
]
irrelevant_sentences = [
"第二家公司被视为母公司的子公司。",
"该公司经历了多次转型。",
"该公司由其员工所有。",
"大公司通过收购可能拥有特定市场领域的小公司来争夺市场份额。",
"母公司是指拥有另一家公司(或子公司)51%或以上有表决权股份的公司。",
"这是一家控股公司,通过其子公司在以下领域提供服务:石油和天然气、工业和基础设施、政府和电力。",
"RXVT Technologies不再是母公司的子公司。"
]
all_sentences = relevant_sentences + irrelevant_sentences
query = "<查询>:一家公司是更大公司的一部分。"
embeddings = encode_batch(sentence_encoder, tokenizer, all_sentences, "cpu").detach().cpu().numpy()
query_embedding = encode_batch(query_encoder, tokenizer, [query], "cpu").detach().cpu().numpy()
sims = cosine_similarity(query_embedding, embeddings)[0]
sentences_sims = list(zip(all_sentences, sims))
sentences_sims.sort(key=lambda x: x[1], reverse=True)
for s, sim in sentences_sims:
print(s, sim)
预期输出:
WHIRC——Wright-Hennepin的子公司 0.9396286
EM Microelectronic-Marin(斯沃琪集团的子公司)。 0.93929046
Fingersoft的母公司是Finger Group。 0.936247
CK Life Sciences International (Holdings) Inc. (),或称CK Life Sciences,是CK Hutchison Holdings的子公司 0.9350312
该公司目前是Jam Industries企业集团的一个部门。 0.9273489
Volt Technical Resources是Volt Workforce Solutions的业务部门,后者是Volt Information Sciences的子公司(目前在场外交易市场以VISI.交易)。 0.9005086
第二家公司被视为母公司的子公司。 0.6723645
这是一家控股公司,通过其子公司在以下领域提供服务:石油和天然气、工业和基础设施、政府和电力。 0.60081375
母公司是指拥有另一家公司(或子公司)51%或以上有表决权股份的公司。 0.59490484
该公司由其员工所有。 0.55286574
RXVT Technologies不再是母公司的子公司。 0.4321953
该公司经历了多次转型。 0.38889483
大公司通过收购可能拥有特定市场领域的小公司来争夺市场份额。 0.25472647