license: apache-2.0
虚拟编译器:汇编代码搜索的终极解决方案
简介
本仓库包含ACL 2024论文《虚拟编译器:汇编代码搜索的终极解决方案》中的模型及对应评估数据集。
虚拟编译器是指能够将任意编程语言编译为底层汇编代码的大型语言模型(LLM)。该虚拟编译器模型基于340亿参数的CodeLlama构建,托管于elsagranger/VirtualCompiler。
我们通过force-exec.py脚本强制执行生成的虚拟汇编代码,评估其与真实汇编代码的相似度,相关评估数据集存放于virtual_assembly_and_ground_truth目录。
通过汇编代码搜索这一下游任务评估虚拟编译器的有效性,评估数据集托管于elsagranger/AssemblyCodeSearchEval。
使用指南
使用FastChat和vllm worker部署模型。请在独立终端(如tmux
)中运行以下命令:
LOGDIR="" python3 -m fastchat.serve.openai_api_server \
--host 0.0.0.0 --port 8080 \
--controller-address http://localhost:21000
LOGDIR="" python3 -m fastchat.serve.controller \
--host 0.0.0.0 --port 21000
LOGDIR="" RAY_LOG_TO_STDERR=1 \
python3 -m fastchat.serve.vllm_worker \
--model-path ./VirtualCompiler \
--num-gpus 8 \
--controller http://localhost:21000 \
--max-num-batched-tokens 40960 \
--disable-log-requests \
--host 0.0.0.0 --port 22000 \
--worker-address http://localhost:22000 \
--model-names "VirtualCompiler"
模型部署完成后,使用do_request.py
向模型发起请求:
~/C/VirtualCompiler (main)> python3 do_request.py
test rdx, rdx
setz al
movzx eax, al
neg eax
retn
汇编代码搜索编码器
由于HuggingFace不支持在文件夹内加载远程模型,我们将基于虚拟编译器增强的汇编代码搜索数据集训练的模型托管于vic-encoder。可通过model.py
测试自定义模型加载。
以下是文本编码器与汇编编码器的使用示例。提取二进制文件中汇编代码的方法请参考process_asm.py脚本。
def calc_map_at_k(logits, pos_cnt, ks=[10,]):
_, indices = torch.sort(logits, dim=1, descending=True)
ranks = torch.nonzero(
indices < pos_cnt,
as_tuple=False
)[:, 1].reshape(logits.shape[0], -1)
mrr = torch.mean(1 / (ranks + 1), dim=1)
res = {}
for k in ks:
res[k] = (
torch.sum((ranks < k).float(), dim=1) / min(k, pos_cnt)
).cpu().numpy()
return ranks.cpu().numpy(), res, mrr.cpu().numpy()
pos_asm_cnt = 1
query = ["列出目录中的所有文件"]
anchor_asm = [ {"1": "endbr64", "2": "mov eax, 0" }, ... ]
neg_anchor_asm = [ {"1": "push rbp", "2": "mov rbp, rsp", ... }, ... ]
query_embs = text_encoder(**text_tokenizer(query))
kwargs = dict(padding=True, pad_to_multiple_of=8, return_tensors="pt")
anchor_asm_ids = asm_tokenizer.pad([asm_tokenizer(pos) for pos in anchor_asm], **kwargs)
neg_anchor_asm_ids = asm_tokenizer.pad([asm_tokenizer(neg) for neg in neg_anchor_asm], **kwargs)
asm_embs = asm_encoder(**anchor_asm_ids)
asm_neg_emb = asm_encoder(**neg_anchor_asm_ids)
logits_pos = torch.einsum(
"ic,jc->ij", [query_embs, asm_embs])
logits_neg = torch.einsum(
"ic,jc->ij", [query_embs, asm_neg_emb[pos_asm_cnt:]]
)
logits = torch.cat([logits_pos, logits_neg], dim=1)
ranks, map_at_k, mrr = calc_map_at_k(
logits, pos_asm_cnt, [1, 5, 10, 20, 50, 100])