许可协议: cc-by-nc-4.0
任务标签: 特征提取
标签:
Salesforce/SFR-嵌入代码-400M_R
Salesforce研究的SFR-嵌入模型。
Salesforce/SFR-嵌入代码是一个通用的嵌入模型家族,适用于多语言和多任务的代码及文本检索。在多个代码检索任务中,它展现出优于各种开源代码嵌入模型的性能。
详情请参阅我们的论文!
伦理考量
此版本仅用于支持学术论文的研究目的。我们的模型、数据集和代码并非为所有下游用途专门设计或评估。我们强烈建议用户在部署此模型前评估并解决与准确性、安全性和公平性相关的潜在问题。我们鼓励用户考虑AI的普遍限制,遵守适用法律,并在选择用例时采用最佳实践,特别是对于错误或滥用可能严重影响人们生活、权利或安全的高风险场景。有关用例的进一步指导,请参考我们的使用政策和AI使用政策。
许可声明:
用户需自行评估与原始数据集和数据相关的任何许可或条款下的义务和责任。此版本仅用于支持学术论文的研究目的。
CoIR基准测试性能
模型 |
模型大小 |
CoIR平均(NDCG@10) |
SFR-嵌入代码 |
2B |
67.4 |
CodeSage-Large-v2 |
1.3B |
64.2 |
CodeSage-Large |
1.3B |
61.0 |
SFR-嵌入代码 |
400M |
61.9 |
CodeRankEmbed |
137M |
60.1 |
CodeSage-Base |
356M |
57.5 |
Voyage-Code-002 |
- |
56.3 |
CodeSage-Small |
130M |
54.4 |
SFR-嵌入团队 († 表示共同负责人)
- 刘烨
- 孟瑞
- Shafiq Rayhan Joty
- Silvio Savarese
- 熊才明 †
- 周英博 †
- Semih Yavuz †
如何运行
转换器
import torch.nn.functional as F
from transformers import AutoModel, AutoTokenizer
input_texts = [
"如何在Python中实现快速排序?",
"def quick_sort(arr):\n if len(arr) <= 1:\n return arr\n pivot = arr[len(arr) // 2]\n left = [x for x in arr if x < pivot]\n middle = [x for x in arr if x == pivot]\n right = [x for x in arr if x > pivot]\n return quick_sort(left) + middle + quick_sort(right)",
"def bubble_sort(arr):\n n = len(arr)\n for i in range(n):\n for j in range(0, n-i-1):\n if arr[j] > arr[j+1]:\n arr[j], arr[j+1] = arr[j+1], arr[j]\n return arr",
]
model_path = 'Salesforce/SFR-Embedding-Code-400M_R'
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModel.from_pretrained(model_path, trust_remote_code=True)
batch_dict = tokenizer(input_texts, max_length=8192, padding=True, truncation=True, return_tensors='pt')
outputs = model(**batch_dict)
embeddings = outputs.last_hidden_state[:, 0]
embeddings = F.normalize(embeddings, p=2, dim=1)
scores = (embeddings[:1] @ embeddings[1:].T) * 100
print("相似度分数:", scores.tolist())
句子转换器
需要 sentence_transformers>=2.7.0
from sentence_transformers import SentenceTransformer
from sentence_transformers.util import cos_sim
sentences = [
"如何在Python中实现快速排序?",
"def quick_sort(arr):\n if len(arr) <= 1:\n return arr\n pivot = arr[len(arr) // 2]\n left = [x for x in arr if x < pivot]\n middle = [x for x in arr if x == pivot]\n right = [x for x in arr if x > pivot]\n return quick_sort(left) + middle + quick_sort(right)",
"def bubble_sort(arr):\n n = len(arr)\n for i in range(n):\n for j in range(0, n-i-1):\n if arr[j] > arr[j+1]:\n arr[j], arr[j+1] = arr[j+1], arr[j]\n return arr",
]
model = SentenceTransformer('Salesforce/SFR-Embedding-Code-400M_R', trust_remote_code=True)
embeddings = model.encode(sentences)
similarities = cos_sim(embeddings[0], embeddings[1:])
print(similarities)
引用
@article{liu2024codexembed,
title={CodeXEmbed: 多语言多任务代码检索的通用嵌入模型家族},
author={刘烨 and 孟瑞 and Jot, Shafiq and Savarese, Silvio and 熊才明 and 周英博 and Yavuz, Semih},
journal={arXiv预印本 arXiv:2411.12644},
year={2024}
}