许可协议: cc-by-nc-4.0
任务标签: 特征提取
标签:
Salesforce/SFR-嵌入代码-2B_R
Salesforce研究院开发的SFR-嵌入模型。
Salesforce/SFR-嵌入代码是一个通用型嵌入模型系列,适用于多语言和多任务的代码与文本检索。在多项代码检索任务中,该模型展现出优于各类开源代码嵌入模型的卓越性能。
详情请参阅我们的论文!
我们还提供4亿参数规模的模型Salesforce/SFR-嵌入代码-400_R
伦理考量
本次发布仅用于支持学术论文的研究目的。我们的模型、数据集和代码并非为所有下游用途专门设计或评估。我们强烈建议用户在部署前评估并解决与准确性、安全性和公平性相关的潜在问题。我们鼓励用户考虑人工智能的普遍局限性,遵守适用法律,并在选择应用场景时采用最佳实践,特别是对于错误或滥用可能严重影响人们生活、权利或安全的高风险场景。更多使用指南请参考我们的可接受使用政策和人工智能可接受使用政策。
许可声明:
用户需自行评估与原始数据集和数据相关的许可或条款下的任何义务或责任。本次发布仅用于支持学术论文的研究目的。
本发布模型是基于Gemma微调的版本,Gemma的使用受ai.google.dev/gemma/terms的Gemma使用条款约束。此外,本模型使用限制遵循Gemma禁止使用政策(ai.google.dev/gemma/prohibited_use_policy),该政策通过引用并入本协议。
CoIR基准测试性能
模型 |
模型规模 |
CoIR平均值(NDCG@10) |
SFR-嵌入代码 |
20亿参数 |
67.4 |
CodeSage-Large-v2 |
13亿参数 |
64.2 |
CodeSage-Large |
13亿参数 |
61.0 |
SFR-嵌入代码 |
4亿参数 |
61.9 |
CodeRankEmbed |
1.37亿参数 |
60.1 |
CodeSage-Base |
3.56亿参数 |
57.5 |
Voyage-Code-002 |
- |
56.3 |
CodeSage-Small |
1.3亿参数 |
54.4 |
SFR-嵌入团队(†表示共同负责人)
- 刘烨
- 孟瑞
- 沙菲克·雷汉·乔蒂
- 西尔维奥·萨瓦雷斯
- 熊才明 †
- 周英博 †
- 塞米赫·亚武兹 †
运行方法
转换器
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel
query_instruction_example = "给定代码或文本,检索相关内容"
queries = [
"如何用Python实现快速排序?"
]
passages = [
"def quick_sort(arr):\n if len(arr) <= 1:\n return arr\n pivot = arr[len(arr) // 2]\n left = [x for x in arr if x < pivot]\n middle = [x for x in arr if x == pivot]\n right = [x for x in arr if x > pivot]\n return quick_sort(left) + middle + quick_sort(right)",
"def bubble_sort(arr):\n n = len(arr)\n for i in range(n):\n for j in range(0, n-i-1):\n if arr[j] > arr[j+1]:\n arr[j], arr[j+1] = arr[j+1], arr[j]\n return arr"
]
model = AutoModel.from_pretrained('Salesforce/SFR-Embedding-Code-2B_R', trust_remote_code=True)
max_length = 32768
query_embeddings = model.encode_queries(queries, instruction=query_instruction_example, max_length=max_length)
passage_embeddings = model.encode_corpus(passages, max_length=max_length)
query_embeddings = F.normalize(query_embeddings, p=2, dim=1)
passage_embeddings = F.normalize(passage_embeddings, p=2, dim=1)
scores = (query_embeddings @ passage_embeddings.T) * 100
print(scores.tolist())
句子转换器
from sentence_transformers import SentenceTransformer
query_instruction_example = "指令:给定代码或文本,检索相关内容\n查询: "
queries = ["如何用Python实现快速排序?"]
passages = [
"def quick_sort(arr):\n if len(arr) <= 1:\n return arr\n pivot = arr[len(arr) // 2]\n left = [x for x in arr if x < pivot]\n middle = [x for x in arr if x == pivot]\n right = [x for x in arr if x > pivot]\n return quick_sort(left) + middle + quick_sort(right)",
"def bubble_sort(arr):\n n = len(arr)\n for i in range(n):\n for j in range(0, n-i-1):\n if arr[j] > arr[j+1]:\n arr[j], arr[j+1] = arr[j+1], arr[j]\n return arr"
]
model = SentenceTransformer('Salesforce/SFR-Embedding-Code-2B_R', trust_remote_code=True)
query_embeddings = model.encode(queries, prompt=query_instruction_example)
passage_embeddings = model.encode(passages)
similarities = model.similarity(query_embeddings, passage_embeddings)
print(similarities)
引用
@article{liu2024codexembed,
title={CodeXEmbed:面向多语言多任务代码检索的通用嵌入模型家族},
author={刘烨 and 孟瑞 and 乔蒂, 沙菲克 and 萨瓦雷斯, 西尔维奥 and 熊才明 and 周英博 and 亚武兹, 塞米赫},
journal={arXiv预印本 arXiv:2411.12644},
year={2024}
}