首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >如何总结定制的tensorflow模型

如何总结定制的tensorflow模型
EN

Stack Overflow用户
提问于 2022-07-22 04:25:01
回答 1查看 67关注 0票数 0

如何获得定制的tensorflow模型摘要?

代码语言:javascript
运行
复制
class Discriminator_block(tf.keras.layers.Layer):
  def __init__(self, num_strides):
    super(Discriminator_block, self).__init__(name='discriminator block')
    self.num_strides = num_strides
    self.conv1 = tf.keras.layers.Conv2D(filters=128, kernel_size=(3,3), strides=(num_strides, num_strides), padding='same', data_format='channels_first', activation=None)
    self.bn1 = tf.keras.layers.BatchNormalization(axis=1)
    self.leaky = keras.layers.advanced_activations.LeakyReLU()

  def forward(self, x):
    x = self.conv1(x)
    x = self.bn1(x)
    x = self.leaky(x)
    return x

我已经使用tensorflow编写了我自己的鉴别器块,我希望看到我的模型的摘要

所以我加入了

代码语言:javascript
运行
复制
Discriminator_block.summary()

但我发现了一个错误:

代码语言:javascript
运行
复制
'Discriminator_block' object has no attribute 'summary'

我在代码中犯了什么错误?

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2022-07-22 07:06:55

您需要更改并考虑代码中的多个步骤,以打印自定义块的summary

重要的是将tf.keras.Model

  • Create更改为模型的一个实例,

  • 向您正在构建的模型输入一个随机张量,然后您就可以得到模型的summary

代码语言:javascript
运行
复制
import tensorflow as tf
class Discriminator_block(tf.keras.Model):
    def __init__(self, num_strides):
        super(Discriminator_block, self).__init__(name='discriminator block')
        self.num_strides = num_strides
        self.conv1 = tf.keras.layers.Conv2D(filters=16, kernel_size=(3,3), 
                                            strides=(num_strides, num_strides), 
                                            padding='same', 
                                            activation='relu', 
                                            input_shape=(28,28,3))
        self.bn1 = tf.keras.layers.BatchNormalization(axis=1)
        self.leaky = tf.keras.layers.LeakyReLU()
        
    def call(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.leaky(x)
        return x
    
block = Discriminator_block(num_strides = 1)
_ = block(tf.random.normal(shape=[2, 28, 28, 3]))
block.summary()

代码语言:javascript
运行
复制
Model: "discriminator block"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 conv2d_1 (Conv2D)           multiple                  448       
                                                                 
 batch_normalization_1 (Batc  multiple                 112       
 hNormalization)                                                 
                                                                 
 leaky_re_lu_1 (LeakyReLU)   multiple                  0         
                                                                 
=================================================================
Total params: 560
Trainable params: 504
Non-trainable params: 56
_________________________________________________________________
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/73075106

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档