This is a common conceptual difficulty when learning to use NumPy effectively. Normally, data processing in Python is best expressed in terms of iterators, to keep memory usage low, to maximize opportunities for parallelism with the I/O system, and to provide for reuse and combination of parts of algorithms.
But NumPy turns all that inside out: the best approach is to express the algorithm as a sequence of whole-array operations, to minimize the amount of time spent in the slow Python interpreter and maximize the amount of time spent in fast compiled NumPy routines.
Here's the general approach I take:
Keep the original version of the function (which you are confident is correct) so that you can test it against your improved versions both for correctness and speed.
Work from the inside out: that is, start with the innermost loop and see if can be vectorized; then when you've done that, move out one level and continue.
Spend lots of time reading the NumPy documentation. There are a lot of functions and operations in there and they are not always brilliantly named, so it's worth getting to know them. In particular, if you find yourself thinking, "if only there were a function that did such-and-such," then it's well worth spending ten minutes looking for it. It's usually in there somewhere.
There's no substitute for practice, so I'm going to give you some example problems. The goal for each problem is to rewrite the function so that it is fully vectorized: that is, so that it consists of a sequence of NumPy operations on whole arrays, with no native Python loops (no for
or while
statements, no iterators or comprehensions).
Problem 1
def sumproducts(x, y):
"""Return the sum of x[i] * y[j] for all pairs of indices i, j.
>>> sumproducts(np.arange(3000), np.arange(3000))
20236502250000
"""
result = 0
for i in range(len(x)):
for j in range(len(y)):
result += x[i] * y[j]
return result
Problem 2
def countlower(x, y):
"""Return the number of pairs i, j such that x[i] < y[j].
>>> countlower(np.arange(0, 200, 2), np.arange(40, 140))
4500
"""
result = 0
for i in range(len(x)):
for j in range(len(y)):
if x[i] < y[j]:
result += 1
return result
Problem 3
def cleanup(x, missing=-1, value=0):
"""Return an array that's the same as x, except that where x ==
missing, it has value instead.
>>> cleanup(np.arange(-3, 3), value=10)
... # doctest: +NORMALIZE_WHITESPACE
array([-3, -2, 10, 0, 1, 2])
"""
result = []
for i in range(len(x)):
if x[i] == missing:
result.append(value)
else:
result.append(x[i])
return np.array(result)
Spoilers below. You'll get much the best results if you have a go yourself before looking at my solutions!
Answer 1
np.sum(x) * np.sum(y)
Answer 2
np.sum(np.searchsorted(np.sort(x), y))
Answer 3
np.where(x == missing, value, x)