许可证: mit
语言:
- 英文
库名称: transformers
标签:
- 推特
- 垃圾内容检测
基础模型: FacebookAI/xlm-roberta-large
推理支持: 是
推特垃圾内容检测模型
该模型用于将X平台(原Twitter)的推文分类为「垃圾内容」(1) 或「优质内容」(0)。
训练数据集
模型基于UtkMl推特垃圾检测数据集微调,采用FacebookAI/xlm-roberta-large
作为基础模型。
使用方法
以下示例代码可帮助您从文本推文数据集中检测垃圾内容:
def classify_texts(df, text_col, model_path="cja5553/xlm-roberta-Twitter-spam-classification", batch_size=24):
'''
使用预训练序列分类模型将文本分类为「优质内容」或「垃圾内容」。
参数:
-----------
df : pandas.DataFrame
包含待分类文本的DataFrame
text_col : str
存储待分类文本的列名
model_path : str, 默认="cja5553/xlm-roberta-Twitter-spam-classification"
预训练序列分类模型的路径
batch_size : int, 可选, 默认=24
批量处理数据的大小,根据GPU内存调整
返回:
--------
pandas.DataFrame
原始DataFrame新增spam_prediction列,包含每条文本的预测标签("Quality"或"Spam")
'''
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForSequenceClassification.from_pretrained(model_path).to("cuda")
model.eval()
df["text"] = df[text_col].astype(str)
text_dataset = Dataset.from_pandas(df)
def tokenize_function(example):
return tokenizer(
example["text"],
padding="max_length",
truncation=True,
max_length=512
)
text_dataset = text_dataset.map(tokenize_function, batched=True)
text_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask'])
text_loader = DataLoader(text_dataset, batch_size=batch_size)
predictions = []
with torch.no_grad():
for batch in tqdm_notebook(text_loader):
input_ids = batch['input_ids'].to("cuda")
attention_mask = batch['attention_mask'].to("cuda")
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
logits = outputs.logits
preds = torch.argmax(logits, dim=-1).cpu().numpy()
predictions.extend(preds)
id2label = {0: "Quality", 1: "Spam"}
predicted_labels = [id2label[pred] for pred in predictions]
df["spam_prediction"] = predicted_labels
return df
spam_df_classification = classify_texts(df, "text_col")
print(spam_df_classification)
性能指标
基于80-10-10的训练-验证-测试集划分,测试集结果如下:
- 准确率: 0.974555
- 精确率: 0.97457
- 召回率: 0.97455
- F1分数: 0.97455
源代码
训练代码已开源:github.com/cja5553/Twitter_spam_detection
问题咨询?
请联系邮箱:alba@wustl.edu