发新帖

tf.keras.layers.BatchNormalization均值方差不更新

[复制链接]
33 0

快来加入 TensorFlowers 大家庭!

您需要 登录 才可以下载或查看,没有帐号?加入社区

x
在1.4版本中 用keras的api来进行batch_norm发现均值和方差没有保存,应该怎么办呢?代码如下:
import tensorflow as tf
import numpy as np

tf.reset_default_graph()
graph = tf.get_default_graph()
tf.keras.backend.set_learning_phase(True)

input_shapes = [(3, )]
hidden_layer_sizes = [16, 16]

inputs = [
    tf.keras.layers.Input(shape=input_shape)
    for input_shape in input_shapes
]

concatenated = tf.keras.layers.Lambda(
    lambda x: tf.concat(x, axis=-1)
)(inputs)

out = concatenated
for units in hidden_layer_sizes:
    hidden = tf.keras.layers.Dense(
    units, activation=None
    )(out)
    bn1 = tf.keras.layers.BatchNormalization()
    batch_normed = bn1(hidden, training=True)
    #batch_normed = tf.layers.batch_normalization(hidden, training=True)
    out = tf.keras.layers.Activation('relu')(batch_normed)
#tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, bn1.moving_mean)
#tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, bn1.moving_variance)
tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, bn1.updates)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
print(update_ops)
with tf.control_dependencies(update_ops):
    out = tf.keras.layers.Dense(
        units=1, activation='linear'
    )(out)


data = np.random.rand(100,3)
with tf.Session(graph=graph) as sess:
    sess.run(tf.global_variables_initializer())
    for i in range(10):
        #sess.run(update_ops)
        sess.run(out, {inputs[0]: data})

        variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                              scope='batch_normalization')
        bn_gamma, bn_beta, bn_moving_mean, bn_moving_variance = [], [], [], []
        for variable in variables:
            val = sess.run(variable)
            nv = np.linalg.norm(val)
            if 'gamma' in variable.name:
                bn_gamma.append(nv)
            if 'beta' in variable.name:
                bn_beta.append(nv)
            if 'moving_mean' in variable.name:
                bn_moving_mean.append(nv)
            if 'moving_variance' in variable.name:
                bn_moving_variance.append(nv)

            diagnostics = {
                'bn_Q_gamma': np.mean(bn_gamma),
                'bn_Q_beta': np.mean(bn_beta),
                'bn_Q_moving_mean': np.mean(bn_moving_mean),
                'bn_Q_moving_variance': np.mean(bn_moving_variance),
            }

      
print(diagnostics)



输出
    {'bn_Q_gamma': 4.0, 'bn_Q_beta': 0.0, 'bn_Q_moving_mean': 0.0, 'bn_Q_moving_variance': 4.0}
    {'bn_Q_gamma': 4.0, 'bn_Q_beta': 0.0, 'bn_Q_moving_mean': 0.0, 'bn_Q_moving_variance': 4.0}
    {'bn_Q_gamma': 4.0, 'bn_Q_beta': 0.0, 'bn_Q_moving_mean': 0.0, 'bn_Q_moving_variance': 4.0}
    {'bn_Q_gamma': 4.0, 'bn_Q_beta': 0.0, 'bn_Q_moving_mean': 0.0, 'bn_Q_moving_variance': 4.0}
    {'bn_Q_gamma': 4.0, 'bn_Q_beta': 0.0, 'bn_Q_moving_mean': 0.0, 'bn_Q_moving_variance': 4.0}
    {'bn_Q_gamma': 4.0, 'bn_Q_beta': 0.0, 'bn_Q_moving_mean': 0.0, 'bn_Q_moving_variance': 4.0}
    {'bn_Q_gamma': 4.0, 'bn_Q_beta': 0.0, 'bn_Q_moving_mean': 0.0, 'bn_Q_moving_variance': 4.0}
    {'bn_Q_gamma': 4.0, 'bn_Q_beta': 0.0, 'bn_Q_moving_mean': 0.0, 'bn_Q_moving_variance': 4.0}
    {'bn_Q_gamma': 4.0, 'bn_Q_beta': 0.0, 'bn_Q_moving_mean': 0.0, 'bn_Q_moving_variance': 4.0}
    {'bn_Q_gamma': 4.0, 'bn_Q_beta': 0.0, 'bn_Q_moving_mean': 0.0, 'bn_Q_moving_variance': 4.0}
咋回事呀!


您需要登录后才可以回帖 登录 | 加入社区

本版积分规则

快速回复 返回顶部 返回列表