本文共 1045 字,大约阅读时间需要 3 分钟。
Tensorflow 的 argmax 接口可以返回一阶以上张量最大值所对应的分量索引。
比如 tensorflow.argmax ( [1,2,3,10,1]) 返回 10对应的索引3 。
对于超过一阶的张量,需要指定要搜索的是第几维的元素,这个维是以0开始的。比如
对于a = [[[10.0,25.0,3.0,4.0] , ] 这样一个张量,想找最里层的元素的最大值,可以tensorflow.argmax ( a , 2 ) 来获取。
下面是分别对三阶和一阶张量找最大值的例子。
import tensorflow as tf
def findMaxFromRank3() : """ 找出一个三阶张量第三维(索引是2)的最大值 """ a =tf.Variable( [[[10.0,25.0,3.0,4.0] , [10.0,251.0,35.0,4.0]] , [[100.0,25.0,3.0,4.0] , [10.0,250.0,3500.0,4.0]] ] ) b = tf.argmax( a , 2 ) se = tf.Session() init = tf.global_variables_initializer() se.run( init ) r = se.run( b ) ar = se.run( a ) se.close() print( r ) print ( type (r ) ) print( ar[0][0][r[0][0]] , "," , ar[0][1][ r[0][1] ] ) print( ar[1][0][r[1][0]] , "," , ar[1][1][ r[1][1] ] )def findMaxFromRank1() : """ 找出一个一阶张量第一维(索引是0)的最大值 """ a =tf.Variable( [10.0,250.0,3510.0,4.0 ] ) b = tf.argmax( a , 0 ) se = tf.Session() init = tf.global_variables_initializer() se.run( init ) r = se.run( b ) ar = se.run( a ) se.close() print( r ) print ( type (r ) ) print( ar[r] ) findMaxFromRank3()findMaxFromRank1()转载地址:http://rmmqf.baihongyu.com/