使用方式
下载工程
创建和激活虚拟环境
安装Python依赖库
开发流程
定义自己的数据加载类,继承DataLoaderBase;
定义自己的网络结构类,继承ModelBase;
定义自己的模型训练类,继承TrainerBase;
定义自己的样本预测类,继承InferBase;
定义自己的配置文件,写入实验的相关参数;
执行训练模型和预测样本操作。
示例工程
识别MNIST库中手写数字,工程
训练:
预测:
网络结构
TensorBoard
工程架构
框架图
文件夹结构
主要组件
DataLoader
操作步骤:
创建自己的加载数据类,继承DataLoaderBase基类;
覆写和,返回训练和测试数据;
Model
操作步骤:
创建自己的网络结构类,继承ModelBase基类;
覆写,创建网络结构;
在构造器中,调用;
注意:支持绘制网络结构;
Trainer
操作步骤:
创建自己的训练类,继承TrainerBase基类;
参数:网络结构model、训练数据data;
覆写,fit数据,训练网络结构;
注意:支持在训练中调用callbacks,额外添加模型存储、TensorBoard、FPR度量等。
Infer
操作步骤:
创建自己的预测类,继承InferBase基类;
覆写,提供模型加载功能;
覆写,提供样本预测功能;
Config
定义在模型训练过程中所需的参数,JSON格式,支持:学习率、Epoch、Batch等参数。
Main
训练:
创建配置文件config;
创建数据加载类dataloader;
创建网络结构类model;
创建训练类trainer,参数是训练和测试数据、模型;
执行训练类trainer的train();
预测:
创建配置文件config;
处理预测样本test;
创建预测类infer;
执行预测类infer的predict();
原文:https://github.com/SpikeKing/DL-Project-Template
- 加入人工智能学院系统学习 -
领取专属 10元无门槛券
私享最新 技术干货