语言:
- 英文
任务标签: 文本分类
标签:
- 预训练模型
许可证: Apache-2.0
库名称: sentence-transformers
基础模型:
- Qwen/Qwen2.5-7B
Qwen2.5-7B-embed-base
模型详情
Qwen2.5是一个包含不同规模解码器语言模型的系列。针对每个规模,我们发布了基础语言模型和对齐的聊天模型。该模型基于Transformer架构,采用SwiGLU激活函数、注意力QKV偏置、分组查询注意力等机制。此外,我们还改进了分词器,使其能自适应多种自然语言和代码。
系统要求
Qwen2.5的代码已集成至最新版Hugging Face transformers库,建议安装transformers>=4.37.0
版本,否则可能遇到以下错误:
KeyError: 'Qwen2.5'
使用说明
本模型已移除'lm_head'层,适用于生成嵌入向量。由于需要进一步微调(如intfloat/e5-mistral-7b-instruct所示),其默认表现可能不尽理想。
推理示例
from sentence_transformers import SentenceTransformer
import torch
model = SentenceTransformer("ssmits/Qwen2.5-7B-embed-base")
sentences = [
"今天天气真好。",
"外面阳光明媚!",
"他开车去了体育场。",
]
embeddings = model.encode(sentences)
print(embeddings.shape)
embeddings_tensor = torch.tensor(embeddings)
similarities = torch.nn.functional.cosine_similarity(embeddings_tensor.unsqueeze(0), embeddings_tensor.unsqueeze(1), dim=2)
print(similarities)
注意:测试中显存占用超过24GB(RTX 4090),建议使用A100或A6000进行推理。
原生Transformers推理
若不使用sentence-transformers库,可按以下方式操作:首先通过transformer模型处理输入,然后对上下文词嵌入执行适当的池化操作。
from transformers import AutoTokenizer, AutoModel
import torch
def mean_pooling(model_output, attention_mask):
token_embeddings = model_output[0]
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
sentences = ['这是示例句子', '每个句子都将被转换']
tokenizer = AutoTokenizer.from_pretrained('ssmits/Qwen2.5-7B-embed-base')
model = AutoModel.from_pretrained('ssmits/Qwen2.5-7B-embed-base')
encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')
with torch.no_grad():
model_output = model(**encoded_input)
sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
print("句子嵌入向量:")
print(sentence_embeddings)
多GPU启用方法
from transformers import AutoModel
from torch.nn import DataParallel
model = AutoModel.from_pretrained("ssmits/Qwen2.5-7B-embed-base")
for module_key, module in model._modules.items():
model._modules[module_key] = DataParallel(module)