tf.keras.Model

介绍

Model 继承自 network.py中的Network calss Model(network.Network):

他的作用是: 将Layers组织成具有训练和推理功能的对象。

`Model` groups layers into an object with training and inference features

将Model实例化的两种方式

1

使用Input,output来实例化tf.keras.Model(inputs=inputs, outputs=outputs)

2

使用继承Model的子类,需要在 `__init__`中定义所有层,在 `call` 中实现模型的前向传播。
you should define your layers in `__init__` and you should implement the model’s forward pass in `call`.
training=False参数可选,其实任何参数都可选,如下optional_parameter

class MyModel(tf.keras.Model):
    def __init__(self, filter=(5, 5)):
        super().__init__()
        self.shape = shape
        self.conv = Conv2D(1, filter, padding="same", activation="relu")

    def call(self, inputs, optional_parameter, training=False):
        print(optional_parameter)
        outputs = self.conv(inputs)
        return outputs

model = MyModel()

值得注意的问题

如果在call中调用需要训练参数的层,那么必须是类的成员变量( self. ),否则模型保存时将忽略保存这些参数。