代码如下:
import tensorflow as tf var_a_op = tf.get_variable(name='a', shape=[3, 4], dtype=tf.float32, initializer=tf.random_uniform_initializer(minval=-1, maxval=1)) var_b_op = tf.get_variable(name='b', shape=[3, 4], dtype=tf.float32, initializer=tf.random_uniform_initializer(minval=-1, maxval=1)) max_a_op = tf.argmax(var_a_op, 1) max_b_op = tf.argmax(var_b_op, 1) equal_op = tf.equal(max_a_op, max_b_op) cast_op = tf.cast(equal_op, "float") mean_op = tf.reduce_mean(cast_op) init_op = tf.initialize_all_variables() with tf.Session() as sess: sess.run(init_op) print(sess.run(var_a_op)) print(sess.run(var_b_op)) print(sess.run(max_a_op)) print(sess.run(max_b_op)) print(sess.run(equal_op)) print(sess.run(cast_op)) print(sess.run(mean_op))
输出如下:
[[ 0.86882663 0.83623457 0.18949676 0.1206305 ] [ 0.26341105 -0.96493745 0.42790508 0.79426742] [ 0.07995605 -0.54766798 0.81907773 -0.13814092]] [[ 0.9508462 0.69852757 -0.42363882 -0.83991981] [ 0.18524432 0.28357625 -0.33817434 -0.10542083] [ 0.32140636 0.00276613 0.57352829 -0.22276378]] [0 3 2] [0 1 2] [ True False True] [ 1. 0. 1.] 0.666667