我们在《深度学习caffe--手写字体识别例程(四)》中,用到了convert_mnist_data.bin文件进行数据集格式的转换,命令如下
$BUILD/convert_mnist_data.bin $DATA/train-images-idx3-ubyte \
$DATA/train-labels-idx1-ubyte $EXAMPLE/mnist_train_${BACKEND} --backend=${BACKEND}
它的作用是将mnist数据集转换为lmdb或leveldb格式的文件,以便用于深度学习的训练。这篇文章我们就来研究convert_mnist_data.bin这个文件是如何实现的。convert_mnist_data.bin文件的源文件在example/mnist/目录下,文件名为convert_mnist_data.cpp,由于这个文件中的代码比较长,我们下面把代码贴出来,并在每行或几行的代码下面进行解释。
#include
#include
#include
#if defined(USE_LEVELDB) && defined(USE_LMDB)
#include
#include
#include
#endif
#include
#include
#include // NOLINT(readability/streams)
#include
#include "boost/scoped_ptr.hpp"
#include "caffe/proto/caffe.pb.h"
#include "caffe/util/db.hpp"
#include "caffe/util/format.hpp"
这些代码是文件包含的头文件,是文件中需要使用到的头文件。
#if defined(USE_LEVELDB) && defined(USE_LMDB)
这是一个判断的宏,如果满足判断条件,则编译下方的代码,否则编译#else下面的代码。我们总览这个文件,发现#else在文件的结尾处,只包含了几行代码。这个宏的根本作用在于,判断是否定义了USE_LEVELDB和USE_LMDB,如果定义了则进行文件格式转换的操作,否则,不操作。这两个宏是在编译caffe源码的时候定义的。
using namespace caffe; // NOLINT(build/namespaces)
using boost::scoped_ptr;
using std::string;
这3行是这个文件需要用到的库。
DEFINE_string(backend, "lmdb", "The backend for storing the result");
这行代码在这个文件中没能找到DEFINE_string的定义。其实它是在gflags.h文件中定义的,这个文件在/usr/include/gflags/目录下,有兴趣可以打开文件研究一下,DEFINE_string是一个宏定义,这里我们只介绍一下它的作用。调用DEFINE_string之后,会生成基于backend生成一个变量FLAGS_backend,并且变量的取值为“lmdb”,"The backend for storing the result"是这个变量的说明。
uint32_t swap_endian(uint32_t val) {
val = ((val <> 8) & 0xFF00FF);
return (val <> 16);
}
这段代码是一个函数,它的作用是对32位的整形变量进行大小端转换,在《深度学习caffe--手写字体识别例程(三)》中,我们介绍了,在mnist数据集中,多字节的数据是按照大端模式存储的,也就是数据的高字节存在低地址,如果我们进行数据读取数据读出来之后,字节顺序是反的。比如一个32字节的数据0x12345678,它在mnist文件中存储时,相对地址0地址为12,1地址为34,2地址为56,3地址为78。当从文件中读取32位的数据时,读出来的是0x78563412,与原始数据正好是反的。所以需要用这个函数进行转换。
还是以0x12345678为例,从mnist中读出的值为0x78563412,调用这个函数时,将0x78563412赋值给val。在函数中,((val <> 8) & 0xFF00FF)将val右移8位并与0xFF00FF按位做与运算,得到的结果是0x00780034。两个结果再做按位或运算并赋值给val,则val =0x56001200 | 0x00780034=0x56781234。
函数的最后一行(val <> 16)将val右移16位得到的结果为0x00005678,两个结果做按位或运算,得到的结果为0x12345678,并将结果返回。经过这一系列的操作,实现了数据的转换。
void convert_dataset(const char* image_filename, const char* label_filename,
const char* db_path, const string& db_backend) {
这段代码是数据转换的函数定义,这个函数是这个文件的核心函数,就是它实现了mnist的二进制文件到lmdb文件的转换。函数的形参分别为
image_filename图片文件名
label_filename标签文件名
db_path生成文件的存储路径
db_backend生成文件的尾缀,指定文件类型,即lmdb还是leveldb。
std::ifstream image_file(image_filename, std::ios::in | std::ios::binary);
std::ifstream label_file(label_filename, std::ios::in | std::ios::binary);
这两行代码的作用是实例化两个std::ifstream对象,它们是流式文件的对象,这两个对象将图片文件和标签文件输入,以二进制的方式打开,通过这两个对象就可以访问图片文件和标签文件。
CHECK(image_file) << "Unable to open file " << image_filename;
CHECK(label_file) << "Unable to open file " << label_filename;
这两行代码的作用是检查image_file和label_file这两个对象,是否为0,如果为0,则打印信息“Unable to open file +文件名”,并退出程序。其中CHECK是个宏定义,它在logging.h中定义,这个文件在/usr/include/glog/目录下,它的作用是判断括号内的条件是否为0,如果为0则打印后边的内容,并退出程序。有兴趣可以翻看logging.h文件研究一下。
uint32_t magic;
uint32_t num_items;
uint32_t num_labels;
uint32_t rows;
uint32_t cols;
这段代码定义了5个变量,它们分别用来保存魔数、条目数、标签数、图片的行数、图片的列数。这些变量的含义在《深度学习caffe--手写字体识别例程(三)》中有详细介绍,可以参考,它们都是从mnist数据集中读取出来的。
image_file.read(reinterpret_cast(&magic), 4);
magic = swap_endian(magic);
CHECK_EQ(magic, 2051) << "Incorrect image file magic.";
这3行代码的作用是从之前实例化的图片文件的流式文件对象中读取4个字节的数据,并且保存到magic变量中,关于mnist数据集中的图片文件的格式定义可以参考《深度学习caffe--手写字体识别例程(三)》。在图片文件中前4个字节的数据就是魔数。
第2行将magic进行大小端变换,这是因为变量在文件中是按照大端存储的。
第3行对magic进行检测,看它与2051是否相等,如果不相等,输出信息,并退出程序。CHECK_EQ是宏定义,与CHECK类似,它也是在logging.h中定义的。
label_file.read(reinterpret_cast(&magic), 4);
magic = swap_endian(magic);
CHECK_EQ(magic, 2049) << "Incorrect label file magic.";
这3行代码的作用是读取标签文件的魔数,并做大小端变换,最后检测它与2049是否相等,如果不相等则打印信息并退出。
image_file.read(reinterpret_cast(&num_items), 4);
num_items = swap_endian(num_items);
这2行代码的作用是读取图片文件的图片条目数,并做大小端变换。
label_file.read(reinterpret_cast(&num_labels), 4);
num_labels = swap_endian(num_labels);
这2行代码的作用是读取标签文件的标签条目数,并做大小端变换。
CHECK_EQ(num_items, num_labels);
这行代码的作用是判断图片条目数和标签条目数是否相等,不相等则退出。图片与标签是一一对应的,如果不相等说明原始文件有问题。
image_file.read(reinterpret_cast(&rows), 4);
rows = swap_endian(rows);
image_file.read(reinterpret_cast(&cols), 4);
cols = swap_endian(cols);
这4行的作用是读取图片的行数和列数,并做大小端变换,由《深度学习caffe--手写字体识别例程(三)》中的介绍,我们知道图片的行数和列数都是28。
scoped_ptr db(db::GetDB(db_backend));
db->Open(db_path, db::NEW);
scoped_ptr txn(db->NewTransaction());
这3行的作用首先定义一个指向db_backend类型数据库的指针,然后新建数据库,并打开。最后定义一个指向数据库事务的指针txn,这个指针指向数据库指针db指向的数据库事务。这个事务下面主要被用作数据的转换存储。
char label;
这一句定义了一个char型变量label,它下面被用作保存标签值。
char* pixels = new char[rows * cols];
这行代码用来定义一个指向char型变量的指针,它指向一个大小为rows * cols的char型数组。它在下面用来保存一副图片的数据。
int count = 0;
string value;
这两行分别定义了一个int型count变量和一个string类型的value变量。
Datum datum;
datum.set_channels(1);
datum.set_height(rows);
datum.set_width(cols);
这几行首先定义了一个Datum类型的变量datum,Datum数据类型在caffe.proto文件中定义,这个文件位于caffe根目录的src/caffe/proto/路径下,有兴趣可以对照着caffe.proto文件对这个数据类型进行深入研究,Datum中包含的主要数据有:
channels:图片的通道数,代码中取值为1。
height:图片的高,在手写体识别例程中,取值为28
width:图片的宽,在手写体识别例程中,取值为28
data:图片的数据,在手写体识别例程中,data中包含28*28=784个数据
label:图片的label
LOG(INFO) << "A total of " << num_items << " items.";
LOG(INFO) << "Rows: " << rows << " Cols: " << cols;
这两行代码的作用是在终端上输出总的条目数和行数和列数。
for (int item_id = 0; item_id Put(key_str, value);
if (++count % 1000 == 0) {
txn->Commit();
}
}
这段代码为一个for循环,循环次数为图片文件的条目数,对每个条目进行遍历。经过上面的读取操作,图片文件对象image_file的指针已经指到了第一幅图片的位置,标签文件对象label_file的指针已经指到了第一个标签的位置。进入到for循环中,首先对图片文件进行读取,读取的大小为一副图片,并保存到pixels指向的存储区,然后读取标签文件,读取一个字节,即一个标签,并保存到label变量中。
接下来datum.set_data(pixels, rows*cols);将图片数据保存到datum数据结构中。datum.set_label(label);将标签保存到datum数据结构中。
然后string key_str = caffe::format_int(item_id, 8);定义了一个名字为key_str的字符串,它保存的是调用caffe::format_int()函数生成的字符串,它将item_id的值转换为8个字节的字符串的格式,比如item_id的取值为25时,转换完的字符串为“00000025”,这个字符串被用作下面数据库存储的键值。
再下面datum.SerializeToString(&value);将datum数据结构中的数据转换为字符串,并保存到value中。
最后将键值key_str和图像数据value写入数据库。
for循环的最后,每次count加1,如果count是1000的整数倍时,数据库提交一次。
if (count % 1000 != 0) {
txn->Commit();
}
这段代码在for循环外边,判断如果count不是1000的整数倍,说明for循环退出时,还有没提交的数据,则再提交一次。
LOG(INFO) << "Processed " << count <Close();
}
这几行是数据转换函数的末尾,打印转换完成的条目数,释放pixels指向的存储空间,并关闭数据库。
int main(int argc, char** argv) {
#ifndef GFLAGS_GFLAGS_H_
namespace gflags = google;
#endif
FLAGS_alsologtostderr = 1;
gflags::SetUsageMessage("This script converts the MNIST dataset to\n"
"the lmdb/leveldb format used by Caffe to load data.\n"
"Usage:\n"
" convert_mnist_data [FLAGS] input_image_file input_label_file "
"output_db_file\n"
"The MNIST dataset could be downloaded at\n"
" http://yann.lecun.com/exdb/mnist/\n"
"You should gunzip them after downloading,"
"or directly use data/mnist/get_mnist.sh\n");
gflags::ParseCommandLineFlags(&argc, &argv, true);
这段代码是主函数的开始部分,主要是gflags相关的一些操作,这些对数据转换过程基本不会有影响,只是显示一些信息,所以这里不对这些代码进行深究。
const string& db_backend = FLAGS_backend;
这一行代码定义了一个FLAGS_backend的引用db_backend,FLAGS_backend是在代码的开头调用DEFINE_string宏进行定义的,它的取值为“lmdb”,&表示它后边定义的是一个引用。
if (argc != 4) {
gflags::ShowUsageWithFlagsRestrict(argv[0],
"examples/mnist/convert_mnist_data");
} else {
google::InitGoogleLogging(argv[0]);
convert_dataset(argv[1], argv[2], argv[3], db_backend);
}
return 0;
}
这段代码是main函数的结尾,首先判断命令行参数个数是否为4,如果不是4,说明输入的命令不对,则打印信息。否则,输入的命令行参数为4则初始化日志,并调用convert_dataset()函数进行数据转换。
#else
int main(int argc, char** argv) {
LOG(FATAL) << "This example requires LevelDB and LMDB; " <<
"compile with USE_LEVELDB and USE_LMDB.";
}
#endif // USE_LEVELDB and USE_LMDB
这段是#else的宏,它是相对于#if defined(USE_LEVELDB) && defined(USE_LMDB)的,判读如果没有定义USE_LEVELDB或USE_LMDB,则打印错误信息并退出。
作者:fxfreefly