What is the alternative of numpy.newaxis in tensorflow?
Solution 1
I think that would be tf.expand_dims
-
tf.expand_dims(a, 1) # Or tf.expand_dims(a, -1)
Basically, we list the axis ID where this new axis is to be inserted and the trailing axes/dims are pushed-back.
From the linked docs, here's few examples of expanding dimensions -
# 't' is a tensor of shape [2]
shape(expand_dims(t, 0)) ==> [1, 2]
shape(expand_dims(t, 1)) ==> [2, 1]
shape(expand_dims(t, -1)) ==> [2, 1]
# 't2' is a tensor of shape [2, 3, 5]
shape(expand_dims(t2, 0)) ==> [1, 2, 3, 5]
shape(expand_dims(t2, 2)) ==> [2, 3, 1, 5]
shape(expand_dims(t2, 3)) ==> [2, 3, 5, 1]
Solution 2
The corresponding command is tf.newaxis
(or None
, as in numpy). It does not have an entry on its own in tensorflow's documentation, but is briefly mentioned on the doc page of tf.stride_slice
.
x = tf.ones((10,10,10))
y = x[:, tf.newaxis] # or y = x [:, None]
print(y.shape)
# prints (10, 1, 10, 10)
Using tf.expand_dims
is fine too but, as stated in the link above,
Those interfaces are much more friendly, and highly recommended.
Solution 3
a = a[..., tf.newaxis].astype("float32")
This Works as well
Solution 4
If you're interested in exactly the same type(i.e. None
) as in NumPy, then tf.newaxis
is the exact alternative to np.newaxis
.
Example:
In [71]: a1 = tf.constant([2,2], name="a1")
In [72]: a1
Out[72]: <tf.Tensor 'a1_5:0' shape=(2,) dtype=int32>
# add a new dimension
In [73]: a1_new = a1[tf.newaxis, :]
In [74]: a1_new
Out[74]: <tf.Tensor 'strided_slice_5:0' shape=(1, 2) dtype=int32>
# add one more dimension
In [75]: a1_new = a1[tf.newaxis, :, tf.newaxis]
In [76]: a1_new
Out[76]: <tf.Tensor 'strided_slice_6:0' shape=(1, 2, 1) dtype=int32>
This is exactly the same kind of operations that you do in NumPy. Just use it exactly at the same dimension where you want it to be increased.
Rahul
Updated on June 07, 2022Comments
-
Rahul almost 2 years
Hi I am new to tensorflow. I want to implement the following python code in tensorflow.
import numpy as np a = np.array([1,2,3,4,5,6,7,9,0]) print(a) ## [1 2 3 4 5 6 7 9 0] print(a.shape) ## (9,) b = a[:, np.newaxis] ### want to write this in tensorflow. print(b.shape) ## (9,1)