license: apache-2.0
language:
- en
- zh
pipeline_tag: token-classification
bert-chunker-2
GitHub
bert-chunker-2 是一款基于 BERT 的文本分块器,通过分类器头部预测分块的起始标记(适用于 RAG 等场景),并采用滑动窗口技术将任意长度的文档切割成文本块。我们将其视为 semantic chunker 的替代方案,但特别之处在于,它不仅适用于结构化文本,还能处理非结构化和杂乱文本。作为 bert-chunker 的新实验版本,它针对文章结构进行了优化,力求在语义分块与结构分块之间取得平衡。该模型是通过对训练好的语义分块器和结构分块器进行 0.1:0.9 线性权重融合而成。
更成熟的新版本是 bert-chunker-3。
快速开始
运行以下代码:
import torch
from transformers import AutoConfig,AutoTokenizer,BertForTokenClassification
import math
model_path="tim1900/bert-chunker-2"
tokenizer = AutoTokenizer.from_pretrained(
model_path,
padding_side="right",
model_max_length=255,
trust_remote_code=True,
)
config = AutoConfig.from_pretrained(
model_path,
trust_remote_code=True,
)
device = 'cpu'
model = BertForTokenClassification.from_pretrained(model_path, ).to(device)
def chunk_text(model,text:str, tokenizer, prob_threshold=0.5)->list[str]:
MAX_TOKENS=255
tokens=tokenizer(text, return_tensors="pt",truncation=False)
input_ids=tokens['input_ids']
attention_mask=tokens['attention_mask'][:,0:MAX_TOKENS]
attention_mask=attention_mask.to(model.device)
CLS=input_ids[:,0].unsqueeze(0)
SEP=input_ids[:,-1].unsqueeze(0)
input_ids=input_ids[:,1:-1]
model.eval()
split_str_poses=[]
token_pos = []
windows_start =0
windows_end= 0
logits_threshold = math.log(1/prob_threshold-1)
print(f'Processing {input_ids.shape[1]} tokens...')
while windows_end <= input_ids.shape[1]:
windows_end= windows_start + MAX_TOKENS-2
ids=torch.cat((CLS, input_ids[:,windows_start:windows_end],SEP),1)
ids=ids.to(model.device)
output=model(input_ids=ids,attention_mask=torch.ones(1, ids.shape[1],device=model.device))
logits = output['logits'][:, 1:-1,:]
chunk_decision = (logits[:,:,1]>(logits[:,:,0]-logits_threshold))
greater_rows_indices = torch.where(chunk_decision)[1].tolist()
if len(greater_rows_indices)>0 and (not (greater_rows_indices[0] == 0 and len(greater_rows_indices)==1)):
split_str_pos=[tokens.token_to_chars(sp + windows_start + 1).start for sp in greater_rows_indices]
token_pos +=[sp + windows_start + 1 for sp in greater_rows_indices]
split_str_poses += split_str_pos
windows_start = greater_rows_indices[-1] + windows_start
else:
windows_start = windows_end
substrings = [text[i:j] for i, j in zip([0] + split_str_poses, split_str_poses+[len(text)])]
token_pos = [0] + token_pos
return substrings,token_pos
text='''在繁华都市的中心,摩天大楼高耸入云,汽车喇叭声此起彼伏,永不停歇。
萨拉,一位怀揣小说家梦想的女孩,在这座古老图书馆的静谧角落里找到了慰藉。
她被书架上低语着数个世纪故事的书籍包围,用文字构筑自己的世界,全然忘却外界的喧嚣。
亚历山大·汤普森博士乘坐"潘多拉探险号"飞船,正前往新发现的系外行星Zephyr-7。
作为这次远征的首席天体生物学家,他的任务是在行星的地下冰洞中寻找微生物生命的迹象。
每跨越一光年,揭开可能改变人类对宇宙生命认知的秘密的期待就愈发强烈。'''
chunks, token_pos=chunk_text(model,text, tokenizer, prob_threshold=0.5)
for i, (c,t) in enumerate(zip(chunks,token_pos)):
print(f'-----分块: {i}----标记索引: {t}--------')
print(c)