空洞卷积的应用

Posted by KevinLT on January 6, 2019

空洞卷积介绍

空洞卷积(Dilated Convolution),又名扩张卷积, 顾名思义,就是在普通的卷积操作中加入了空洞。 如下面两张图所示,上图为我们平时常用的普通的3x3卷积核, 而下图是dilation rate为2的3x3卷积核。 dilation rate定义了在空洞卷积核当中每个卷积操作间的距离, 普通的卷积即dilation rate为1的情况。

图1. 普通卷积,dilation rate=1
图1. 普通卷积,dilation rate=1
图2. 空洞卷积,dilation rate=2
图2. 空洞卷积,dilation rate=2

语义分割

我们为什么要使用空洞卷积,它相比普通卷积有什么好处? 这还得从Multi-scale context aggregation by dilated convolutions1这篇文章讲起。 它所研究的问题是语义分割,如下图所示,即对于输入图片,对其进行像素级的分类。

图3. 语义分割示例图
图3. 语义分割示例图

在该领域中最早利用深度卷积神经网络的是Fully convolutional networks for semantic segmentation2这篇文章, 它提出的FCN(Fully Convolutional Networks)模型将分类CNN网络中的全连接层改为卷积层,最后通过上采样获取语义分割结果。

图4. FCN模型
图4. FCN模型

但分类CNN网络用于语义分割任务上是有其限制的。 在分类CNN网络中,池化是很重要的一步操作, 它能用于扩大感受野(Receptive Field),帮助网络提取图像的全局特征。 但在语义分割任务中,其输出的是一张与输入图像分辨率大小相同的分割结果图, 池化会导致原图中的很多信息丢失。

于是Multi-scale context aggregation by dilated convolutions就提出使用空洞卷积来提取图像特征, 抛弃之前常用的池化层,通过提高卷积的dilation rate来扩大感受野。 其优点是既能做到特征图像大小始终保持与输入图像一致,又能够扩大感受野获取到多尺度下的特征信息。 下图就是该文所提出的Context Networks的结构。 可以看到,对于感受野的影响来说,将dilation rate翻倍和使用池化的效果是一致的。

图5. Context Networks结构
图5. Context Networks结构

下图就是Context Networks与FCN的结果对比。 其中第一列为输入图像,第二列为FCN的分割结果,第三列为Context Networks的分割结果,第四列为Ground-Truth。 可以看到FCN的分割粒度明显比Context Networks粗糙,如上图中远方的汽车,下图中椅子之间的间隔。 这一结果表明了空洞卷积相比于池化的好处。

图6. 语义分割结果对比
图6. 语义分割结果对比

机器翻译

空洞卷积在语义分割任务上的成功也启发了领域的研究者们。 Neural machine translation in linear time3这篇文章所提出的ByteNet就是空洞卷积在机器翻译任务上成功应用。 机器翻译即对于给定字符串,利用计算机程序,将其从源语言翻译为目标语言。

早期使用深度学习方法的机器翻译网络都是基于LSTM来搭建的。 如下图所示,蓝色部分为LSTM encoder,用于对输入字符串提取特征。 它提取到的特征即最后一个block所输出的hidden state。 红色部分则为LSTM decoder,它以encoder输出的hidden state作为初始state, 输出翻译字符串。

图7. LSTM encoder-decoder模型
图7. LSTM encoder-decoder模型

根据Neural machine translation in linear time这篇文章的论述,LSTM encoder decoder有以下缺陷。

  1. encoder提取的特征大小为常量,即最后一个hidden state的size,但输出序列的长度是变化。
  2. 难以解决长距离依赖问题。理想的网络结构下,在前向和反向传播中, 输入token和输出token间的距离应尽可能与序列长度解耦。

于是,它们选择用空洞卷积神经网络来构建encoder-decoder网络,下图即他们所提出的ByteNet网络。 图中底部s即输入字符串中的字符的embedding表示,网络通过一维空洞卷积来对其提取特征。 得益于感受野的快速扩张,网络即能获取多尺度下的序列信息。 与LSTM encoder相比,CNN encoder输出的特征张量大小与输出序列长度是成正比的。 decoder网络与encoder网络类似,不同点在与其采用了causal连接, 即current token仅利用其之前的toekn信息,future token不参与计算。 可以看到,输出token与各输入token之间距离是相等的,这解决了之前LSTM的第2个缺陷。

图8. ByteNet
图8. ByteNet

相比于之前的LSTM网络,ByteNet取得了当时在机器任务翻译上的SOTA结果。

人体三维姿态估计

人体姿态估计任务即给定图像,预测出其中人物的关节点位置。 二维姿态估计预测的是关节点在图像平面上的坐标, 三维姿态估计预测的是关节点在三维坐标系下的位置。 在三维姿态估计方法中,目前较为成功的一类是两阶段的方法, 即先估计出二维关节点位置,再以二维关节点位置为输入预测三维姿态。

Exploiting temporal information for 3D human pose estimation4这篇文章就是上述的两阶段方法的一个典型。 它借鉴了机器翻译领域LSTM encoder decoder网络,将从视频中各帧提取出来的二维姿态序列作为输入,输出三维姿态序列,其结构如下图所示。

图9. 2D pose sequence to 3D pose sequence
图9. 2D pose sequence to 3D pose sequence

既然在机器翻译任务上,ByteNet用空洞卷积打败了LSTM,那么在三维姿态估计任务上其实也不例外。 3D human pose estimation in video with temporal convolutions and semi-supervised training5这篇文章 就提出了与ByteNet异曲同工的Temporal Convolutional Model,本质上还是就是利用了空洞卷积。下图为他们的网络结构示意。 它的不同点在于没有使用decoder部分,直接以二维姿态序列为输入,输出单帧三维姿态。

图10. Temporal Convolutional Model
图10. Temporal Convolutional Model

在Human 3.6m数据集上,Temporal Convolutional Model一举超越了之前的所有结果,取得了目前的SOTA结果。

语音合成

WaveNet: A generative model for raw audio6这篇文章所提出的WaveNet也是利用了空洞卷积来搭建的, 这篇文章我没有读,就不阐述了。

图11. WaveNet
图11. WaveNet

总结

从语义分割再到其他任务上的应用,伴随空洞卷积的关键字就是感受野。 它打破了以往CNN只能靠池化来扩大感受野的局限,使得CNN能够更广泛的应用到其他任务上, 具备了更加通用的多尺度特征提取能力。

参考文献


  1. Yu, Fisher, and Vladlen Koltun. “Multi-scale context aggregation by dilated convolutions.” arXiv preprint arXiv:1511.07122 (2015).

  2. Long, Jonathan, Evan Shelhamer, and Trevor Darrell. “Fully convolutional networks for semantic segmentation.” Proceedings of the IEEE conference on computer vision and pattern recognition. 2015.

  3. Kalchbrenner, Nal, et al. “Neural machine translation in linear time.” arXiv preprint arXiv:1610.10099 (2016).

  4. Hossain, Mir Rayat Imtiaz, and James J. Little. “Exploiting temporal information for 3D human pose estimation.” European Conference on Computer Vision. Springer, Cham, 2018.

  5. Pavllo, Dario, et al. “3D human pose estimation in video with temporal convolutions and semi-supervised training.” arXiv preprint arXiv:1811.11742 (2018).

  6. Van Den Oord, Aäron, et al. “Wavenet: A generative model for raw audio.” CoRR abs/1609.03499 (2016).