前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >basenji框架揭秘

basenji框架揭秘

作者头像
bye
发布2021-03-20 14:44:49
5170
发布2021-03-20 14:44:49
举报
文章被收录于专栏:bye漫漫求学路

目录

  • basenji_train.py:
  • 读入json文件
  • 构建模型seqnn_model
  • 构建seqnn_trainer
  • 编译并训练模型

basenji_train.py:

代码运行流程:根据params_small.json文件获取模型参数与训练参数,然后使用seqnn.SeqNN类构建模型,然后使用trainer.Trainer类构建seqnn_trainer,以对模型进行训练,然后通过seqnn_trainer调用compile与fit函数执行训练。

读入json文件

使用以下代码读取params_small.json文件,将模型和训练的参数传给basenji_train.py的params变量

代码语言:javascript
复制
 with open(params_file) as params_open:
   params = json.load(params_open)
 params_model = params['model'] # model参数也是一个字典
 params_train = params['train']# train参数也是一个字典

params_model内容如下: (其中trunk字典是模型每一层结构的参数,后面将被设置为seqnn_model实例的属性)

代码语言:javascript
复制
"model": {
    "seq_length": 131072,
    "target_length": 1024,
"augment_rc": true,
"augment_shift": 3,

"activation": "gelu",
"batch_norm": true,
"bn_momentum": 0.9,

"trunk": [
    {
        "name": "conv_block",
        "filters": 64,
        "kernel_size": 15,
        "pool_size": 8
    },
    {
        "name": "conv_tower",
        "filters_init": 64,
        "filters_mult": 1.125,
        "kernel_size": 5,
        "pool_size": 4,
        "repeat": 2
    },
    {
        "name": "dilated_residual",
        "filters": 32,
        "rate_mult": 2,
        "repeat": 6,
        "dropout": 0.25
    },
    {
        "name": "conv_block",
        "filters": 64,
        "dropout": 0.05
    }
],
"head": {
    "name": "dense",
    "units": 3,
    "activation": "softplus"
}

}

构建模型seqnn_model

接下来将params_model传递给seqnn.SeqNN来构建模型,所用命令如下:

代码语言:javascript
复制
seqnn_model = seqnn.SeqNN(params_model) # line:104
# seqnn_model的属性包含params_model中的参数,
# 此外可以用seqnn_model.model.summary()查看模型的信息。

该类存放于seqnn.py文件中,其方法有:

  1. init
代码语言:javascript
复制
  def __init__(self, params):
    self.set_defaults()
    for key, value in params.items():
      self.__setattr__(key, value) # 将params里的属性设置为该类的实例
    self.build_model()
    self.ensemble = None
    self.embed = None

params参数即params_model字典,该构造函数将params_model中的键值对设置成seqnn_model实例的属性。

  1. build_blocksdef build_block(self, current, block_params) 参数1:current,即输入,由tf.keras.Input生成,并不带有实际的数据, 参数2:block_params,字典形式,即params_model[‘trunk’]字典,这里面存放的是对模型的每一层的参数定义 功能:使用blocks.py中所定义的block(即卷积,全连接等操作)来对输入进行操作,并返回一个current
  2. build_modeldef build_model(self, save_reprs=False) 该函数依次读取self.trunk中的参数,然后调用build_block函数构建对应的网络结构。

构建seqnn_trainer

代码语言:javascript
复制
seqnn_trainer = trainer.Trainer(params_train, train_data, eval_data, options.out_dir)

编译并训练模型

代码如下所示:

代码语言:javascript
复制
seqnn_trainer.compile(seqnn_model) 
seqnn_trainer.fit(seqnn_model)
本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2021/03/02 ,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 目录
  • basenji_train.py:
  • 读入json文件
  • 构建模型seqnn_model
  • 构建seqnn_trainer
  • 编译并训练模型
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档