前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >MXNet源码解读笔记1 ---- 如何解析参数文件

MXNet源码解读笔记1 ---- 如何解析参数文件

作者头像
BBuf
发布2020-04-29 09:20:07
9040
发布2020-04-29 09:20:07
举报
文章被收录于专栏:GiantPandaCV

前言

本文主要内容是解读MXNet加载并解析模型参数文件所涉及到的代码,希望读者读完本文能对MXNet参数文件的存储格式有比较清晰的了解,并可以自己来实现参数文件的解析。

解析MXNet参数文件C++小工程:https://github.com/Ldpe2G/DeepLearningForFun/tree/master/MXNet-Cpp/parsingNDArray

本文解读的MXNet代码基于版本:commit 7d2c9bf3b631433132452760734b684e39170814

Python前端代码入口

首先看从MXNet Python前端入口是如何读取NDArray参数文件的,这部分代码见 ${MXNET_ROOT}/python/mxnet/ndarray/utils.py 第149行:

代码语言:javascript
复制
def load(fname):
    if not isinstance(fname, string_types):
        raise TypeError('fname required to be a string')
    out_size = mx_uint()
    out_name_size = mx_uint()
    handles = ctypes.POINTER(NDArrayHandle)()
    names = ctypes.POINTER(ctypes.c_char_p)()
    check_call(_LIB.MXNDArrayLoad(c_str(fname),
                                  ctypes.byref(out_size),
                                  ctypes.byref(handles),
                                  ctypes.byref(out_name_size),
                                  ctypes.byref(names)))
    .....

这个 load 函数接收参数路径作为输入,然后调用了 MXNDArrayLoad 这个中间层的C接口函数来读取参数,MXNet底层是C++实现并提供了一层C函数的接口供前端语言去调用。

C接口层

接着来看下MXNDArrayLoad接口的实现,这部分代码见${MXNET_ROOT}/src/c_api/c_api.cc第1344行:

代码语言:javascript
复制
int MXNDArrayLoad(const char* fname,
                  uint32_t *out_size,
                  NDArrayHandle** out_arr,
                  uint32_t *out_name_size,
                  const char*** out_names) {
  MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get();
  std::vector<NDArray> data;
  std::vector<std::string> &names = ret->ret_vec_str;
  {
    std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(fname, "r"));
    mxnet::NDArray::Load(fi.get(), &data, &names);
  }
  ......
}

核心代码就是首先打开文件流,接着调用NDArray类的静态函数mxnet::NDArray::Load读取并解析参数文件,然后将得到参数NDArray数组保存到data这个变量里面。所以我们只要关注dmlc::Stream这个类的实现还有mxnet::NDArray::Load这个函数的具体实现就可以了。

底层C++实现

NDArray::Load静态函数

函数具体实现见${MXNET_ROOT}/src/ndarray/ndarray.cc第 1924 行:

代码语言:javascript
复制
void NDArray::Load(dmlc::Stream* fi,
                   std::vector<NDArray>* data,
                   std::vector<std::string>* keys) {
  uint64_t header, reserved;
  CHECK(fi->Read(&header))
      << "Invalid NDArray file format";
  CHECK(fi->Read(&reserved))
      << "Invalid NDArray file format";
  CHECK(header == kMXAPINDArrayListMagic)
      << "Invalid NDArray file format";
  CHECK(fi->Read(data))
      << "Invalid NDArray file format";
  CHECK(fi->Read(keys))
      << "Invalid NDArray file format";
  CHECK(keys->size() == 0 || keys->size() == data->size())
      << "Invalid NDArray file format";
}

从这里读取内容的过程可以大概看出NDArray参数文件存储的内容格式。

首先文件开头保存了两个uint64_t类型的数字,接着就是NDArray参数数组,接着是每个NDArray对应的名字数组。

读取的时候都是调用Strem类的Read函数,接下来就是看下Stream类的实现。

Stream类

看回上面打开参数文件的代码:

代码语言:javascript
复制
std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(fname, "r"));

dmlc::Stream::Create代码见dmlc-core子模块:${MXNET_ROOT}/3rdparty/dmlc-core/src/io.cc第132行:

代码语言:javascript
复制
Stream *Stream::Create(const char *uri,
                       const char * const flag,
                       bool try_create) {
  io::URI path(uri);
  return io::FileSystem::
      GetInstance(path)->Open(path, flag, try_create);
}

调用了FileSystem::GetINstance函数得到实例,并调用Open函数打开文件,这里返回的实例是LocalFileSystem类的实例,其Open函数见${MXNET_ROOT}/3rdparty/dmlc-core/src/io/local_filesys.cc第147行:

代码语言:javascript
复制
SeekStream *LocalFileSystem::Open(const URI &path,
                                  const char* const mode,
                                  bool allow_null) {
  FILE *fp = NULL;
  const char *fname = path.name.c_str();
  using namespace std;
  std::string flag = mode;
  if (flag == "r") flag = "rb";
  fp = fopen(fname, flag.c_str());
  if (fp != NULL) {
    return new FileStream(fp, false);
  } else {
    return NULL;
  }
}

为了可读性我简化了代码,可以看到就是调用std::fopen函数打开文件,并把FILE指针传给FileStream类,

代码见${MXNET_ROOT}/3rdparty/dmlc-core/src/io/local_filesys.cc第27行:

代码语言:javascript
复制
class FileStream : public SeekStream {
 public:
  explicit FileStream(FILE *fp, bool use_stdio)
      : fp_(fp), use_stdio_(use_stdio) {}
  virtual ~FileStream(void) {
    this->Close();
  }
  virtual size_t Read(void *ptr, size_t size) {
    return std::fread(ptr, 1, size, fp_);
  }

  ......
  
 private:
  std::FILE *fp_;
  bool use_stdio_;
};

可以看到FileStream继承自SeekStrem,而且成员函数Read实现的功能是调用std::fread函数从fp_文件指针里面读取size大小字节的内容,std::fread的文档见https://en.cppreference.com/w/cpp/io/c/fread

看下每个参数的解释:

代码语言:javascript
复制
  buffer - void 指针指向从文件流中读取到内容的存取目标地址
  size   - 目标地址指针每个元素字节大小,这里由于是void指针,所以size大小恒为1
  count  - 读取的字节数
  stream - 文件流

MXNet这里的实现是把需要被读取的内存指针转换成void *,这样子就可以兼容各种基本类型的指针读取,只需要记住传入的读取元素个数是 sizeof(T) * count,就是原来类型元素个数乘以每个元素对应的字节数。

接着再看回上面打开参数文件的代码:

代码语言:javascript
复制
std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(fname, "r"));

Strem::Create返回的是Strem类型,而不是SeekStrem,所以继续往上找Strem类的定义,代码见${MXNET_ROOT}/3rdparty/dmlc-core/include/dmlc/io.h第30行:

代码语言:javascript
复制
class Stream {  // NOLINT(*)
 public:

  virtual size_t Read(void *ptr, size_t size) = 0;

  static Stream *Create(const char *uri,
                        const char* const flag,
                        bool allow_null = false);
  
  template<typename T>
  inline bool Read(T *out_data);

  template<typename T>
  inline bool ReadArray(T* data, size_t num_elems);
};

为了可读性只保留了读文件相关的代码,可以看到FileStream是重写了virtual size_t Read(void *ptr, size_t size) = 0虚函数,而回看NDArray静态Load函数:

代码语言:javascript
复制
void NDArray::Load(dmlc::Stream* fi,
                   std::vector<NDArray>* data,
                   std::vector<std::string>* keys) {
  uint64_t header, reserved;
  CHECK(fi->Read(&header))
      << "Invalid NDArray file format";
  CHECK(fi->Read(&reserved))
      << "Invalid NDArray file format";
  CHECK(header == kMXAPINDArrayListMagic)
      << "Invalid NDArray file format";
  CHECK(fi->Read(data))
      << "Invalid NDArray file format";
  CHECK(fi->Read(keys))
      << "Invalid NDArray file format";
  CHECK(keys->size() == 0 || keys->size() == data->size())
      << "Invalid NDArray file format";
}

具体读参数文件内容的时候调用的是Stream类的Read(T *out_data)模板函数

代码语言:javascript
复制
  template<typename T>
  inline bool Read(T *out_data);

这个模板函数的实现很有意思,代码见${MXNET_ROOT}/3rdparty/dmlc-core/include/dmlc/io.h第455:

代码语言:javascript
复制
template<typename T>
inline bool Stream::Read(T *out_data) {
  return serializer::Handler<T>::Read(this, out_data);
}

可以看到调用了Handler::Read函数,继续跟进去看Handler的实现,代码见${MXNET_ROOT}/3rdparty/dmlc-core/include/dmlc/serializer.h第258行:

代码语言:javascript
复制
template<typename T>
struct Handler {
  ......
  /*!
   * \brief read data to stream
   * \param strm the stream to read the data.
   * \param data the pointer to the data obeject to read
   * \return whether the read is successful
   */
  inline static bool Read(Stream *strm, T *data) {
    return
    IfThenElse<dmlc::is_arithmetic<T>::value,
               ArithmeticHandler<T>,
               IfThenElse<dmlc::is_pod<T>::value && DMLC_IO_NO_ENDIAN_SWAP,
                          NativePODHandler<T>,
                          IfThenElse<dmlc::has_saveload<T>::value,
                                     SaveLoadClassHandler<T>,
                                     UndefinedSerializerFor<T>, T>,
                          T>,
               T>
    ::Read(strm, data);
  }
};

一开始看到这串代码可能会有点懵,不过没关系,接下来我们就一步步拆解这段代码,首先看IfThenElse结构体的定义。代码见${MXNET_ROOT}/3rdparty/dmlc-core/include/dmlc/serializer.h第48行:

代码语言:javascript
复制
template<bool cond, typename Then, typename Else, typename Return>
struct IfThenElse;

template<typename Then, typename Else, typename T>
struct IfThenElse<true, Then, Else, T> {
  ......
  inline static bool Read(Stream *strm, T *data) {
    return Then::Read(strm, data);
  }
};
template<typename Then, typename Else, typename T>
struct IfThenElse<false, Then, Else, T> {
  ......
  inline static bool Read(Stream *strm, T *data) {
    return Else::Read(strm, data);
  }
};

就是根据模板参数在编译期间做分支选择,根据模板参数决定调用实现分支,这里可以看到如果第一个模板参数template<bool cond, ...>true的话,就调用 Then::Read函数 ,否则调用Else::Read函数,然后看回Handler::Read函数:

代码语言:javascript
复制
  inline static bool Read(Stream *strm, T *data) {
    return
    IfThenElse<dmlc::is_arithmetic<T>::value,
               ArithmeticHandler<T>,
               IfThenElse<dmlc::is_pod<T>::value && DMLC_IO_NO_ENDIAN_SWAP,
                          NativePODHandler<T>,
                          IfThenElse<dmlc::has_saveload<T>::value,
                                     SaveLoadClassHandler<T>,
                                     UndefinedSerializerFor<T>, T>,
                          T>,
               T>
    ::Read(strm, data);
  }
};

代码就很好理解了,如果dmlc::is_arithmetic<T>::value值为true则走ArithmeticHandler<T>,否则再进行第二次判断,先来看下具体实现${MXNET_ROOT}/3rdparty/dmlc-core/include/dmlc/type_traits.h第66行:

代码语言:javascript
复制
template<typename T>
struct is_arithmetic {
  static const bool value = std::is_arithmetic<T>::value;
};

接着只要查下C++的文档看下std::is_arithmetic<T>的定义就知道模板参数类型是什么的情况下值是true或者false,C++文档解释见https://en.cppreference.com/w/cpp/types/is_arithmetic

也就是如果模板类型T是整数型或者浮点型,value的值就是true否则是false

所以如果满足条件则会调用ArithmeticHandler::Read函数,代码见${MXNET_ROOT}/3rdparty/dmlc-core/include/dmlc/serializer.h第82行:

代码语言:javascript
复制
/*! \brief Serializer for arithmetic data, handle endianness */
template<typename T>
struct ArithmeticHandler {
  ......
  inline static bool Read(Stream *strm, T *dptr) {
    bool ret = strm->Read((void*)dptr, sizeof(T)) == sizeof(T);  
    ......
    return ret;
  }
};

就是运行时调用子类重写的Read函数,从文件流中读取一个T类型元素。接着再看回其他分支选择:

代码语言:javascript
复制
  inline static bool Read(Stream *strm, T *data) {
    return
    IfThenElse<dmlc::is_arithmetic<T>::value,
               ArithmeticHandler<T>,
               IfThenElse<dmlc::is_pod<T>::value && DMLC_IO_NO_ENDIAN_SWAP,
                          NativePODHandler<T>,
                          IfThenElse<dmlc::has_saveload<T>::value,
                                     SaveLoadClassHandler<T>,
                                     UndefinedSerializerFor<T>, T>,
                          T>,
               T>
    ::Read(strm, data);
  }
};

如果不满足ArithmeticHandler的条件,则看下面一个判断dmlc::is_pod<T>,代码见${MXNET_ROOT}/3rdparty/dmlc-core/include/dmlc/type_traits.h第20行:

代码语言:javascript
复制
template<typename T>
struct is_pod {
#if DMLC_USE_CXX11
  /*! \brief the value of the traits */
  static const bool value = std::is_pod<T>::value;
#else
  /*! \brief the value of the traits */
  static const bool value = false;
#endif
};

C++文档解释https://en.cppreference.com/w/cpp/types/is_pod

如果是plain old data type值就是true,关于PODType的解释大家可以参考:

https://zhuanlan.zhihu.com/p/29734547

https://en.cppreference.com/w/cpp/named_req/PODType

NativePODHandler::Read函数的实现也是和ArithmeticHandler::Read类似,也是运行时调用Stream子类重写的Read函数,从文件流中读取一个T类型元素

代码语言:javascript
复制
template<typename T>
struct NativePODHandler {
  ......
  inline static bool Read(Stream *strm, T *dptr) {
    return strm->Read((void*)dptr, sizeof(T)) == sizeof(T);  // NOLINT(*)
  }
};

接着继续回看下一个条件判断dmlc::has_saveload<T>

代码语言:javascript
复制
  inline static bool Read(Stream *strm, T *data) {
    return
    IfThenElse<......
                          IfThenElse<dmlc::has_saveload<T>::value,
                                     SaveLoadClassHandler<T>,
                                     UndefinedSerializerFor<T>, T>,
                          T>,
               T>
    ::Read(strm, data);
  }
};

代码见${MXNET_ROOT}/3rdparty/dmlc-core/include/dmlc/type_traits.h第109行:

代码语言:javascript
复制
template<typename T>
struct has_saveload {
  /*! \brief the value of the traits */
  static const bool value = false;
};

默认值是false,不过看到${MXNET_ROOT}/include/mxnet/ndarray.h第1492行:

代码语言:javascript
复制
namespace dmlc {
/*!\brief traits */
DMLC_DECLARE_TRAITS(has_saveload, mxnet::NDArray, true);
}  // namespace dmlc

${MXNET_ROOT}/3rdparty/dmlc-core/include/dmlc/type_traits.h第125行,DMLC_DECLARE_TRAITS的宏定义:

代码语言:javascript
复制
/*! \brief macro to quickly declare traits information */
#define DMLC_DECLARE_TRAITS(Trait, Type, Value)       \
  template<>                                          \
  struct Trait<Type> {                                \
    static const bool value = Value;                  \
  }

就知道对于NDArray类来说,dmlc::has_saveload<NDArray>::value == true,所以可以判断如果模板参数类型是NDArray则会进入SaveLoadClassHandler实现:

代码语言:javascript
复制
template<typename T>
struct SaveLoadClassHandler {
  ......
  inline static bool Read(Stream *strm, T *data) {
    return data->Load(strm);
  }
};

实际就是调用了T::Load函数,也就是NDArray::Load成员函数。

Handler类还提供了其他模板参数类型的支持比如vector<T>或者std::string

代码语言:javascript
复制
template<typename T>
struct Handler<std::vector<T> > {
  ......
  inline static bool Read(Stream *strm, std::vector<T> *data) {
    return IfThenElse<dmlc::is_pod<T>::value && DMLC_IO_NO_ENDIAN_SWAP,
                      NativePODVectorHandler<T>,
                      ComposeVectorHandler<T>,
                      std::vector<T> >
    ::Read(strm, data);
  }
};

template<typename T>
struct Handler<std::basic_string<T> > {
  .....
  inline static bool Read(Stream *strm, std::basic_string<T> *data) {
    return IfThenElse<dmlc::is_pod<T>::value && (DMLC_IO_NO_ENDIAN_SWAP || sizeof(T) == 1),
                      NativePODStringHandler<T>,
                      UndefinedSerializerFor<T>,
                      std::basic_string<T> >
    ::Read(strm, data);
  }
};

相信有了前面Handler类的解释,再来理解这两个模板类的实现也就容易多了,这里不再展开有兴趣的读者可以继续去深入了解。

MXNet参数文件解析逻辑

首先给出MXNet参数文件存储内容的格式示意图:

然后根据官方代码的解析逻辑,我自己实现的参数提取代码,为了可读性简化了代码,完整代码见文章开头的github链接:

代码语言:javascript
复制
struct cpu {
  static const int kDevMask = 1 << 0;
};
struct gpu {
  static const int kDevMask = 1 << 1;
};

enum DeviceType {
  kCPU = cpu::kDevMask,
  kGPU = gpu::kDevMask,
  kCPUPinned = 3,
  kCPUShared = 5,
};

static bool Read(std::FILE *fp, void *ptr, size_t size) {
  return std::fread(ptr, 1, size, fp) == size;
} 

int32_t loadNDArrayV2(std::vector<NDArray *>& ndarrays, std::string param_file) {

  std::FILE *fp = fopen(param_file.c_str(), "rb");

  uint64_t header, reserved;
  Read(fp, (void*)(&header), sizeof(uint64_t));
  Read(fp, (void*)(&reserved), sizeof(uint64_t))

  uint64_t nd_size;
  Read(fp, (void*)(&nd_size), sizeof(uint64_t));

  size_t size = static_cast<size_t>(nd_size);
  ndarrays.resize(size);

  // read nd data
  for (size_t i = 0; i < nd_size; ++i) {
    NDArray* nd = new NDArray;
    ndarrays[i] = nd;

    uint32_t magic;
    Read(fp, (void*)(&magic), sizeof(uint32_t));

    // load storage type
    int32_t stype;
    Read(fp, (void*)(&stype), sizeof(int32_t));

    // load shape
    uint32_t ndim_{0};
    Read(fp, (void*)(&ndim_), sizeof(uint32_t));

    size_t nread = sizeof(int64_t) * ndim_;
    int64_t *data_heap_ = new int64_t[ndim_];
    Read(fp, (void*)data_heap_, nread);

    int64_t size = 1;
    for (uint32_t i=0; i<ndim_;++i) {
      size *= data_heap_[i];
      nd->shape.push_back(data_heap_[i]);
    }

    delete[] data_heap_;

    // load context 
    DeviceType dev_type;
    int32_t dev_id;
    Read(fp, (void*)(&dev_type), sizeof(dev_type));

    Read(fp, (void*)(&dev_id), sizeof(int32_t));

    // load type flag
    int32_t type_flag;
    Read(fp, (void*)(&type_flag), sizeof(int32_t));

    size_t all_size = size * mshadow_sizeof(type_flag);
    nd->numOfBytes = all_size;
    nd->data = (void *)malloc(all_size);

    Read(fp, nd->data, nd->numOfBytes);
  }

  // read nd names
  std::vector<std::string> keys;
  uint64_t keysLen;
  Read(fp, (void*)(&keysLen), sizeof(uint64_t));
  keys.resize(keysLen);

  for (uint64_t k = 0; k < keysLen; ++k) {
     uint64_t stringLen;
     Read(fp, (void*)(&stringLen), sizeof(uint64_t));
    size_t size = static_cast<size_t>(stringLen);
    keys[k].resize(size);
    if (size != 0) {
      size_t nbytes = sizeof(char) * size;
      Read(fp, (void*)(&(keys[k][0])), nbytes);
    }
  }
  std::fclose(fp);
  return kSuccess;
}

结合示意图和代码,应该就能比较好的理解参数文件的存储格式和解析方法了。


本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2020-04-28,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 GiantPandaCV 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 前言
    • Python前端代码入口
      • C接口层
        • 底层C++实现
          • NDArray::Load静态函数
          • Stream类
        • MXNet参数文件解析逻辑
        相关产品与服务
        文件存储
        文件存储(Cloud File Storage,CFS)为您提供安全可靠、可扩展的共享文件存储服务。文件存储可与腾讯云服务器、容器服务、批量计算等服务搭配使用,为多个计算节点提供容量和性能可弹性扩展的高性能共享存储。腾讯云文件存储的管理界面简单、易使用,可实现对现有应用的无缝集成;按实际用量付费,为您节约成本,简化 IT 运维工作。
        领券
        问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档