czy36mengfei/tensorflow2_tutorials_chinese

tf.keras问题

xuchengggg opened this issue · 0 comments

我将原本在tf.1.12下搭建的模型用tf.2.0编译,修改了几个修改的函数以后,又出现了 unhashable type: 'ListWrapper'的错误,出现在self.keras_model.add_loss(loss) 这一行上,

def compile(self, learning_rate, momentum):
        """Gets the model ready for training. Adds losses, regularization, and
        metrics. Then calls the Keras compile() function.
        """
        # Optimizer object
        optimizer = keras.optimizers.SGD(
            lr=learning_rate, momentum=momentum,
            clipnorm=self.config.GRADIENT_CLIP_NORM, )
        # Add Losses
        # First, clear previously set losses to avoid duplication
        self.keras_model._losses = []
        self.keras_model._per_input_losses = {}
        loss_names = ["loc_loss", "class_loss", "mask_loss"]
        for name in loss_names:
            layer = self.keras_model.get_layer(name)
            if layer.output in self.keras_model.losses:
                continue
            # Mean here because Dataparallel
            loss = tf.reduce_mean(layer.output, keepdims=True)
            self.keras_model.add_loss(tf.abs(loss))

        # Add L2 Reqgularization
        # Skip gamma and beta weights of batch normalization layers.
        reg_losses = [
            keras.regularizers.l2(self.config.WEIGHT_DECAY)(w) / tf.cast(tf.size(w), tf.float32)
            for w in self.keras_model.trainable_weights
            if 'gamma' not in w.name and 'beta' not in w.name]
        #
        self.keras_model.add_loss(tf.add_n(reg_losses))

        # Compile
        self.keras_model.compile(
            optimizer=optimizer,
            loss=[None] * len(self.keras_model.outputs))

        # Add metrics for losses
        for name in loss_names:
            if name in self.keras_model.metrics_names:
                continue
            layer = self.keras_model.get_layer(name)
            self.keras_model.metrics_names.append(name)
            loss = tf.reduce_mean(layer.output, keepdims=True)
            self.keras_model.metrics_tensors.append(loss)

这部分的全部代码是这样的,有人知道怎么修改嘛?