Understanding tensordot

32,935

Solution 1

The idea with tensordot is pretty simple - We input the arrays and the respective axes along which the sum-reductions are intended. The axes that take part in sum-reduction are removed in the output and all of the remaining axes from the input arrays are spread-out as different axes in the output keeping the order in which the input arrays are fed.

Let's look at few sample cases with one and two axes of sum-reductions and also swap the input places and see how the order is kept in the output.

I. One axis of sum-reduction

Inputs :

 In [7]: A = np.random.randint(2, size=(2, 6, 5))
   ...:  B = np.random.randint(2, size=(3, 2, 4))
   ...: 

Case #1:

In [9]: np.tensordot(A, B, axes=((0),(1))).shape
Out[9]: (6, 5, 3, 4)

A : (2, 6, 5) -> reduction of axis=0
B : (3, 2, 4) -> reduction of axis=1

Output : `(2, 6, 5)`, `(3, 2, 4)` ===(2 gone)==> `(6,5)` + `(3,4)` => `(6,5,3,4)`

Case #2 (same as case #1 but the inputs are fed swapped):

In [8]: np.tensordot(B, A, axes=((1),(0))).shape
Out[8]: (3, 4, 6, 5)

B : (3, 2, 4) -> reduction of axis=1
A : (2, 6, 5) -> reduction of axis=0

Output : `(3, 2, 4)`, `(2, 6, 5)` ===(2 gone)==> `(3,4)` + `(6,5)` => `(3,4,6,5)`.

II. Two axes of sum-reduction

Inputs :

In [11]: A = np.random.randint(2, size=(2, 3, 5))
    ...: B = np.random.randint(2, size=(3, 2, 4))
    ...: 

Case #1:

In [12]: np.tensordot(A, B, axes=((0,1),(1,0))).shape
Out[12]: (5, 4)

A : (2, 3, 5) -> reduction of axis=(0,1)
B : (3, 2, 4) -> reduction of axis=(1,0)

Output : `(2, 3, 5)`, `(3, 2, 4)` ===(2,3 gone)==> `(5)` + `(4)` => `(5,4)`

Case #2:

In [14]: np.tensordot(B, A, axes=((1,0),(0,1))).shape
Out[14]: (4, 5)

B : (3, 2, 4) -> reduction of axis=(1,0)
A : (2, 3, 5) -> reduction of axis=(0,1)

Output : `(3, 2, 4)`, `(2, 3, 5)` ===(2,3 gone)==> `(4)` + `(5)` => `(4,5)`

We can extend this to as many axes as possible.

Solution 2

tensordot swaps axes and reshapes the inputs so it can apply np.dot to 2 2d arrays. It then swaps and reshapes back to the target. It may be easier to experiment than to explain. There's no special tensor math going on, just extending dot to work in higher dimensions. tensor just means arrays with more than 2d. If you are already comfortable with einsum then it will be simplest compare the results to that.

A sample test, summing on 1 pair of axes

In [823]: np.tensordot(A,B,[0,1]).shape
Out[823]: (3, 5, 3, 4)
In [824]: np.einsum('ijk,lim',A,B).shape
Out[824]: (3, 5, 3, 4)
In [825]: np.allclose(np.einsum('ijk,lim',A,B),np.tensordot(A,B,[0,1]))
Out[825]: True

another, summing on two.

In [826]: np.tensordot(A,B,[(0,1),(1,0)]).shape
Out[826]: (5, 4)
In [827]: np.einsum('ijk,jim',A,B).shape
Out[827]: (5, 4)
In [828]: np.allclose(np.einsum('ijk,jim',A,B),np.tensordot(A,B,[(0,1),(1,0)]))
Out[828]: True

We could do same with the (1,0) pair. Given the mix of dimension I don't think there's another combination.

Solution 3

The answers above are great and helped me a lot in understanding tensordot. But they don't show actual math behind operations. That's why I did equivalent operations in TF 2 for myself and decided to share them here:

a = tf.constant([1,2.])
b = tf.constant([2,3.])
print(f"{tf.tensordot(a, b, 0)}\t tf.einsum('i,j', a, b)\t\t- ((the last 0 axes of a), (the first 0 axes of b))")
print(f"{tf.tensordot(a, b, ((),()))}\t tf.einsum('i,j', a, b)\t\t- ((() axis of a), (() axis of b))")
print(f"{tf.tensordot(b, a, 0)}\t tf.einsum('i,j->ji', a, b)\t- ((the last 0 axes of b), (the first 0 axes of a))")
print(f"{tf.tensordot(a, b, 1)}\t\t tf.einsum('i,i', a, b)\t\t- ((the last 1 axes of a), (the first 1 axes of b))")
print(f"{tf.tensordot(a, b, ((0,), (0,)))}\t\t tf.einsum('i,i', a, b)\t\t- ((0th axis of a), (0th axis of b))")
print(f"{tf.tensordot(a, b, (0,0))}\t\t tf.einsum('i,i', a, b)\t\t- ((0th axis of a), (0th axis of b))")

[[2. 3.]
 [4. 6.]]    tf.einsum('i,j', a, b)     - ((the last 0 axes of a), (the first 0 axes of b))
[[2. 3.]
 [4. 6.]]    tf.einsum('i,j', a, b)     - ((() axis of a), (() axis of b))
[[2. 4.]
 [3. 6.]]    tf.einsum('i,j->ji', a, b) - ((the last 0 axes of b), (the first 0 axes of a))
8.0          tf.einsum('i,i', a, b)     - ((the last 1 axes of a), (the first 1 axes of b))
8.0          tf.einsum('i,i', a, b)     - ((0th axis of a), (0th axis of b))
8.0          tf.einsum('i,i', a, b)     - ((0th axis of a), (0th axis of b))

And for (2,2) shape:

a = tf.constant([[1,2],
                 [-2,3.]])

b = tf.constant([[-2,3],
                 [0,4.]])
print(f"{tf.tensordot(a, b, 0)}\t tf.einsum('ij,kl', a, b)\t- ((the last 0 axes of a), (the first 0 axes of b))")
print(f"{tf.tensordot(a, b, (0,0))}\t tf.einsum('ij,ik', a, b)\t- ((0th axis of a), (0th axis of b))")
print(f"{tf.tensordot(a, b, (0,1))}\t tf.einsum('ij,ki', a, b)\t- ((0th axis of a), (1st axis of b))")
print(f"{tf.tensordot(a, b, 1)}\t tf.matmul(a, b)\t\t- ((the last 1 axes of a), (the first 1 axes of b))")
print(f"{tf.tensordot(a, b, ((1,), (0,)))}\t tf.einsum('ij,jk', a, b)\t- ((1st axis of a), (0th axis of b))")
print(f"{tf.tensordot(a, b, (1, 0))}\t tf.matmul(a, b)\t\t- ((1st axis of a), (0th axis of b))")
print(f"{tf.tensordot(a, b, 2)}\t tf.reduce_sum(tf.multiply(a, b))\t- ((the last 2 axes of a), (the first 2 axes of b))")
print(f"{tf.tensordot(a, b, ((0,1), (0,1)))}\t tf.einsum('ij,ij->', a, b)\t\t- ((0th axis of a, 1st axis of a), (0th axis of b, 1st axis of b))")
[[[[-2.  3.]
   [ 0.  4.]]
  [[-4.  6.]
   [ 0.  8.]]]

 [[[ 4. -6.]
   [-0. -8.]]
  [[-6.  9.]
   [ 0. 12.]]]]  tf.einsum('ij,kl', a, b)   - ((the last 0 axes of a), (the first 0 axes of b))
[[-2. -5.]
 [-4. 18.]]      tf.einsum('ij,ik', a, b)   - ((0th axis of a), (0th axis of b))
[[-8. -8.]
 [ 5. 12.]]      tf.einsum('ij,ki', a, b)   - ((0th axis of a), (1st axis of b))
[[-2. 11.]
 [ 4.  6.]]      tf.matmul(a, b)            - ((the last 1 axes of a), (the first 1 axes of b))
[[-2. 11.]
 [ 4.  6.]]      tf.einsum('ij,jk', a, b)   - ((1st axis of a), (0th axis of b))
[[-2. 11.]
 [ 4.  6.]]      tf.matmul(a, b)            - ((1st axis of a), (0th axis of b))
16.0    tf.reduce_sum(tf.multiply(a, b))    - ((the last 2 axes of a), (the first 2 axes of b))
16.0    tf.einsum('ij,ij->', a, b)          - ((0th axis of a, 1st axis of a), (0th axis of b, 1st axis of b))
Share:
32,935
floflo29
Author by

floflo29

Updated on July 09, 2022

Comments

  • floflo29
    floflo29 almost 2 years

    After I learned how to use einsum, I am now trying to understand how np.tensordot works.

    However, I am a little bit lost especially regarding the various possibilities for the parameter axes.

    To understand it, as I have never practiced tensor calculus, I use the following example:

    A = np.random.randint(2, size=(2, 3, 5))
    B = np.random.randint(2, size=(3, 2, 4))
    

    In this case, what are the different possible np.tensordot and how would you compute it manually?

  • floflo29
    floflo29 over 7 years
    What do you exactly mean by sum-reduction?
  • Divakar
    Divakar over 7 years
    @floflo29 Well you might know that matrix-multiplication involves elementwise multiplication keeping an axis aligned and then summation of elements along that common aligned axis. With that summation, we are losing that common axis, which is termed as reduction, so in short sum-reduction.
  • floflo29
    floflo29 over 7 years
    I've made "experimentations" in comparing np.einsum and np.tensordot and I think I now do understand how it works.
  • Bryan Head
    Bryan Head almost 7 years
    Is there anyway to reorder the output axes? Just use transpose at the end?
  • Divakar
    Divakar almost 7 years
    @BryanHead The only way to reorder the output axes using np.tensordot is to swap the inputs. If it doesn't get you your desired one, transpose would be the way to go.
  • kmario23
    kmario23 over 6 years
    Maybe for completeness and for newbies like me, would it be a good idea to mention that the dimensions of the axes to be sum-reduced must match? Although it's obvious from your examples, but stating it explicitly is really good :) This shape match constraint is well explained in TF docstring of tf.tensordot. Also, TF uses the term contraction or tensor contraction, so it might be a good idea to add that as well
  • Brenlla
    Brenlla about 5 years
    I still don't fully grasp it :(. In the 1st example from the docs they are multiplying element-wise 2 arrays with shape (4,3) and then doing sum over those 2 axes. How could you get that same result using a dot product?
  • Brenlla
    Brenlla about 5 years
    The way I could reproduce the 1st result from the docs is by using np.dot on flattened 2-D arrays: for aa in a.T: for bb in b.T: print(aa.ravel().dot(bb.T.ravel()))
  • hpaulj
    hpaulj about 5 years
    The einsum equivalent of tensordot with axes=([1,0],[0,1]), is np.einsum('ijk,jil->kl',a,b). This dot also does it: a.T.reshape(5,12).dot(b.reshape(12,2)). The dot is between a (5,12) and (12,2). The a.T puts the 5 first, and also swaps the (3,4) to match b.
  • CKM
    CKM over 4 years
    Would have been better if @Divakar have added the example starting from 1-D tensor along with how each entry is computed. E.g. t1=K.variable([[1,2],[2,3]] ) t2=K.variable([2,3]) print(K.eval(tf.tensordot(t1,t2,axes=0))) output: [[[2. 3.] [4. 6.]] [[4. 6.] [6. 9.]]] Not sure how the output shape is 2x2x2.
  • Divakar
    Divakar over 4 years
    @chandresh Your query seems to be tensorflow specific. numpy.tensordot is specific to NumPy arrays. So, please post a question with the appropriate tag, if you have any.
  • CKM
    CKM over 4 years
    @Divakar both work the same way if you just change tf variables to np variables. see here tf.tensordot
  • dereks
    dereks over 4 years
    It would be better to show actual math behind each operation, not only shapes
  • Divakar
    Divakar over 4 years
    @dereks You are looking to learn how sum-reductions are done? That's too-broad and beyond the scope. For the same, one can can lookup matrix-multiplication.
  • dereks
    dereks over 4 years
    I understand matmul, einsum and dot product. But like the author of this question I don't know how to use tensordot.
  • Divakar
    Divakar over 4 years
    @dereks So, when you are talking about math, I understood you are looking to understand sum-reductions. If not, what exactly are you looking for?
  • dereks
    dereks over 4 years
    I'm looking for a few complete examples of using tensordot. All steps (all operations inside) from the beginning to the end with different axes argument.
  • Divakar
    Divakar over 4 years
    @dereks Yeah I am expecting readers to know how matrix-multiplication works and what are sum-reductions. Hence, examples are listed in this post with respect to shapes and how the axes are aligned.
  • dereks
    dereks over 4 years
    How matrix-multiplication works: https://lh3.googleusercontent.com/Jmb3q-kvNTr1Lz3jgmIIxiPo_G‌​GgwllP_FonFnSROmte0w‌​c1KmM7d_aUJhHzPDsA7w‌​B0es5OTHs51HSSlENTOl‌​toa31TxzZMZlgGpf-l62‌​gCHRaU3C_CsHUWw3orAO‌​LCM5Lucpo What does this have to do with tensordot?
  • dereks
    dereks over 4 years
    Sum-reduction: print(f"{tf.einsum('ij->j', A)}\t tf.reduce_sum(A, 0)\t- Column sum")
  • Divakar
    Divakar over 4 years
    @dereks The sum-reduction term used in this post is an umbrella term for element-wise multiplication and then sum-reduction. In the context of dot/tensordot, I assumed it would be safe to put it that way. Apologies if that was confusing. Now, with matrix-multiplication you have one axis of sum-reduction (second axis of first array against first axis of second array), whereas in tensordot more than one axes of sum-reduction. The examples presented show how axes are aligned in the input arrays and how the output axes are obtained from those.
  • dereks
    dereks over 4 years
    @Divakar Thank you. Your comments were helpful. I added what I meant as the answer below.
  • gboffi
    gboffi about 4 years
    "all of the remaining axes from the input arrays are spread-out" This is exactly what I missed... now I have just to understand how to reduce the dimensionality of tensordot output when in einsum I use a repeated index in both inputs and the output — I know of this answer of yours but I still haven't grasped the process, maybe you could try to answer this Q using tensordot. In any case, thank you.
  • Divakar
    Divakar about 4 years
    @gboffi Well tensordot won't work there, because we need to keep one axis aligned through the inputs and also keep that in output - i. So, einsum is the way to go there.