license: apache-2.0
DiffCSE:基于差异的对比学习框架用于句子嵌入


arXiv论文链接:https://arxiv.org/abs/2204.10298
即将发表于NAACL 2022
作者:
Yung-Sung Chuang,
Rumen Dangovski,
Hongyin Luo,
Yang Zhang,
Shiyu Chang,
Marin Soljačić,
Shang-Wen Li,
Scott Wen-tau Yih,
Yoon Kim,
James Glass
我们的代码主要基于SimCSE的代码实现,更多细节请参考其代码库。
概述

我们提出了DiffCSE,一种无监督对比学习框架用于学习句子嵌入表示。DiffCSE通过捕捉原始句子与编辑后句子的差异来学习嵌入表示,其中编辑后的句子是通过随机掩码原始句子后利用掩码语言模型采样生成。我们证明DiffCSE是等变对比学习(Dangovski et al., 2021)的一个实例,该框架泛化了对比学习,能够学习对某些类型增强不敏感而对其他"有害"增强类型敏感的表示。实验表明,DiffCSE在无监督句子表示学习方法中取得了最先进的结果,在语义文本相似度任务上比无监督SimCSE高出2.3个绝对百分点。
环境配置

系统要求
安装定制化Transformers包
cd transformers-4.2.1
pip install .
若已通过pip安装transformers==4.2.1
,需将modeling_bert.py
放入<your_python_env>/site-packages/transformers/models/bert/modeling_bert.py
,modeling_roberta.py
放入<your_python_env>/site-packages/transformers/models/bert/modeling_roberta.py
。
我们修改了这两个文件以实现基于BERT/RoBERTa的_条件式_预训练任务。建议直接安装我们定制的Transformers包。
安装其他依赖
pip install -r requirements.txt
下载预训练数据集
cd data
bash download_wiki.sh
下载下游任务数据集
cd SentEval/data/downstream/
bash download_dataset.sh
训练
(同run_diffcse.sh
)
python train.py \
--model_name_or_path bert-base-uncased \
--generator_name distilbert-base-uncased \
--train_file data/wiki1m_for_simcse.txt \
--output_dir <模型输出目录> \
--num_train_epochs 2 \
--per_device_train_batch_size 64 \
--learning_rate 7e-6 \
--max_seq_length 32 \
--evaluation_strategy steps \
--metric_for_best_model stsb_spearman \
--load_best_model_at_end \
--eval_steps 125 \
--pooler_type cls \
--mlp_only_train \
--overwrite_output_dir \
--logging_first_step \
--logging_dir <日志目录> \
--temp 0.05 \
--do_train \
--do_eval \
--batchnorm \
--lambda_weight 0.005 \
--fp16 --masking_ratio 0.30
新增参数说明:
--lambda_weight
: 论文第3节所述的λ系数
--masking_ratio
: MLM生成器的掩码比例
--generator_name
: 生成器模型名称。bert-base-uncased
对应distilbert-base-uncased
,roberta-base
对应distilroberta-base
继承自SimCSE的参数:
--train_file
: 训练文件路径
--model_name_or_path
: 预训练模型名称
--temp
: 对比损失温度系数(固定为0.05)
--pooler_type
: 池化方法
--mlp_only_train
: 训练时使用MLP层但测试时不使用,适用于无监督SimCSE/DiffCSE
论文结果基于NVIDIA 2080Ti GPU(CUDA 11.2)获得,不同设备或CUDA/Python/PyTorch版本可能导致性能差异。
评估

我们提供了Colab笔记本便于复现结果,也可通过以下命令评估:
BERT模型
语义相似度任务
python evaluation.py \
--model_name_or_path voidism/diffcse-bert-base-uncased-sts \
--pooler cls_before_pooler \
--task_set sts \
--mode test
迁移学习任务
python evaluation.py \
--model_name_or_path voidism/diffcse-bert-base-uncased-trans \
--pooler cls_before_pooler \
--task_set transfer \
--mode test
RoBERTa模型
语义相似度任务
python evaluation.py \
--model_name_or_path voidism/diffcse-roberta-base-sts \
--pooler cls_before_pooler \
--task_set sts \
--mode test
迁移学习任务
python evaluation.py \
--model_name_or_path voidism/diffcse-roberta-base-trans \
--pooler cls_before_pooler \
--task_set transfer \
--mode test
更多细节请参考SimCSE代码库。
预训练模型

- DiffCSE-BERT-base (STS任务): https://huggingface.co/voidism/diffcse-bert-base-uncased-sts
- DiffCSE-BERT-base (迁移任务): https://huggingface.co/voidism/diffcse-bert-base-uncased-trans
- DiffCSE-RoBERTa-base (STS任务): https://huggingface.co/voidism/diffcse-roberta-base-sts
- DiffCSE-RoBERTa-base (迁移任务): https://huggingface.co/voidism/diffcse-roberta-base-trans
可通过SimCSE提供的API加载模型:
from diffcse import DiffCSE
model_bert_sts = DiffCSE("voidism/diffcse-bert-base-uncased-sts")
model_bert_trans = DiffCSE("voidism/diffcse-bert-base-uncased-trans")
model_roberta_sts = DiffCSE("voidism/diffcse-roberta-base-sts")
model_roberta_trans = DiffCSE("voidism/diffcse-roberta-base-trans")
引用

若我们的工作对您有帮助,请引用我们的论文及SimCSE论文:
@inproceedings{chuang2022diffcse,
title={{DiffCSE}: Difference-based Contrastive Learning for Sentence Embeddings},
author={Chuang, Yung-Sung and Dangovski, Rumen and Luo, Hongyin and Zhang, Yang and Chang, Shiyu and Soljacic, Marin and Li, Shang-Wen and Yih, Wen-tau and Kim, Yoon and Glass, James},
booktitle={Annual Conference of the North American Chapter of the Association for Computational Linguistics (NAACL)},
year={2022}
}
@inproceedings{gao2021simcse,
title={{SimCSE}: Simple Contrastive Learning of Sentence Embeddings},
author={Gao, Tianyu and Yao, Xingcheng and Chen, Danqi},
booktitle={Empirical Methods in Natural Language Processing (EMNLP)},
year={2021}
}