【论文阅读】GAN模型的搭建以及运行

我学习GAN是因为科研项目的需要,当时想到在项目上使用GAN,是一个因缘巧合吧。2018视觉与学习青年学者研讨会在大连进行,然后蹭了几场talk,和一帮大佬谈笑风生,感觉自己吊吊的。当然这不是重点,重点是当时听了一场依图科技有关核磁共振图像的talk,当时下面有人提到了,样本不均衡的问题。当时依图科技的吴双博士给出了GAN的解决方案。因此迁移到我的科研项目上,我思考了是不是我的项目也能这么用呢。从此开始了GAN的掉坑爬坑之旅。

最近在工作站上安装了pycharm,anaconda和pytorch。算是把基本准备工作完成了,然后需要开始做的是将github上一些BEGAN模型的程序跑起来。然后根据手头的图片改一些代码参数。本文的开头呢,还是一如往常,不是直接开始讲模型,而是先讲一些小的知识点。因为我也是白手起家开始学深度学习相关的知识,并且需要利用深度学习解决项目上的一些问题。因此在很多时候会遇到不少开始以为很难,但是解决之后感觉很简单的问题。其实说白了,这就是基础不够扎实。好了,废话不多说,开始讲干货了。

遇到的一些奇怪问题

pytorch中.pth文件的作用

​ 在github上下载下来的代码中经常会看到.pth文件,我开始也是一脸蒙蔽,后来查了一下。pytorh实现对网络结构和模型参数的保存。有两种方式:一是保存整个神经网络的结构信息和模型参数信息,save的对象是网络net;二是只保存神经网络的训练模型参数,save的对象是net.state_dict()

1
2
3
4
5
6
## 两种保存模型的方式
torch.save(model_object,'net.pth') # 保存整个神经网络的结构和模型参数
torch.save(net.state_dict(),'net.pth') # 只保存神经网络的模型参数
## 两种加载模型的方式
torch.load('.pth') # 对应第一种保存完整的网络结构信息,重载的时候直接初始化新的神经网络对象
net.load_state_dict(torch.load('.pth')) #对于第二种,需要先导入对应的网络

Python直接输入到文件

print输出直接到文件里 ,在pyton2和3中不一样。

1
2
3
4
5
6
7
8
9
10
11
# python3 
k=10
f=open("./output/recard","w+")
for i in range(k):
print("第{0}条数据".format(i),file=f)

# python2
k=10
f=open("./output/recard","w+")
for i in range(k):
print>>f,"第{0}条数据".format(i)

pycharm进行python调试

想谈这个问题是因为,最近遇到程序总是报错,如果没有一个科学的python调试方案,整个项目进度会非常缓慢。在讲python的debug之前,我们需要先回顾一下visual studioC++ 的调试。

  • 我们先设置断点,这样程序就将运行到这个位置停下来
  • 然后按F5,程序将处于调试状态进行运行
  • 然后按F11程序将逐语句进行运行,这个时候你将看到每个语句运行之后,各个变量的值
  • 当你遇到一个函数的时候,想看程序在函数内部运行的时候,各个变量都是怎么变化的,你可以按F10,程序将进入函数进行运行,然后你再按F11程序将在函数内部逐语句进行运行。

类似地,我们来看在pycharm上python如何进行调试。

  • 设置断点,这和vs上一样。
  • 然后按绿色甲壳虫的符号(或者使用shift+ F9)进行断点调试。
  • 然后点击Step Over(或者F8),继续往下运行,到下一个断点

TensorFlow的data_format:NHWC、NCHW的区别与转换

区别:

  • NHWC: [batch,in_height,in_width,in_channels]
  • NCHW: [batch,in_channels,in_height,in_width]

转换:

  • NHWC->NCHW

    1
    2
    3
    4
    5
    6
    7
    >import tensorflow as tf
    >x=tf.reshape(tf.range(24),[1,3,4,2])
    >out= tf.transpose(x,[0,3,1,2])
    >
    >print x.shape
    >print out.shape
    >

    输出:

    1
    2
    3
    >(1,3,4,2)
    >(1,2,3,4)
    >
  • NCHW->NHWC

    1
    2
    3
    4
    5
    6
    7
    >import tensorflow as tf
    >x=tf.reshape(tf.range(24),[1,2,3,4])
    >out = tf.transpose(x,[0,2,3,1])
    >
    >print x.shape
    >print out.shape
    >

    输出:

    1
    2
    3
    >(1,2,3,4)
    >(1,3,4,2)
    >

Utils文件的作用

经常在程序中遇到Utils文件,开始对这个比较疑惑,后来查了一下它的作用,Utils包中放一些常用的公共方法,提供一些实用的方法和数据结构比如:

  • 日期类来产生和获取日期和时间
  • 提供随机数类来产生各种类型的随机数
  • 提供堆栈类表示堆栈结构
  • 提供哈希表来表示哈希结构

附带着在廖雪峰的教学网站上学到了,有关模块构建的相关知识,也非常实用。

pytorch安装的注意点

哎,今天傻逼了,之前的BEGAN一直没调通,后来想想算了,先用之前跑通的WGAN吧,然后拿过来放在工作站上跑,md,结果怎么都跑不通,总提示torch.FloatTensor object has no attribute shape就是说没有shape这个参数。我就纳了闷了,在pycharm上也不报错,跑到这段程序就报错,一脸蒙蔽。怎么都想不通。晚上过来在命令行里面敲进import pytorch结果提示没有这个模块,我滴个乖乖,原来我压根没安装pytorch这个框架。得,装吧。

安装pytorch,还是比较方便的。就是先到pytorch的官网上,按照自己电脑的环境选择合适的安装命令。

python argparse命令行解析包用法总结

首先解释一下为什么要先讲这个,因为很多在github上的深度学习项目的运行都是采用命令行的形式进行的,一开始我很懵逼,我习惯了用一个run的按钮启动程序。在网上看到python采用命令行形式启动程序的方式,在main函数里面需要加入一些argparse命令行相关参数,因此学习一下这个命令行解析包非常必要。

使用说明:

  • 导入命令行解析器的包

  • 设置分析器

    为脚本添加说明描述:

    1
    2
    3
    4
    5
    > parse = argparse.ArgumentParser(description='描述信息')
    > 或者采用
    > parse =argparse.ArgumentParser()
    > parse.description='描述信息'
    >
  • 设置程序版本号

    程序可能需要版本升级

    1
    2
    > parse.add_argument("-v","-version",action='version',version='%(prog)s 1.0')
    >

    设置参数action 的值为version 就可以了。后面的 %(prog)s代表脚本的名称。是可变的。

  • 必要的参数

    必须要添加的参数:SourceDir是一个必填的参数,后面help是帮助信息,默认参数格式是字符串,如果需要参数是其他内容,需要指定type ,比如你需要的参数数字,就需要用type= int来指定类型;

    1
    2
    > parser.add_argument('sourceDir', help='Select source directory')
    >

    添加一个不带-的前缀的选填参数 :在最后面添加上nargs='?' 表示这一项选填,其中问号表示这项可以是零或者1;

    1
    2
    > parser.add_argument('targetDir', help='Select target directory', nargs='?')
    >

    添加一个带-前缀的选填参数:在参数处加上- 前缀,就表示这个参数是选填的,并且同时表示可以简写或者双横线的全拼;

    1
    2
    > parser.add_argument("-a", "--add", help="add something")
    >

    添加一个互斥的二选一选项参数:需要定义一个参数组group然后往这个组里面添加不同功能的参数就行,action="store_true"这个配置项代表这个不需要填写值,直接写参数就行,输出的时候为布尔值。

    1
    2
    3
    4
    > group = parser.add_mutually_exclusive_group()
    > group.add_argument("-m", "--move", help="The way to operate the file is to move", action="store_true")
    > group.add_argument("-c", "--copy", help="The way to operate the file is to copy", action="store_true")
    >
  • 获取所有参数

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    > args = parser.parse_args()
    > print(args)
    >
    > # 打印 usage
    > parser.print_usage()
    > # 打印完整的 help 信息
    > parser.print_help()
    > # 输出 usage
    > parser.format_usage()
    > # 输出 help
    > parser.format_help()
    >

Linux路径总结

  1. Linux中:

    • . 表示当前目录
    • .. 表示上一级目录
    • - 表示上次所在目录
    • ~ 表示当前用户的home目录
    • 使用pwd可以获取当前所在路径(绝对路径)

    • 绝对路径就是以根/目录为起点,以你所到达的目录为终点,表现形式为/usr/local/bin

    • 进入哪个目录取决于哪个更方便,比如当前在/usr/local/bin下,要进入上一级目录当然可以使用cd ..,这就比使用cd /usr/local方便。

GAN的代码讲解及修改

接下来我将通过几个GAN的代码来讲解GAN的工作原理。顺带讲讲其中存在的坑。

参考文献