UniXcoder基础版模型卡
模型详情
模型描述
UniXcoder是一个统一的多模态预训练模型,利用多模态数据(即代码注释和抽象语法树)预训练代码表示。
- 开发团队: 微软团队
- 共享方 [可选]: Hugging Face
- 模型类型: 特征工程
- 支持语言 (NLP): 英文
- 许可证: Apache-2.0
- 相关模型:
- 更多信息参考:
使用说明
1. 依赖项
- pip安装torch
- pip安装transformers
2. 快速入门
我们实现了一个使用UniXcoder的类,您可以参考以下代码构建UniXcoder。
通过以下命令下载类文件:
wget https://raw.githubusercontent.com/microsoft/CodeBERT/master/UniXcoder/unixcoder.py
import torch
from unixcoder import UniXcoder
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UniXcoder("microsoft/unixcoder-base")
model.to(device)
接下来我们将展示不同模式下的零样本任务示例,包括:
- 代码搜索(仅编码器)
- 代码补全(仅解码器)
- 函数名预测(编码器-解码器)
- API推荐(编码器-解码器)
- 代码摘要(编码器-解码器)
3. 仅编码器模式
以代码搜索为例。
1) 代码与自然语言嵌入
获取代码片段嵌入的示例:
func = "def f(a,b): if a>b: return a else return b"
tokens_ids = model.tokenize([func],max_length=512,mode="<encoder-only>")
source_ids = torch.tensor(tokens_ids).to(device)
tokens_embeddings,max_func_embedding = model(source_ids)
func = "def f(a,b): if a<b: return a else return b"
tokens_ids = model.tokenize([func],max_length=512,mode="<encoder-only>")
source_ids = torch.tensor(tokens_ids).to(device)
tokens_embeddings,min_func_embedding = model(source_ids)
nl = "return maximum value"
tokens_ids = model.tokenize([nl],max_length=512,mode="<encoder-only>")
source_ids = torch.tensor(tokens_ids).to(device)
tokens_embeddings,nl_embedding = model(source_ids)
print(max_func_embedding.shape)
print(max_func_embedding)
2) 代码与自然语言相似度
计算自然语言与两个函数的余弦相似度。尽管两个函数仅有一个运算符差异(<
和>
),UniXcoder仍能区分它们。
norm_max_func_embedding = torch.nn.functional.normalize(max_func_embedding, p=2, dim=1)
norm_min_func_embedding = torch.nn.functional.normalize(min_func_embedding, p=2, dim=1)
norm_nl_embedding = torch.nn.functional.normalize(nl_embedding, p=2, dim=1)
max_func_nl_similarity = torch.einsum("ac,bc->ab",norm_max_func_embedding,norm_nl_embedding)
min_func_nl_similarity = torch.einsum("ac,bc->ab",norm_min_func_embedding,norm_nl_embedding)
print(max_func_nl_similarity)
print(min_func_nl_similarity)
4. 仅解码器模式
以代码补全为例。
context = """
def f(data,file_path):
# write json data into file_path in python language
"""
tokens_ids = model.tokenize([context],max_length=512,mode="<decoder-only>")
source_ids = torch.tensor(tokens_ids).to(device)
prediction_ids = model.generate(source_ids, decoder_only=True, beam_size=3, max_length=128)
predictions = model.decode(prediction_ids)
print(context+predictions[0][0])
5. 编码器-解码器模式
提供三个示例:函数名预测、API推荐和代码摘要。
1) 函数名预测
context = """
def <mask0>(data,file_path):
data = json.dumps(data)
with open(file_path, 'w') as f:
f.write(data)
"""
tokens_ids = model.tokenize([context],max_length=512,mode="<encoder-decoder>")
source_ids = torch.tensor(tokens_ids).to(device)
prediction_ids = model.generate(source_ids, decoder_only=False, beam_size=3, max_length=128)
predictions = model.decode(prediction_ids)
print([x.replace("<mask0>","").strip() for x in predictions[0]])
2) API推荐
context = """
def write_json(data,file_path):
data = <mask0>(data)
with open(file_path, 'w') as f:
f.write(data)
"""
tokens_ids = model.tokenize([context],max_length=512,mode="<encoder-decoder>")
source_ids = torch.tensor(tokens_ids).to(device)
prediction_ids = model.generate(source_ids, decoder_only=False, beam_size=3, max_length=128)
predictions = model.decode(prediction_ids)
print([x.replace("<mask0>","").strip() for x in predictions[0]])
3) 代码摘要
context = """
# <mask0>
def write_json(data,file_path):
data = json.dumps(data)
with open(file_path, 'w') as f:
f.write(data)
"""
tokens_ids = model.tokenize([context],max_length=512,mode="<encoder-decoder>")
source_ids = torch.tensor(tokens_ids).to(device)
prediction_ids = model.generate(source_ids, decoder_only=False, beam_size=3, max_length=128)
predictions = model.decode(prediction_ids)
print([x.replace("<mask0>","").strip() for x in predictions[0]])
参考文献
如果您使用此代码或UniXcoder,请考虑引用我们。
@article{guo2022unixcoder,
title={UniXcoder: Unified Cross-Modal Pre-training for Code Representation},
author={Guo, Daya and Lu, Shuai and Duan, Nan and Wang, Yanlin and Zhou, Ming and Yin, Jian},
journal={arXiv preprint arXiv:2203.03850},
year={2022}
}