TensorFlow API 用法详解:tf.equal
equal(
x,
y,
name=None
)
对输入的 x
和 y
两个 Tensor
逐元素(element-wise)做 (x == y) 逻辑比较,返回 bool
类型的 Tensor
。
参数
x
只支持以下类型:half
,float32
,float64
,uint8
,int8
,int16
,int32
,int64
,complex64
,quint8
,qint8
,qint32
,string
,bool
,complex128
y
的类型必须与x
相同name
给这个操作取一个名称,可选
返回
bool
类型的Tensor
特性
- 支持
broadcasting
,详见Numpy
文档。
示例
基本用法:x
和 y
拥有相同的 shape
import tensorflow as tf a = tf.constant([1, 2], tf.int32) b = tf.constant([2, 2], tf.int32) with tf.Session() as sess: print(sess.run(tf.equal(a, b))) # 输出 [False True]
broadcasting
用法:x
和 y
不同 shape
x = tf.constant(["hehe", "haha", "hoho", "kaka"], tf.string) y = tf.constant("hoho", tf.string) with tf.Session() as sess: print(sess.run(tf.equal(x, y))) # 输出 [False False True False]
注意观察上面这个栗子,实际解决了在一个数组中查找某个元素索引(index
)的问题,这个特性配合 tf.cast
在生成 one-hot
向量时将会特别有用。
a = tf.constant([[1], [2]], tf.int32) b = tf.constant([[2, 1]], tf.int32) with tf.Session() as sess: print(sess.run(tf.equal(a, b))) # 输出 # [[False True] # [ True False]]
0 Comments
No comments yet.