library_name: transformers
tags:
- 视觉
- 图像分割
- 生态学
datasets:
- coralscapes
metrics:
- 平均交并比
license: apache-2.0
模型ID卡片
基于MiT-B5骨干网络的SegFormer模型,在1024x1024分辨率下针对Coralscapes数据集进行微调,详见论文《珊瑚景观数据集:珊瑚礁语义场景理解》(https://arxiv.org/abs/2503.20000)。
模型详情
模型描述
模型来源
快速使用指南
使用该模型对Coralscapes数据集图像进行分割的最简方法如下:
from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation
from PIL import Image
from datasets import load_dataset
dataset = load_dataset("EPFL-ECEO/coralscapes")
image = dataset["test"][42]["image"]
preprocessor = SegformerImageProcessor.from_pretrained("EPFL-ECEO/segformer-b5-finetuned-coralscapes-1024-1024")
model = SegformerForSemanticSegmentation.from_pretrained("EPFL-ECEO/segformer-b5-finetuned-coralscapes-1024-1024")
inputs = preprocessor(image, return_tensors = "pt")
outputs = model(**inputs)
outputs = preprocessor.post_process_semantic_segmentation(outputs, target_sizes=[(image.size[1], image.size[0])])
label_pred = outputs[0].numpy()
对于与模型训练尺寸(1024x1024)差异较大的图像,推荐采用滑动窗口策略以获得更优效果:
import torch
import torch.nn.functional as F
from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation
from PIL import Image
import numpy as np
from datasets import load_dataset
device = 'cuda' if torch.cuda.is_available() else 'cpu'
def resize_image(image, target_size=1024):
"""将图像短边调整至1024像素"""
h_img, w_img = image.size
if h_img < w_img:
new_h, new_w = target_size, int(w_img * (target_size / h_img))
else:
new_h, new_w = int(h_img * (target_size / w_img)), target_size
return image.resize((new_h, new_w))
def segment_image(image, preprocessor, model, crop_size=(1024,1024), num_classes=40, transform=None):
"""基于图像尺寸和长宽比生成重叠滑动窗口进行预测"""
h_crop, w_crop = crop_size
img = torch.Tensor(np.array(resize_image(image)).transpose(2,0,1)).unsqueeze(0)
batch_size, _, h_img, w_img = img.size()
if transform:
img = torch.Tensor(transform(image=img.numpy())["image"]).to(device)
h_grids = int(np.round(3/2*h_img/h_crop)) if h_img > h_crop else 1
w_grids = int(np.round(3/2*w_img/w_crop)) if w_img > w_crop else 1
h_stride = int((h_img - h_crop + h_grids -1)/(h_grids -1)) if h_grids > 1 else h_crop
w_stride = int((w_img - w_crop + w_grids -1)/(w_grids -1)) if w_grids > 1 else w_crop
preds = img.new_zeros((batch_size, num_classes, h_img, w_img))
count_mat = img.new_zeros((batch_size, 1, h_img, w_img))
for h_idx in range(h_grids):
for w_idx in range(w_grids):
y1, x1 = h_idx * h_stride, w_idx * w_stride
y2, x2 = min(y1 + h_crop, h_img), min(x1 + w_crop, w_img)
y1, x1 = max(y2 - h_crop, 0), max(x2 - w_crop, 0)
crop_img = img[:, :, y1:y2, x1:x2]
with torch.no_grad():
if preprocessor:
inputs = preprocessor(crop_img, return_tensors="pt").to(device)
else:
inputs = crop_img.to(device)
outputs = model(**inputs)
resized_logits = F.interpolate(outputs.logits[0].unsqueeze(0),
size=crop_img.shape[-2:], mode="bilinear")
preds += F.pad(resized_logits, (x1, preds.shape[3]-x2, y1, preds.shape[2]-y2)).cpu()
count_mat[:, :, y1:y2, x1:x2] += 1
preds = (preds / count_mat).argmax(dim=1)
preds = F.interpolate(preds.unsqueeze(0).type(torch.uint8), size=image.size[::-1], mode='nearest')
return preds.squeeze().cpu().numpy()
dataset = load_dataset("EPFL-ECEO/coralscapes")
image = dataset["test"][42]["image"]
preprocessor = SegformerImageProcessor.from_pretrained("EPFL-ECEO/segformer-b5-finetuned-coralscapes-1024-1024")
model = SegformerForSemanticSegmentation.from_pretrained("EPFL-ECEO/segformer-b5-finetuned-coralscapes-1024-1024")
label_pred = segment_image(image, preprocessor, model)
训练与评估
数据
使用通用珊瑚礁语义分割数据集Coralscapes进行训练和评估。
训练流程
在原始SegFormer实现基础上增强数据扩增:
- 批量大小4,100个epoch
- AdamW优化器(初始学习率6e-5,权重衰减1e-2)
- 多项式学习率调度器(幂次1)
- 训练时随机缩放(1.02-2倍)、旋转(±15°)、色彩抖动(对比度/饱和度/亮度0.8-1.2,色相±0.05)
- 评估采用1024x1024非重叠滑动窗口
结果
- 测试准确率: 82.761
- 测试平均交并比: 57.800
引用
若使用本模型,请引用:
@misc{sauder2025coralscapesdatasetsemanticscene,
title={珊瑚景观数据集:珊瑚礁语义场景理解},
author={Jonathan Sauder and Viktor Domazetoski and Guilhem Banc-Prandi and Gabriela Perna and Anders Meibom and Devis Tuia},
year={2025},
eprint={2503.20000},
archivePrefix={arXiv},
primaryClass={cs.CV},
url={https://arxiv.org/abs/2503.20000},
}