深度学习caffe--手写字体识别例程(五)—— convert_mnist_data.cpp文件详解

Winona ·
更新时间:2024-09-21
· 897 次阅读

        我们在《深度学习caffe--手写字体识别例程(四)》中,用到了convert_mnist_data.bin文件进行数据集格式的转换,命令如下

$BUILD/convert_mnist_data.bin $DATA/train-images-idx3-ubyte \ $DATA/train-labels-idx1-ubyte $EXAMPLE/mnist_train_${BACKEND} --backend=${BACKEND}

它的作用是将mnist数据集转换为lmdb或leveldb格式的文件,以便用于深度学习的训练。这篇文章我们就来研究convert_mnist_data.bin这个文件是如何实现的。convert_mnist_data.bin文件的源文件在example/mnist/目录下,文件名为convert_mnist_data.cpp,由于这个文件中的代码比较长,我们下面把代码贴出来,并在每行或几行的代码下面进行解释。

#include #include #include #if defined(USE_LEVELDB) && defined(USE_LMDB) #include #include #include #endif #include #include #include // NOLINT(readability/streams) #include #include "boost/scoped_ptr.hpp" #include "caffe/proto/caffe.pb.h" #include "caffe/util/db.hpp" #include "caffe/util/format.hpp"

        这些代码是文件包含的头文件,是文件中需要使用到的头文件。

#if defined(USE_LEVELDB) && defined(USE_LMDB)

        这是一个判断的宏,如果满足判断条件,则编译下方的代码,否则编译#else下面的代码。我们总览这个文件,发现#else在文件的结尾处,只包含了几行代码。这个宏的根本作用在于,判断是否定义了USE_LEVELDB和USE_LMDB,如果定义了则进行文件格式转换的操作,否则,不操作。这两个宏是在编译caffe源码的时候定义的。

using namespace caffe; // NOLINT(build/namespaces) using boost::scoped_ptr; using std::string;

        这3行是这个文件需要用到的库。

DEFINE_string(backend, "lmdb", "The backend for storing the result");

        这行代码在这个文件中没能找到DEFINE_string的定义。其实它是在gflags.h文件中定义的,这个文件在/usr/include/gflags/目录下,有兴趣可以打开文件研究一下,DEFINE_string是一个宏定义,这里我们只介绍一下它的作用。调用DEFINE_string之后,会生成基于backend生成一个变量FLAGS_backend,并且变量的取值为“lmdb”,"The backend for storing the result"是这个变量的说明。

uint32_t swap_endian(uint32_t val) { val = ((val <> 8) & 0xFF00FF); return (val <> 16); }

        这段代码是一个函数,它的作用是对32位的整形变量进行大小端转换,在《深度学习caffe--手写字体识别例程(三)》中,我们介绍了,在mnist数据集中,多字节的数据是按照大端模式存储的,也就是数据的高字节存在低地址,如果我们进行数据读取数据读出来之后,字节顺序是反的。比如一个32字节的数据0x12345678,它在mnist文件中存储时,相对地址0地址为12,1地址为34,2地址为56,3地址为78。当从文件中读取32位的数据时,读出来的是0x78563412,与原始数据正好是反的。所以需要用这个函数进行转换。

       还是以0x12345678为例,从mnist中读出的值为0x78563412,调用这个函数时,将0x78563412赋值给val。在函数中,((val <> 8) & 0xFF00FF)将val右移8位并与0xFF00FF按位做与运算,得到的结果是0x00780034。两个结果再做按位或运算并赋值给val,则val =0x56001200 | 0x00780034=0x56781234。

       函数的最后一行(val <> 16)将val右移16位得到的结果为0x00005678,两个结果做按位或运算,得到的结果为0x12345678,并将结果返回。经过这一系列的操作,实现了数据的转换。

void convert_dataset(const char* image_filename, const char* label_filename, const char* db_path, const string& db_backend) {

        这段代码是数据转换的函数定义,这个函数是这个文件的核心函数,就是它实现了mnist的二进制文件到lmdb文件的转换。函数的形参分别为

image_filename图片文件名

label_filename标签文件名

db_path生成文件的存储路径

db_backend生成文件的尾缀,指定文件类型,即lmdb还是leveldb。

std::ifstream image_file(image_filename, std::ios::in | std::ios::binary); std::ifstream label_file(label_filename, std::ios::in | std::ios::binary);

       这两行代码的作用是实例化两个std::ifstream对象,它们是流式文件的对象,这两个对象将图片文件和标签文件输入,以二进制的方式打开,通过这两个对象就可以访问图片文件和标签文件。

CHECK(image_file) << "Unable to open file " << image_filename; CHECK(label_file) << "Unable to open file " << label_filename;

        这两行代码的作用是检查image_file和label_file这两个对象,是否为0,如果为0,则打印信息“Unable to open file +文件名”,并退出程序。其中CHECK是个宏定义,它在logging.h中定义,这个文件在/usr/include/glog/目录下,它的作用是判断括号内的条件是否为0,如果为0则打印后边的内容,并退出程序。有兴趣可以翻看logging.h文件研究一下。

uint32_t magic; uint32_t num_items; uint32_t num_labels; uint32_t rows; uint32_t cols;

这段代码定义了5个变量,它们分别用来保存魔数、条目数、标签数、图片的行数、图片的列数。这些变量的含义在《深度学习caffe--手写字体识别例程(三)》中有详细介绍,可以参考,它们都是从mnist数据集中读取出来的。

image_file.read(reinterpret_cast(&magic), 4); magic = swap_endian(magic); CHECK_EQ(magic, 2051) << "Incorrect image file magic.";

        这3行代码的作用是从之前实例化的图片文件的流式文件对象中读取4个字节的数据,并且保存到magic变量中,关于mnist数据集中的图片文件的格式定义可以参考《深度学习caffe--手写字体识别例程(三)》。在图片文件中前4个字节的数据就是魔数。

        第2行将magic进行大小端变换,这是因为变量在文件中是按照大端存储的。

        第3行对magic进行检测,看它与2051是否相等,如果不相等,输出信息,并退出程序。CHECK_EQ是宏定义,与CHECK类似,它也是在logging.h中定义的。

label_file.read(reinterpret_cast(&magic), 4); magic = swap_endian(magic); CHECK_EQ(magic, 2049) << "Incorrect label file magic.";

        这3行代码的作用是读取标签文件的魔数,并做大小端变换,最后检测它与2049是否相等,如果不相等则打印信息并退出。

image_file.read(reinterpret_cast(&num_items), 4); num_items = swap_endian(num_items);

        这2行代码的作用是读取图片文件的图片条目数,并做大小端变换。

label_file.read(reinterpret_cast(&num_labels), 4); num_labels = swap_endian(num_labels);

       这2行代码的作用是读取标签文件的标签条目数,并做大小端变换。

CHECK_EQ(num_items, num_labels);

        这行代码的作用是判断图片条目数和标签条目数是否相等,不相等则退出。图片与标签是一一对应的,如果不相等说明原始文件有问题。

image_file.read(reinterpret_cast(&rows), 4); rows = swap_endian(rows); image_file.read(reinterpret_cast(&cols), 4); cols = swap_endian(cols);

        这4行的作用是读取图片的行数和列数,并做大小端变换,由《深度学习caffe--手写字体识别例程(三)》中的介绍,我们知道图片的行数和列数都是28。

scoped_ptr db(db::GetDB(db_backend)); db->Open(db_path, db::NEW); scoped_ptr txn(db->NewTransaction());

        这3行的作用首先定义一个指向db_backend类型数据库的指针,然后新建数据库,并打开。最后定义一个指向数据库事务的指针txn,这个指针指向数据库指针db指向的数据库事务。这个事务下面主要被用作数据的转换存储。

char label;

       这一句定义了一个char型变量label,它下面被用作保存标签值。

char* pixels = new char[rows * cols];

       这行代码用来定义一个指向char型变量的指针,它指向一个大小为rows * cols的char型数组。它在下面用来保存一副图片的数据。

int count = 0; string value;

        这两行分别定义了一个int型count变量和一个string类型的value变量。

Datum datum; datum.set_channels(1); datum.set_height(rows); datum.set_width(cols);

        这几行首先定义了一个Datum类型的变量datum,Datum数据类型在caffe.proto文件中定义,这个文件位于caffe根目录的src/caffe/proto/路径下,有兴趣可以对照着caffe.proto文件对这个数据类型进行深入研究,Datum中包含的主要数据有:

channels:图片的通道数,代码中取值为1。

height:图片的高,在手写体识别例程中,取值为28

width:图片的宽,在手写体识别例程中,取值为28

data:图片的数据,在手写体识别例程中,data中包含28*28=784个数据

label:图片的label

LOG(INFO) << "A total of " << num_items << " items."; LOG(INFO) << "Rows: " << rows << " Cols: " << cols;

        这两行代码的作用是在终端上输出总的条目数和行数和列数。

for (int item_id = 0; item_id Put(key_str, value); if (++count % 1000 == 0) { txn->Commit(); } }

      这段代码为一个for循环,循环次数为图片文件的条目数,对每个条目进行遍历。经过上面的读取操作,图片文件对象image_file的指针已经指到了第一幅图片的位置,标签文件对象label_file的指针已经指到了第一个标签的位置。进入到for循环中,首先对图片文件进行读取,读取的大小为一副图片,并保存到pixels指向的存储区,然后读取标签文件,读取一个字节,即一个标签,并保存到label变量中。

      接下来datum.set_data(pixels, rows*cols);将图片数据保存到datum数据结构中。datum.set_label(label);将标签保存到datum数据结构中。

       然后string key_str = caffe::format_int(item_id, 8);定义了一个名字为key_str的字符串,它保存的是调用caffe::format_int()函数生成的字符串,它将item_id的值转换为8个字节的字符串的格式,比如item_id的取值为25时,转换完的字符串为“00000025”,这个字符串被用作下面数据库存储的键值。

       再下面datum.SerializeToString(&value);将datum数据结构中的数据转换为字符串,并保存到value中。

       最后将键值key_str和图像数据value写入数据库。

       for循环的最后,每次count加1,如果count是1000的整数倍时,数据库提交一次。

if (count % 1000 != 0) { txn->Commit(); }

        这段代码在for循环外边,判断如果count不是1000的整数倍,说明for循环退出时,还有没提交的数据,则再提交一次。

LOG(INFO) << "Processed " << count <Close(); }

        这几行是数据转换函数的末尾,打印转换完成的条目数,释放pixels指向的存储空间,并关闭数据库。

int main(int argc, char** argv) { #ifndef GFLAGS_GFLAGS_H_ namespace gflags = google; #endif FLAGS_alsologtostderr = 1; gflags::SetUsageMessage("This script converts the MNIST dataset to\n" "the lmdb/leveldb format used by Caffe to load data.\n" "Usage:\n" " convert_mnist_data [FLAGS] input_image_file input_label_file " "output_db_file\n" "The MNIST dataset could be downloaded at\n" " http://yann.lecun.com/exdb/mnist/\n" "You should gunzip them after downloading," "or directly use data/mnist/get_mnist.sh\n"); gflags::ParseCommandLineFlags(&argc, &argv, true);

        这段代码是主函数的开始部分,主要是gflags相关的一些操作,这些对数据转换过程基本不会有影响,只是显示一些信息,所以这里不对这些代码进行深究。

const string& db_backend = FLAGS_backend;

        这一行代码定义了一个FLAGS_backend的引用db_backend,FLAGS_backend是在代码的开头调用DEFINE_string宏进行定义的,它的取值为“lmdb”,&表示它后边定义的是一个引用。

if (argc != 4) { gflags::ShowUsageWithFlagsRestrict(argv[0], "examples/mnist/convert_mnist_data"); } else { google::InitGoogleLogging(argv[0]); convert_dataset(argv[1], argv[2], argv[3], db_backend); } return 0; }

        这段代码是main函数的结尾,首先判断命令行参数个数是否为4,如果不是4,说明输入的命令不对,则打印信息。否则,输入的命令行参数为4则初始化日志,并调用convert_dataset()函数进行数据转换。

#else int main(int argc, char** argv) { LOG(FATAL) << "This example requires LevelDB and LMDB; " << "compile with USE_LEVELDB and USE_LMDB."; } #endif // USE_LEVELDB and USE_LMDB

        这段是#else的宏,它是相对于#if defined(USE_LEVELDB) && defined(USE_LMDB)的,判读如果没有定义USE_LEVELDB或USE_LMDB,则打印错误信息并退出。


作者:fxfreefly



caffe 手写字体 字体 mnist convert

需要 登录 后方可回复, 如果你还没有账号请 注册新账号