这篇博客讲讲对AlexNet如何调参,最近因为科研的原因,对卷积神经网络进行了调参,忙活了大概一两个星期,期间对AlexNet进行了全方面、多层次、宽领域的调参,总结了一些调参技巧,在这儿总结一下。
调参的技巧,我打算从两个方面来讲,一个是模型,一个是训练参数。
环境参数
- Pytorch 0.4.1
- torchvision 0.1.8
- Python 3.6.3
- CUDA 8.0.61
- Linux + Pycharm
模型参数
首先我摆出原始的模型代码,是AlexNet的代码,这个在网上很常见。
调参技巧总结/1.png)
模型代码:
1 | class AlextNet(nn.Module): |
这个就是卷积神经网络的核心代码,这段代码如何看,在我之前的博客《DCGAN模型讲解及避坑指南》 中已经讲的很详细了。这里不再赘述。
这个模型是由两大部分组成,卷积层和全连接层。可以看到参数也非常多,因此可以调节的部分也非常多。
模型调参技巧-卷积层
首先在《卷积神经网络的卷积核大小、个数,卷积层数如何确定呢?》 讲述了相当全面的卷积核以及卷积层的参数确定。
- 卷积核
卷积核是卷积神经网络的核心,卷积核的大小设置,我都是设置为3×3,一个方面是因为我的图片比较小,还有一个原因是这个卷积核是最小的并且能够体现出上下左右中方位信息的卷积核。
- 池化层
池化层的窗口大小是2×2,可以搭配stride =2 的步幅,这是为了层加深就需要使下采样速度更慢。
- 层数
开始的代码是三层卷积层,但是效果一直都一般,后来没办法,我添加了一个卷积层。卷积层达到了四层,而卷积核有五个。一般来说,卷积核的个数最好是奇数个。
模型调参技巧-全连接层
卷积层是用来把图片的特征提取出来,而全连接层是根据提取出来的特征将图片进行分类。在这篇《Dropout 层应该加在什么地方?》 中讲述了dropout层应该放置的位置。一般都要放置在全连接层中
,我也尝试了放置在卷积层中,但是效果一般。(原因是:卷积层参数比较少,不容易过拟合)。这篇博客还记录了外文论文和博客关于这一观点的例子。
其次是全连接层的dropout
层的参数设置。我的个人经验是,不要设置地太大,我一般设置在0.1~0.3
之间。
我经常看网上有人推荐使用softmax
进行分类,但是实际上在我的代码里面crossEntropyloss()
函数里面就包含了softmax
,在这篇博客《PyTorch学习》里面讲过。
训练参数调参技巧
这个训练参数其实说的不太好,准确来说应该是,图片的参数和模型训练过程中遇到的参数的调整。
还有一个很重要的调参技巧:
- 用小数据集进行训练,在小训练集上取得一个比较好的结果,再到大数据集上调一调,这样的好处是小数据集迭代速度快节约时间,还有就是小数据集上能取得比较好的结果,一般大数据集上能取得更好的结果。
训练参数调参技巧-图片读取
这个图片读取的方法,在我之前的博客《基于Pytorch对WGAN_gp模型进行调参总结》里面也讲过,主要用的函数代码是:
1 | data_transforms = { |
这段代码的意思就是按照data_trandform
的形式进行读取,读取之后按照指定的batch_size
和 shuffle
方式进行载入数据。开始的时候我对这部分代码不以为意,想着就是简单的数据载入,没什么可以改的地方。因为我看网上给出的分类代码,这部分的参数都差不多,所以我想改点参数应该没什么影响,后来发现我错了!!!
特别有效的参数是:transforms.CenterCrop()
这个函数是将更改大小之后的图片进行剪裁,是对图片的中心部分进行剪裁。刚开始我在这设置的参数是96
,训练的准确率一直在85%
左右徘徊,但是将参数改为224
之后,训练准确率达到惊人的92%
,本人感觉十分满意。
训练参数调参技巧-优化器(Optimizer
)
这个优化器也是调参非常关键的部分,因为优化器控制了梯度变化的方式,如果优化器搭配地好,将显著提升训练集的准确率。
优化器可以调节的参数主要包括:优化器的种类, 学习率, weight_decay
等。在这篇博客《当前训练神经网络最快的方式:AdamW优化算法+超级收敛》中 讲了不同的优化器的设置,尤其是L2正则化的参数设置,L2正则化就是weight_decay
参数,这个参数可以防止过拟合。
除此之外还有学习率需要调节,学习率在程序迭代过程中,随着loss
值的下降,还需要继续削减,因此可以设置一个判断语句,只要等到迭代条件达到某一程度时就可以就可以继续减小其值。
当然由于GAN
添加了一些生成的数据到训练集中,因此会导致过拟合(测试集中的数据都是原始数据,模型学习到的是原始数据和生成数据的模样,这样也可以很定性地理解这个问题)也就不奇怪了,我在这篇《读书笔记:对抗过拟合》博客中找到一些对抗过拟合的方法。比如早停止,dropout,数据增强等方法。
展望
我在这篇博客《Pytorch实战2:ResNet-18实现Cifar-10图像分类(测试集分类准确率95.170%)》里面看到,使用ResNet-18
实现图像分类的功能,效果也非常好。如果后期,有需要通过不同的网络进行对比试验,这是一个好的参照点。