caffe实战之classify.py解析

来源:互联网 时间:2017-06-01

本文将对caffe/python下的classify.py代码以及相关的classifier.py和io.py进行解析。

一、classify.py

由最后的if __name__ == '__main__': main(sys.argv)代表该文件在命令行下运行,则运行main函数,参数存放在sys.argv中。在main函数定义中,分别判断并存入各类参数,分别如下:

input_file:输入图像,参数为必需。

output_file:输出文件,参数为必需。

--model_def:网络测试结构文件,默认为imagenet的deploy.txt文件

--pretrained_model:网络参数文件,默认为imagenet的bvlc_reference_caffenet.caffemodel文件。

--gpu:是否用gpu计算,action=’store true’表示如果不指定,则默认false,用cpu,否则为true,用gpu。对于一张128*128的灰度图像,cpu前向计算大概20ms,而gpu仅5ms左右。

--center_only:默认false,即对输入图像的裁剪图像做预测,然后将结果进行平均;指定为true,即只取输入图像的中间部分做一次预测。当然,如果指定输入图像和裁剪尺寸一致,那么取中间部分即为原图本身。

--images_dim:输入图像尺寸,只考虑高和宽,默认256*256。

--mean_file:均值文件。注意数据格式是npy文件,即存储为numpy.array格式,维度为(通道,高,宽)。如果仅有通过compute_mean.bin计算的均值文件,需要进行转化。默认均值文件为imagenet的ilsvrc_2012_mean.npy文件。

--input_scale:图像预处理后的缩放系数,发生在减去均值后,默认为1。

--raw_scale:图像预处理前的缩放系数,发生在减去均值前。由于读入的像素值在[0,1]区间,则默认为255.0,使像素在[0,255]区间。

--channel_swap:通道调整,默认为’2,1,0’,因为caffe通过opencv读入的图片通道为BGR,因此必须将RGB-->BGR,即第0个通道和第2个通道交换。

--ext:默认’jpg’,代表如果输入指定为目录,则仅读取后缀名为jpg的文件。

下面几个参数是改进版classify.py中新加的。

--labels_file:标签类别文件,默认为imagenet的synset_words.txt文件。

--print_results:是否打印结果到屏幕,不指定则false,指定为true。

--force_grayscale:是否指定输入为单通道图像,不指定则false,指定为true。

通过args = parser.parse_args()更新,确认最终输入的参数。下面进行分类测试:

# 列表生成式,通过逗号划分维度字符串,并强制转化为int类型。最后为列表。

image_dims = [int(s) for s in args.images_dim.split(',')]

# 如果指定了均值文件,则加载均值文件

if args.mean_file:

mean = np.load(args.mean_file)

# 如果是灰度图像,则没有通道交换。如果是rgb图像,如果有通道交换,通过逗号划 分字符串,强制转化为int类型,存到列表中。

    if args.force_grayscale:

        channel_swap = None

    else:

        if args.channel_swap

            channel_swap = [int(s) for s in args.channel_swap.split(',')]

# 如果指定了gpu,则启动gpu模式

    if args.gpu:

        caffe.set_mode_gpu()

        print("GPU mode")

    else:

        caffe.set_mode_cpu()

        print("CPU mode")

# 初始化分类器,见classifier.py

classifier = caffe.Classifier(..)

# 下面就是读取文件的代码,有反映说加载灰度图像会报错的情况,这里给出记载灰度和rgb图像的代码。

    if args.force_grayscale:

    # 这里的false代表返回单通道图像,见io.py

        inputs = [caffe.io.load_image(args.input_file, False)]

    else:

        inputs = [caffe.io.load_image(args.input_file)]

# inputs用[]括起来,代表用列表存储,所以len(inputs)代表有多少张输入图像。

 

# 计时,这里以ms为单位

start = time.time() * 1000

# 前向计算,见classifier.py,得到preditions为np数组,行为输入图像张数,列为预测总类别数目

    predictions = classifier.predict(inputs, not args.center_only)

    print("Done in %.2f ms." % (time.time() * 1000 - start))

print("Predictions : %s" % predictions)

 

# 打印结果,根据得分排序,给出分数较高的前五类,类名称由labels_file指定。

# print result, add by caisenchuan

if args.print_results:

...

二、Classifier.py

该文件定义了classifier类,包括了初始化函数__init__和predict函数。

1、 __init__:

首先调用了caffe类的初始化函数,并设定了test模式。

接着调用了transformer类,以cifar-10为例,输入为字典{’data’: (1,3,32,32)}。

然后是set_transpose方法:

# 将维度从(32,32,3)转化为(3,32,32),适用于caffe中的处理

self.transformer.set_transpose(in_, (2, 0, 1))

然后调用transformer类的set方法,设置各种参数,具体见下文io.py中的解析。

最后,关于图像维度的定义:

# 裁剪尺寸根据prototxt定义

self.crop_dims = np.array(self.blobs[in_].data.shape[2:])

# 如果没有定义图片尺寸参数,则等于裁剪的尺寸;否则按定义的来

# 一般来说,如果用了裁剪,则图像尺寸>裁剪尺寸

    if not image_dims:

        image_dims = self.crop_dims

self.image_dims = image_dims

2、 predict:

执行前向计算,预测图像分类的概率。参数为输入以及是否过采样的布尔值。

# 定义inputs_维度(m,h,w,channel)

input_ = np.zeros((len(inputs),

                   self.image_dims[0],

                   self.image_dims[1],

                   inputs[0].shape[2]),

                   dtype=np.float32)

    # 将所有待分类尺寸统一为image_dims尺寸

    for ix, in_ in enumerate(inputs):

        input_[ix] = caffe.io.resize_image(in_, self.image_dims)

# 如果过采样,则每张图像通过裁剪生成10张图像

# 维度将变为(10*m,h,w,channel)

    if oversample:

        # Generate center, corner, and mirrored crops.

        input_ = caffe.io.oversample(input_, self.crop_dims)

# 否则,裁剪中心区域。取图像尺寸的中点,然后分别往上往下取裁剪的尺寸长度。

# 以64*64裁剪32*32为例,(64,64)取中点-->(32,32),扩充到四个坐标-->(32,32,32,32),

# 取裁剪尺寸(32,32,32,32)+(-16,-16,16,16)-->(16,16,48,48)

    else:

        # Take center crop.

        center = np.array(self.image_dims) / 2.0

        crop = np.tile(center, (1, 2))[0] + np.concatenate([

            -self.crop_dims / 2.0,

            self.crop_dims / 2.0

        ])

        crop = crop.astype(int)

        input_ = input_[:, crop[0]:crop[2], crop[1]:crop[3], :]

# 将输入转化为caffe需要的格式,维度变为(m,channel,h,w)

caffe_in = np.zeros(np.array(input_.shape)[[0, 3, 1, 2]],

                    dtype=np.float32)

# 每张图片都进行预处理,见io.py的preprocess函数

    for ix, in_ in enumerate(input_):

        caffe_in[ix] = self.transformer.preprocess(self.inputs[0], in_)

    # 前向计算,输出为字典,out[‘prob’]为各类概率

out = self.forward_all(**{self.inputs[0]: caffe_in})

    predictions = out[self.outputs[0]]

 

    # 如果过采样,需要对每10个预测结果进行平均

    if oversample:

        predictions = predictions.reshape((len(predictions) / 10, 10, -1))

        predictions = predictions.mean(1)

# 返回结果

return predictions

三、io.py

该文件重点介绍预处理类Transformer的成员函数。

1、preprocess

注意到函数的注释部分表明了预处理的全部流程,包括:

转化为单精度;

resize到统一尺寸;

维度转化为(channel,h,w);

    通道交换,转化为BGR;

减去均值前缩放;    

    减去均值;

减去均值后缩放。

重要代码:

...

# 返回[h,w]

in_dims = self.inputs[in_][2:]

# 输入图像和规定尺寸不一样,则resize统一

if caffe_in.shape[:2] != in_dims:

    caffe_in = resize_image(caffe_in, in_dims)

# 维度转化

if transpose is not None:

    caffe_in = caffe_in.transpose(transpose)

# 通道交换,指的是channel的交换,h和w不变

if channel_swap is not None:

    caffe_in = caffe_in[channel_swap, :, :]

# 乘法

if raw_scale is not None:

    caffe_in *= raw_scale

# 减法

if mean is not None:

    caffe_in -= mean

# 乘法

if input_scale is not None:

    caffe_in *= input_scale

return caffe_in

2、Load_image,注意color参数默认True

# 利用skimage工具读入图片,默认读入彩色图片,如果as_grey为1,则读入灰度图片;读入值为[0,1]的浮点数

img = skimage.img_as_float(skimage.io.imread(filename, 

                       as_grey=not color)).astype(np.float32)

# 保证返回的是三维的数组。

    if img.ndim == 2:

    # 如果读入只有二维,需要增加维度

        img = img[:, :, np.newaxis]

        if color:

            # 如果是灰度图像却以彩色图片读入,那么扩充为三个通道

            img = np.tile(img, (1, 1, 3))

     elif img.shape[2] == 4:

         # 如果有四个通道,去掉第四个通道

         img = img[:, :, :3]

    # 返回(h,w,3)的数组

return img

 

还有resize_image,oversample,以及各种set函数,这里就不一一介绍了。所谓caffe的python接口或者matlab接口,都是对caffe的输入预处理以及输出结果的处理,而不用理会网络计算的中间过程。

 

 

相关阅读:
Top