How does python numpy.where() work?

191,643

Solution 1

How do they achieve internally that you are able to pass something like x > 5 into a method?

The short answer is that they don't.

Any sort of logical operation on a numpy array returns a boolean array. (i.e. __gt__, __lt__, etc all return boolean arrays where the given condition is true).

E.g.

x = np.arange(9).reshape(3,3)
print x > 5

yields:

array([[False, False, False],
       [False, False, False],
       [ True,  True,  True]], dtype=bool)

This is the same reason why something like if x > 5: raises a ValueError if x is a numpy array. It's an array of True/False values, not a single value.

Furthermore, numpy arrays can be indexed by boolean arrays. E.g. x[x>5] yields [6 7 8], in this case.

Honestly, it's fairly rare that you actually need numpy.where but it just returns the indicies where a boolean array is True. Usually you can do what you need with simple boolean indexing.

Solution 2

Old Answer it is kind of confusing. It gives you the LOCATIONS (all of them) of where your statment is true.

so:

>>> a = np.arange(100)
>>> np.where(a > 30)
(array([31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
       48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64,
       65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81,
       82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98,
       99]),)
>>> np.where(a == 90)
(array([90]),)

a = a*40
>>> np.where(a > 1000)
(array([26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42,
       43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59,
       60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76,
       77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93,
       94, 95, 96, 97, 98, 99]),)
>>> a[25]
1000
>>> a[26]
1040

I use it as an alternative to list.index(), but it has many other uses as well. I have never used it with 2D arrays.

http://docs.scipy.org/doc/numpy/reference/generated/numpy.where.html

New Answer It seems that the person was asking something more fundamental.

The question was how could YOU implement something that allows a function (such as where) to know what was requested.

First note that calling any of the comparison operators do an interesting thing.

a > 1000
array([False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True`,  True,  True,  True,  True,  True,  True,  True,  True,  True], dtype=bool)`

This is done by overloading the "__gt__" method. For instance:

>>> class demo(object):
    def __gt__(self, item):
        print item


>>> a = demo()
>>> a > 4
4

As you can see, "a > 4" was valid code.

You can get a full list and documentation of all overloaded functions here: http://docs.python.org/reference/datamodel.html

Something that is incredible is how simple it is to do this. ALL operations in python are done in such a way. Saying a > b is equivalent to a.gt(b)!

Solution 3

np.where returns a tuple of length equal to the dimension of the numpy ndarray on which it is called (in other words ndim) and each item of tuple is a numpy ndarray of indices of all those values in the initial ndarray for which the condition is True. (Please don't confuse dimension with shape)

For example:

x=np.arange(9).reshape(3,3)
print(x)
array([[0, 1, 2],
      [3, 4, 5],
      [6, 7, 8]])
y = np.where(x>4)
print(y)
array([1, 2, 2, 2], dtype=int64), array([2, 0, 1, 2], dtype=int64))


y is a tuple of length 2 because x.ndim is 2. The 1st item in tuple contains row numbers of all elements greater than 4 and the 2nd item contains column numbers of all items greater than 4. As you can see, [1,2,2,2] corresponds to row numbers of 5,6,7,8 and [2,0,1,2] corresponds to column numbers of 5,6,7,8 Note that the ndarray is traversed along first dimension(row-wise).

Similarly,

x=np.arange(27).reshape(3,3,3)
np.where(x>4)


will return a tuple of length 3 because x has 3 dimensions.

But wait, there's more to np.where!

when two additional arguments are added to np.where; it will do a replace operation for all those pairwise row-column combinations which are obtained by the above tuple.

x=np.arange(9).reshape(3,3)
y = np.where(x>4, 1, 0)
print(y)
array([[0, 0, 0],
   [0, 0, 1],
   [1, 1, 1]])
Share:
191,643

Related videos on Youtube

pajton
Author by

pajton

I am a Python, Java and C/C++ software developer. And that's all folks!

Updated on March 29, 2020

Comments

  • pajton
    pajton over 3 years

    I am playing with numpy and digging through documentation and I have come across some magic. Namely I am talking about numpy.where():

    >>> x = np.arange(9.).reshape(3, 3)
    >>> np.where( x > 5 )
    (array([2, 2, 2]), array([0, 1, 2]))
    

    How do they achieve internally that you are able to pass something like x > 5 into a method? I guess it has something to do with __gt__ but I am looking for a detailed explanation.

  • eat
    eat over 12 years
    Just to point out that numpy.where do have 2 'operational modes', first one returns the indices, where condition is True and if optional parameters x and y are present (same shape as condition, or broadcastable to such shape!), it will return values from x when condition is True otherwise from y. So this makes where more versatile and enables it to be used more often. Thanks
  • ely
    ely about 11 years
    There can also be overhead in some cases using the __getitem__ syntax of [] over either numpy.where or numpy.take. Since __getitem__ has to also support slicing, there's some overhead. I've seen noticeable speed differences when working with the Python Pandas data structures and logically indexing very large columns. In those cases, if you don't need slicing, then take and where are actually better.
  • davidA
    davidA over 6 years
    This comparison operator overloading doesn't seem to work well with more complex logical expressions though - for example I can't do np.where(a > 30 and a < 50) or np.where(30 < a < 50) because it ends up trying to evaluate the logical AND of two arrays of booleans, which is pretty meaningless. Is there a way to write such a condition with np.where?
  • tibalt
    tibalt almost 6 years
    @meowsqueak np.where((a > 30) & (a < 50))
  • Andreas Yankopolus
    Andreas Yankopolus over 4 years
    Why is np.where() returning a list in your example?

Related