What is the alternative of numpy.newaxis in tensorflow?

14,795

Solution 1

I think that would be tf.expand_dims -

tf.expand_dims(a, 1) # Or tf.expand_dims(a, -1)

Basically, we list the axis ID where this new axis is to be inserted and the trailing axes/dims are pushed-back.

From the linked docs, here's few examples of expanding dimensions -

# 't' is a tensor of shape [2]
shape(expand_dims(t, 0)) ==> [1, 2]
shape(expand_dims(t, 1)) ==> [2, 1]
shape(expand_dims(t, -1)) ==> [2, 1]

# 't2' is a tensor of shape [2, 3, 5]
shape(expand_dims(t2, 0)) ==> [1, 2, 3, 5]
shape(expand_dims(t2, 2)) ==> [2, 3, 1, 5]
shape(expand_dims(t2, 3)) ==> [2, 3, 5, 1]

Solution 2

The corresponding command is tf.newaxis (or None, as in numpy). It does not have an entry on its own in tensorflow's documentation, but is briefly mentioned on the doc page of tf.stride_slice.

x = tf.ones((10,10,10))
y = x[:, tf.newaxis] # or y = x [:, None]
print(y.shape)
# prints (10, 1, 10, 10)

Using tf.expand_dims is fine too but, as stated in the link above,

Those interfaces are much more friendly, and highly recommended.

Solution 3

a = a[..., tf.newaxis].astype("float32")

This Works as well

Solution 4

If you're interested in exactly the same type(i.e. None) as in NumPy, then tf.newaxis is the exact alternative to np.newaxis.

Example:

In [71]: a1 = tf.constant([2,2], name="a1")

In [72]: a1
Out[72]: <tf.Tensor 'a1_5:0' shape=(2,) dtype=int32>

# add a new dimension
In [73]: a1_new = a1[tf.newaxis, :]

In [74]: a1_new
Out[74]: <tf.Tensor 'strided_slice_5:0' shape=(1, 2) dtype=int32>

# add one more dimension
In [75]: a1_new = a1[tf.newaxis, :, tf.newaxis]

In [76]: a1_new
Out[76]: <tf.Tensor 'strided_slice_6:0' shape=(1, 2, 1) dtype=int32>

This is exactly the same kind of operations that you do in NumPy. Just use it exactly at the same dimension where you want it to be increased.

Share:
14,795
Rahul
Author by

Rahul

Updated on June 07, 2022

Comments

  • Rahul
    Rahul almost 2 years

    Hi I am new to tensorflow. I want to implement the following python code in tensorflow.

    import numpy as np
    a = np.array([1,2,3,4,5,6,7,9,0])
    print(a) ## [1 2 3 4 5 6 7 9 0]
    print(a.shape) ## (9,)
    b = a[:, np.newaxis] ### want to write this in tensorflow.
    print(b.shape) ## (9,1)