库名称: tf-keras
许可证: mit
任务标签: 图像到图像
MSI-网络
📖 用于视觉显著性预测的上下文编码器-解码器网络
🤗 该模型的演示可在HuggingFace Spaces上找到。
概述
MSI-Net是一种视觉显著性模型,通过基于眼动数据训练的上下文编码器-解码器网络预测人类在自然图像上的注视点。该模型基于卷积神经网络架构,包含一个具有不同膨胀率的多个卷积层的ASPP模块,以并行捕获多尺度特征。此外,它将生成的表示与全局场景信息相结合,以实现对视觉显著性的准确预测。MSI-Net包含约2500万个参数,因此是计算资源有限应用的合适选择。有关该模型的更多信息,请查看GitHub及相应的论文或预印本。
要求
要安装所需的依赖项,请使用pip
或conda
:
pip install -r requirements.txt
conda env create -f requirements.yml
使用示例
导入依赖项
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from huggingface_hub import snapshot_download
下载仓库文件
hf_dir = snapshot_download(repo_id="alexanderkroner/MSI-Net")
加载显著性模型
model = tf.keras.models.load_model(hf_dir)
加载输入预处理和输出后处理的函数
def get_target_shape(original_shape):
original_aspect_ratio = original_shape[0] / original_shape[1]
square_mode = abs(original_aspect_ratio - 1.0)
landscape_mode = abs(original_aspect_ratio - 240 / 320)
portrait_mode = abs(original_aspect_ratio - 320 / 240)
best_mode = min(square_mode, landscape_mode, portrait_mode)
if best_mode == square_mode:
target_shape = (320, 320)
elif best_mode == landscape_mode:
target_shape = (240, 320)
else:
target_shape = (320, 240)
return target_shape
def preprocess_input(input_image, target_shape):
input_tensor = tf.expand_dims(input_image, axis=0)
input_tensor = tf.image.resize(
input_tensor, target_shape, preserve_aspect_ratio=True
)
vertical_padding = target_shape[0] - input_tensor.shape[1]
horizontal_padding = target_shape[1] - input_tensor.shape[2]
vertical_padding_1 = vertical_padding // 2
vertical_padding_2 = vertical_padding - vertical_padding_1
horizontal_padding_1 = horizontal_padding // 2
horizontal_padding_2 = horizontal_padding - horizontal_padding_1
input_tensor = tf.pad(
input_tensor,
[
[0, 0],
[vertical_padding_1, vertical_padding_2],
[horizontal_padding_1, horizontal_padding_2],
[0, 0],
],
)
return (
input_tensor,
[vertical_padding_1, vertical_padding_2],
[horizontal_padding_1, horizontal_padding_2],
)
def postprocess_output(
output_tensor, vertical_padding, horizontal_padding, original_shape
):
output_tensor = output_tensor[
:,
vertical_padding[0] : output_tensor.shape[1] - vertical_padding[1],
horizontal_padding[0] : output_tensor.shape[2] - horizontal_padding[1],
:,
]
output_tensor = tf.image.resize(output_tensor, original_shape)
output_array = output_tensor.numpy().squeeze()
output_array = plt.cm.inferno(output_array)[..., :3]
return output_array
加载并预处理示例图像
input_image = tf.keras.utils.load_img(hf_dir + "/example.jpg")
input_image = np.array(input_image, dtype=np.float32)
original_shape = input_image.shape[:2]
target_shape = get_target_shape(original_shape)
input_tensor, vertical_padding, horizontal_padding = preprocess_input(
input_image, target_shape
)
将输入张量输入模型
output_tensor = model(input_tensor)["output"]
后处理并可视化输出
saliency_map = postprocess_output(
output_tensor, vertical_padding, horizontal_padding, original_shape
)
alpha = 0.65
blended_image = alpha * saliency_map + (1 - alpha) * input_image / 255
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.imshow(input_image / 255)
plt.title("输入图像")
plt.axis("off")
plt.subplot(1, 2, 2)
plt.imshow(blended_image)
plt.title("显著性图")
plt.axis("off")
plt.tight_layout()
plt.show()
数据集
在基于注视数据训练模型之前,编码器权重是从在ImageNet分类任务上预训练的VGG16主干初始化的。然后,该模型在SALICON数据集上进行了训练,该数据集包含作为注视测量代理的鼠标移动记录。最后,权重可以在人类眼动追踪数据上进行微调。因此,MSI-Net也在以下数据集之一上进行了训练,尽管这里我们仅提供SALICON基础模型:
我们模型的评估可在原始MIT显著性基准和更新的MIT/Tübingen显著性基准上找到。后者的结果来自预测显著性图的概率表示,并进行了特定度量的后处理,以实现公平的模型比较。
局限性
MSI-Net是在自由观看范式下收集的人类注视数据上训练的。因此,预测的显著性图可能无法推广到在实验过程中接受任务指令的观察者。还必须注意,训练数据主要由自然图像组成。因此,对于特定图像类型(例如分形、图案)或对抗性示例的注视预测可能不太准确。
另一个限制是,基于显著性的裁剪算法(2018年至2021年间应用于上传到社交媒体平台Twitter的图像)在种族和性别方面显示出偏见。因此,使用显著性模型时需要谨慎,并承认其应用中涉及的潜在风险。
引用
如果您发现此代码或模型有用,请引用以下论文:
@article{kroner2020contextual,
title={Contextual encoder-decoder network for visual saliency prediction},
author={Kroner, Alexander and Senden, Mario and Driessens, Kurt and Goebel, Rainer},
url={http://www.sciencedirect.com/science/article/pii/S0893608020301660},
doi={https://doi.org/10.1016/j.neunet.2020.05.004},
journal={Neural Networks},
publisher={Elsevier},
year={2020},
volume={129},
pages={261--270},
issn={0893-6080}
}