许可证:afl-3.0
语言:
- 英文
评估指标:
- 准确率
基础模型:
- distilbert/distilbert-base-uncased
管道标签:文本分类
标签:
- 塔罗牌
- 问题检测器
DistilBERT 占卜问题检测模型
本项目提供了一个基于DistilBERT
的占卜问题检测模型,可用于判断输入文本是否为符合塔罗占卜的问题。
📂 目录结构
model.safetensors:训练好的模型权重
config.json:模型架构配置文件
tokenizer.json:分词器配置
special_tokens_map.json:特殊标记配置
vocab.txt:分词器词汇表
🚀 快速开始
1️⃣ 安装依赖
请确保你的环境已安装 Python 3.8+,然后运行以下命令安装所需的依赖库:
pip install torch transformers fastapi uvicorn safetensors
2️⃣ 直接运行推理
如果你想直接在本地测试模型,可以运行 inference.py:
python inference.py
示例代码(inference.py):
import torch
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
model_path = "./distilbert-question-detector"
tokenizer = DistilBertTokenizer.from_pretrained(model_path)
model = DistilBertForSequenceClassification.from_pretrained(model_path)
model.eval()
text = "Is this a question?"
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=128)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
probabilities = torch.nn.functional.softmax(logits, dim=-1)
predicted_class = torch.argmax(probabilities, dim=-1).item()
print(f"Probabilities: {probabilities}")
print(f"Predicted class: {predicted_class}")
3️⃣ 运行 API
你也可以使用 FastAPI 部署一个 HTTP 接口,允许其他应用通过 HTTP 请求访问模型。
uvicorn app:app --host 0.0.0.0 --port 8000
示例 API 代码(app.py):
from fastapi import FastAPI
import torch
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
app = FastAPI()
model_path = "./distilbert-question-detector/checkpoint-5150"
tokenizer = DistilBertTokenizer.from_pretrained(model_path)
model = DistilBertForSequenceClassification.from_pretrained(model_path)
model.eval()
@app.post("/predict/")
async def predict(text: str):
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=128)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
probabilities = torch.nn.functional.softmax(logits, dim=-1)
predicted_class = torch.argmax(probabilities, dim=-1).item()
return {"text": text, "probabilities": probabilities.tolist(), "predicted_class": predicted_class}
API 运行后,可通过以下方式测试:
curl -X 'POST' \
'http://127.0.0.1:8000/predict/' \
-H 'Content-Type: application/json' \
-d '{"text": "Is this a valid question?"}'
📌 结果说明
predicted_class: 0 代表输入文本符合条件
predicted_class: 1 代表输入文本不符合条件
示例结果:
{
"text": "Is this a valid question?",
"probabilities": [[0.9266, 0.0734]],
"predicted_class": 0
}