counting combinations and permutations efficiently

42,439

Solution 1

if n is not far from r then using the recursive definition of combination is probably better, since xC0 == 1 you will only have a few iterations:

The relevant recursive definition here is:

nCr = (n-1)C(r-1) * n/r

This can be nicely computed using tail recursion with the following list:

[(n - r, 0), (n - r + 1, 1), (n - r + 2, 2), ..., (n - 1, r - 1), (n, r)]

which is of course easily generated in Python (we omit the first entry since nC0 = 1) by izip(xrange(n - r + 1, n+1), xrange(1, r+1)) Note that this assumes r <= n you need to check for that and swap them if they are not. Also to optimize use if r < n/2 then r = n - r.

Now we simply need to apply the recursion step using tail recursion with reduce. We start with 1 since nC0 is 1 and then multiply the current value with the next entry from the list as below.

from itertools import izip

reduce(lambda x, y: x * y[0] / y[1], izip(xrange(n - r + 1, n+1), xrange(1, r+1)), 1)

Solution 2

Two fairly simple suggestions:

  1. To avoid overflow, do everything in log space. Use the fact that log(a * b) = log(a) + log(b), and log(a / b) = log(a) - log(b). This makes it easy to work with very large factorials: log(n! / m!) = log(n!) - log(m!), etc.

  2. Use the gamma function instead of factorial. You can find one in scipy.stats.loggamma. It's a much more efficient way to calculate log-factorials than direct summation. loggamma(n) == log(factorial(n - 1)), and similarly, gamma(n) == factorial(n - 1).

Solution 3

There's a function for this in scipy which hasn't been mentioned yet: scipy.special.comb. It seems efficient based on some quick timing results for your doctest (~0.004 seconds for comb(100000, 1000, 1) == comb(100000, 99000, 1)).

[While this specific question seems to be about algorithms the question is there a math ncr function in python is marked as a duplicate of this...]

Solution 4

If you don't need a pure-python solution, gmpy2 might help (gmpy2.comb is very fast).

Solution 5

For Python until 3.7:

def prod(items, start=1):
    for item in items:
        start *= item
    return start


def perm(n, k):
    if not 0 <= k <= n:
        raise ValueError(
            'Values must be non-negative and n >= k in perm(n, k)')
    else:
        return prod(range(n - k + 1, n + 1))


def comb(n, k):
    if not 0 <= k <= n:
        raise ValueError(
            'Values must be non-negative and n >= k in comb(n, k)')
    else:
        k = k if k < n - k else n - k
        return prod(range(n - k + 1, n + 1)) // math.factorial(k)

For Python 3.8+:


Interestingly enough, some manual implementation of the combination function may be faster than math.comb():

def math_comb(n, k):
    return math.comb(n, k)


def comb_perm(n, k):
    k = k if k < n - k else n - k
    return math.perm(n, k) // math.factorial(k)


def comb(n, k):
    k = k if k < n - k else n - k
    return prod(range(n - k + 1, n + 1)) // math.factorial(k)


def comb_other(n, k):
    k = k if k > n - k else n - k
    return prod(range(n - k + 1, n + 1)) // math.factorial(k)


def comb_reduce(n, k):
    k = k if k < n - k else n - k
    return functools.reduce(
        lambda x, y: x * y[0] // y[1],
        zip(range(n - k + 1, n + 1), range(1, k + 1)),
        1)


def comb_iter(n, k):
    k = k if k < n - k else n - k
    result = 1
    for i in range(1, k + 1):
        result = result * (n - i + 1) // i
    return result


def comb_iterdiv(n, k):
    k = k if k < n - k else n - k
    result = divider = 1
    for i in range(1, k + 1):
        result *= (n - i + 1)
        divider *= i
    return result // divider


def comb_fact(n, k):
    k = k if k < n - k else n - k
    return math.factorial(n) // math.factorial(n - k) // math.factorial(k)

bm

so that actually comb_perm() (implemented with math.perm() and math.factorial()) is actually faster than math.comb() most of the times for these benchamarks, which show the computation time for fixed n=256 and increasing k (up until k = n // 2).

Note that comb_reduce(), which is quite slow, is essentially the same approach as from @wich's answer, while comb_iter(), also relatively slow, is essentially the same approach as @ZXX's answer.

Partial analysis here (without comb_math() and comb_perm() since they are not supported in Python's version of Colab -- 3.7 -- as of last edit).

Share:
42,439
Christian Oudard
Author by

Christian Oudard

My programming interests include web development, databases, algorithms, AI, and optimization. My favorite language is Python.

Updated on June 03, 2021

Comments

  • Christian Oudard
    Christian Oudard almost 3 years

    I have some code to count permutations and combinations, and I'm trying to make it work better for large numbers.

    I've found a better algorithm for permutations that avoids large intermediate results, but I still think I can do better for combinations.

    So far, I've put in a special case to reflect the symmetry of nCr, but I'd still like to find a better algorithm that avoids the call to factorial(r), which is an unnecessarily large intermediate result. Without this optimization, the last doctest takes too long trying to calculate factorial(99000).

    Can anyone suggest a more efficient way to count combinations?

    from math import factorial
    
    def product(iterable):
        prod = 1
        for n in iterable:
            prod *= n
        return prod
    
    def npr(n, r):
        """
        Calculate the number of ordered permutations of r items taken from a
        population of size n.
    
        >>> npr(3, 2)
        6
        >>> npr(100, 20)
        1303995018204712451095685346159820800000
        """
        assert 0 <= r <= n
        return product(range(n - r + 1, n + 1))
    
    def ncr(n, r):
        """
        Calculate the number of unordered combinations of r items taken from a
        population of size n.
    
        >>> ncr(3, 2)
        3
        >>> ncr(100, 20)
        535983370403809682970
        >>> ncr(100000, 1000) == ncr(100000, 99000)
        True
        """
        assert 0 <= r <= n
        if r > n // 2:
            r = n - r
        return npr(n, r) // factorial(r)
    
  • Christian Oudard
    Christian Oudard over 14 years
    sorry i wasn't clear, my code is python 3, not python 2. range in python 3 is the same as xrange in python 2.
  • wich
    wich over 14 years
    It seems that this implementation is O(n^2) while the tail recursion I laid out is O(n) as far as I can see.
  • wich
    wich over 14 years
    It seems a different recursive definition is used. here n choose k = n-1 choose k-1 + n-1 choose k, while I used n choose k = n-1 choose k-1 * n/k
  • wich
    wich over 14 years
    This is basically what Agor suggested, but it would be O(n^2). Since using multiplications and divisions is really not a problem anymore these days, using a different recursion relation one can make the algorithm O(n) as I described.
  • agorenst
    agorenst over 14 years
    Indeed, such is the case, wich. I will shortly edit this post to include a quick python mock-up of the algorithm. Yours is significantly faster. I will leave my post here, in case if Gorgapor has some exotic machine in which multiplication requires hours. >.>
  • JPvdMerwe
    JPvdMerwe over 14 years
    This might be O(N^2) but it precalculates all combination pairs of nCr, so if you are gonna use nCr a lot with a lot of different values, this will be faster, because lookups are O(1) and is less susceptable to overflows. For one value the O(N) algo is better though.
  • JPvdMerwe
    JPvdMerwe over 14 years
    For a single nCr this is better, but when you have multiple nCr's (on the order of N) then the dynamic programming approach is better, even though it has a long setup time, since it won't overflow into a 'bignum' unless necessary.
  • Christian Oudard
    Christian Oudard over 14 years
    the main problem with the factorial is the size of the result, not the time calculating it. also, the values of the result here are much bigger than can be accurately represented by a float value.
  • Christian Oudard
    Christian Oudard over 14 years
    thanks for the reference, that's a very good practical solution. this is more of a learning project for me though, and so i'm more interested in the algorithm than the practical result.
  • Christian Oudard
    Christian Oudard over 14 years
    Good suggestion doing things in log space. Not sure what you mean by "for precision" though. Wouldn't using log-floats cause roundoff error for large numbers?
  • dsimcha
    dsimcha over 14 years
    @Gorgapor: I guess a clearer way of stating this is: "To avoid overflow". Edited.
  • starblue
    starblue over 14 years
    Note that this will not give exact results, due to the limited precision of floating-point numbers.
  • dsimcha
    dsimcha over 14 years
    @starblue: But you know the real answer has to be an integer, so if you do something like round(exp(logFactorial(n))), it will be exact for small n. For large n it may be inexact, but anything other than (slow) arbitrary precision would just be dead wrong.
  • Christian Oudard
    Christian Oudard over 13 years
    there is not much trouble in computing this for small n. the point is to compute this accurately for large n, and I'm already using arbitrary precision, because I'm using python longs.
  • Bill Bell
    Bill Bell over 9 years
    For those who come to this answer some years after it was written, gmpy is now known as gmpy2.
  • PTTHomps
    PTTHomps about 9 years
    how do you use the gamma and loggamma functions? Neither one returns an integer, rather returns a scipy.stats._distn_infrastructure.rv_frozen object.
  • PTTHomps
    PTTHomps about 9 years
    math.gamma and math.lgamma produce integer results. Still not clear on what the scipy.stats functions are doing.
  • ZXX
    ZXX almost 2 years
    I beg to differ. There is no way that a loop with multiplications gets faster as the number of iterations grows - in any universe :-) It would have to have complexity like O(1/ln(n)) => you are faster than theoretical quantum computer :-)))))) Did you check the results? With a big int numeric lib? Your prod(..) punctures MAX_INT in a jiffy => wraps around. Probably gives a fluctuating curve. The (n-k+1)/k remains int and doesn't lose a bit. When it comes to the puncturing point it flips to float. Maybe Python's [*] gets faster when you puncture MAX_INT ? :-) because it stops? :-)
  • norok2
    norok2 almost 2 years
    @ZXX perhaps the content of the plots was not sufficiently clear, sorry. Anyway, I am pretty sure I have never written or implied what you are saying. If you refer to comb_other() getting faster with larger inputs, that is because of k and n - k being swapped to show where the expensive computation is happening. You can quite easily check it yourself that all these functions get to the same numerical values, well past the int64 result threshold (and Python has built-in big int support, I think I can safely assume math.comb() is giving the correct result).