Optimizing Advent of Code 2020 day 17

2021-02-08 ∙ 13 minute read

... in which we optimize our Advent of Code 2020 day 17 solution, a Python implementation of multidimensional Game of Life, to end up with a 65x improvement.

We will focus on profiling and optimizing the existing code, in a way that helps you translate those skills to your regular, non-puzzle coding.

We'll start from the script as we left it in the initial article, and check we didn't break anything using the tests we already wrote.

Contents

Why is it slow? #

Our solution is pretty naive: for each cell, count how many of the cell's neighbors are active, and change the cell state based on that; see this for a detailed explanation of the rules.

As we add more dimensions, the run time increases by orders of magnitude; for a world of size 16, I get:

  • 2D: .02 seconds
  • 3D: 1 second
  • 4D: 1 minute

The same happens when we increase the world size: in 4D, a world of size 20 doesn't take only 1.25 times longer than a size 16 world, but 2.4 times!

To get a better picture of why this is happening, let's count how many cells and and neighbors we need to look at every cycle (as a reminder, neighbors are all the cells in a size 3 "square" centered on the cell, except the cell itself):

>>> size = 16
>>> dims = 2
>>> size ** dims, size ** dims * (3 ** dims - 1)
(256, 2048)
>>> dims = 3
>>> size ** dims, size ** dims * (3 ** dims - 1)
(4096, 106496)
>>> dims = 4
>>> size ** dims, size ** dims * (3 ** dims - 1)
(65536, 5242880)
>>> size = 20
>>> size ** dims, size ** dims * (3 ** dims - 1)
(160000, 12800000)

That is indeed exponential growth (the number of dimensions being the exponent).

As I mentioned before, there are many optimizations to simulating Life. They usually involve one or more of:

  • reducing the number of cells to look at
  • making neighbors faster to count
  • detecting parts of the board that repeat either in space or time, and reusing the previous results

We won't change our naive algorithm; instead, we'll try to make our existing Python code faster, since it is both easier to do (at least initially), and easier to translate to other Python problems.

(If we were after speed at any cost, we'd probably both use better algorithms, and switch to a faster language.)

Intro to profiling #

The Python standard library provides a profiler which allows getting statistics for how often and how long various functions get executed.

For us, the easiest way to use it is to pass a whole script, like this:

python3 -m cProfile [scriptfile] [arg] ...

You can also profile specific bits of code; see this for details.

Before profiling, we get a baseline run time for the "real" workload:

$ python3 conway_cubes.py real 20 4 6
after cycle #0 (0.01s): ...
after cycle #1 (24.15s): ...
after cycle #2 (23.45s): ...
after cycle #3 (23.79s): ...
after cycle #4 (24.31s): ...
after cycle #5 (24.15s): ...
after cycle #6 (24.07s): ...
the result is 2276 (143.94s)

While working on the script, we'll simulate a smaller world for just one cycle, so we can iterate quickly. We get a baseline for that as well:

$ python3 conway_cubes.py test 8 4 1
after cycle #0 (0.00s): ...
after cycle #1 (0.65s): ...
the result is 29 (0.65s)

Let's run it through the profiler:

$ python3 -m cProfile -s cumulative conway_cubes.py test 8 4 1
after cycle #0 (0.00s): ...
after cycle #1 (1.01s): ...
the result is 29 (1.01s)

         2268874 function calls (2241818 primitive calls) in 1.020 seconds

   Ordered by: cumulative time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.000    0.000    1.020    1.020 {built-in method builtins.exec}
        1    0.000    0.000    1.020    1.020 conway_cubes.py:1(<module>)
        1    0.000    0.000    1.020    1.020 conway_cubes.py:162(main)
        1    0.000    0.000    1.020    1.020 conway_cubes.py:120(run)
        3    0.000    0.000    1.013    0.338 conway_cubes.py:104(simulate_forever)
        1    0.005    0.005    1.013    1.013 conway_cubes.py:74(simulate)
     4096    0.391    0.000    1.001    0.000 conway_cubes.py:28(get_active_neighbors)
   327680    0.216    0.000    0.396    0.000 {built-in method builtins.any}
  1557736    0.188    0.000    0.188    0.000 conway_cubes.py:37(<genexpr>)
   327680    0.143    0.000    0.143    0.000 conway_cubes.py:32(<listcomp>)
     4096    0.005    0.000    0.063    0.000 conway_cubes.py:23(make_directions)
     4096    0.058    0.000    0.058    0.000 conway_cubes.py:25(<listcomp>)
33938/8194    0.011    0.000    0.011    0.000 conway_cubes.py:56(ndenumerate)
        1    0.001    0.001    0.006    0.006 {built-in method builtins.sum}
     4097    0.001    0.000    0.006    0.000 conway_cubes.py:138(<genexpr>)
   1170/2    0.001    0.000    0.001    0.000 conway_cubes.py:89(make_world)
    146/2    0.000    0.000    0.001    0.000 conway_cubes.py:93(<listcomp>)
     4100    0.001    0.000    0.001    0.000 {built-in method builtins.len}
        6    0.000    0.000    0.000    0.000 {built-in method builtins.print}
        1    0.000    0.000    0.000    0.000 conway_cubes.py:8(parse_input)
        1    0.000    0.000    0.000    0.000 conway_cubes.py:96(copy_centered_2d)
        1    0.000    0.000    0.000    0.000 conway_cubes.py:9(<listcomp>)
       10    0.000    0.000    0.000    0.000 {built-in method builtins.isinstance}
        5    0.000    0.000    0.000    0.000 {built-in method time.perf_counter}
        3    0.000    0.000    0.000    0.000 conway_cubes.py:10(<listcomp>)
        1    0.000    0.000    0.000    0.000 conway_cubes.py:6(<dictcomp>)
        1    0.000    0.000    0.000    0.000 {method 'splitlines' of 'str' objects}
        1    0.000    0.000    0.000    0.000 {method 'disable' of '_lsprof.Profiler' objects}

After the script finishes, the profiler prints the number of calls and run times for each function. We're interested in two columns:

  • cumtime, "the cumulative time spent in this and all subfunctions"
  • tottime, "the total time spent in the given function" (excluding sub-functions)

By default, the results are sorted by function name, which isn't very useful; we use use the -s option to sort by cumulative time.

Since the output is quite long, from now on I'll just show the relevant rows in the middle of the table.

You may notice the run time increased; that's because profiling adds some overhead. We are using the cProfile module, a C implementation of the profiler; if you try the pure-Python version, profile, it'll take even more, around 25x on my machine.

If you're following along, you might find it useful to re-run the command automatically every time you save the file.

I used entr to do it:

echo conway_cubes.py | entr -rcs "
python3 -m cProfile -s cumulative conway_cubes.py test 8 4 1 \
| grep -A20 ncalls
"

Worst offenders #

Looking at the data, we see that almost the whole 1 second run time is spent in get_active_neighbors() and subfunctions, which is consistent with our initial calculation.

Let's see it in full:

28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
def get_active_neighbors(world, active, coords):
    active_neighbors = 0
    for offsets in make_directions(len(coords)):

        neighbor_coords = [
            coord + offset
            for coord, offset in zip(coords, offsets)
        ]

        if any(coord < 0 for coord in neighbor_coords):
            if active:
                raise RuntimeError(f"active on edge: {coords}")
            continue

        try:
            neighbor = world
            for coord in neighbor_coords:
                neighbor = neighbor[coord]
        except IndexError:
            if active:
                raise RuntimeError(f"active on edge: {coords}")
            continue

        active_neighbors += neighbor

    return active_neighbors

Of the total time, about .4s are spent in the function itself (tottime), and the rest in subfunctions:

  • .4s in the any(coord < 0 ...) check, .2s of which in the generator expression
  • .15s in the neighbor_coords list comprehension
  • .06s in make_directions(), almost all of it in the list comprehension

(Comprehensions and generator expressions are treated as functions as well.)

make_directions() #

Let's start small.

make_directions() only takes 6% of the total time, but should be quite easy to speed up – it is a pure function (the result only depends on the arguments), and has a single argument with only a handful of values (2, 3, 4).

We could pre-compute the results for each dimension, save them in a global dict, and reuse them from there.

Turns out, the functools.lru_cache decorator from the standard library allows us to do just that in a transparent way:

2
from functools import lru_cache
24
25
26
@lru_cache()
def make_directions(dimensions):
    # ...

Now, make_directions() will save the return value for a specific argument on the first call, and subsequent calls with the same argument will return the already computed value.

Here's the result:

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
     4096    0.390    0.000    0.940    0.000 conway_cubes.py:30(get_active_neighbors)
      ...
        1    0.000    0.000    0.000    0.000 conway_cubes.py:24(make_directions)

Not bad, for this little work.

any(coord < 0 ...) #

Emboldened by this momentous achievement, we'll go straight to the any(coord < 0 ...) check.

There's more than one way to approach it, but before exploring any of them, let's look a bit harder at get_active_neighbors():

  • for every neighbor, we're checking if any of its coordinates is < 0;
  • but by definition, the lowest a neighbor can be is -1 from the cell;
  • so that's equivalent to checking that the cell coordinate is < 1;
  • since the cell isn't moving, we can do it just once, outside the neighbor loop.

It might not seem like a lot, but remember the neighbor count increases exponentially: in 2D, we're doing the check 8 times; in 4D, we're doing it 80 times!

So:

30
31
32
33
34
35
36
def get_active_neighbors(world, active, coords):
    if any(coord < 1 for coord in coords):
        if active:
            raise RuntimeError(f"active on edge: {coords}")

    active_neighbors = 0
    # ...

Which gives us a 57% improvement!

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
     4096    0.265    0.000    0.406    0.000 conway_cubes.py:30(get_active_neighbors)
      ...
     4096    0.003    0.000    0.005    0.000 {built-in method builtins.any}
    17656    0.003    0.000    0.003    0.000 conway_cubes.py:31(<genexpr>)

There's a slight issue, though: in the original version, if the check failed, we'd skip that neighbor (assume it's not active); now we're not doing that anymore.

This means that for a cell on the top/left/... edge of the world (index 0), we will be getting the state for its neighbors at the far end (index -1).

As long as the neighbors on the far end are inactive, it will still work; thankfully, we are checking that as well – that's what the if active in except IndexError does.

But now we've made a bit of logic dependent on another that's quite far from it. Instead of just reasoning through it every time we change something, we rely on the edge case tests to verify it for us (nothing to do, since we've already written them :).

The script up to this point.

neighbor_coords #

Next up is the neighbor_coords list comprehension.

Using functions written in C may remove some of the comprehension overhead; let's see if it works:

38
        neighbor_coords = map(sum, zip(coords, offsets))

It's not much better:

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
     4096    0.376    0.000    0.382    0.000 conway_cubes.py:30(get_active_neighbors)

What if instead of using sum, which is generic to any iterable, we used a function that's made specifically for 2 numbers?

30
31
from operator import add
from itertools import starmap
41
        neighbor_coords = starmap(add, zip(coords, offsets))

This fares slighly better, with a 26% improvement:

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
     4096    0.296    0.000    0.301    0.000 conway_cubes.py:33(get_active_neighbors)

We've now exhausted all the obvious things to improve; most of the time is spent in get_active_neighbors, not its subfunctions.

You may remember the any() call spent almost as much time in the function itself (tottime) as in the generator expression; indeed, calling any function seems to have significant overhead:

$ python3 -m timeit -s 'from operator import add' 'add(1, 2)'
5000000 loops, best of 5: 49.3 nsec per loop
$ python3 -m timeit '1 + 2'
50000000 loops, best of 5: 7.98 nsec per loop

Let's try something different.

In the for offsets ... loop, we're calling 3 functions to add 4 pairs of numbers. What if we didn't? Having more general code did help with testing, but we may be reaching a point where it's not worth it anymore.

We can validate it with a quick experiment (this will break non-4D temporarily):

38
39
40
41
42
43
        neighbor_coords = [
            coords[0] + offsets[0],
            coords[1] + offsets[1],
            coords[2] + offsets[2],
            coords[3] + offsets[3],
        ]

Which gives us a 49% improvement!

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
     4096    0.202    0.000    0.207    0.000 conway_cubes.py:30(get_active_neighbors)

What if make everything non-generic, and get rid of intermediary variables?

30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
def get_active_neighbors(world, active, coords):
    c0, c1, c2, c3 = coords

    if c0 < 1 or c1 < 1 or c2 < 1 or c3 < 1:
        if active:
            raise RuntimeError(f"active on edge: {coords}")

    active_neighbors = 0
    for o0, o1, o2, o3 in make_directions(len(coords)):
        try:
            active_neighbors += world[o0 + c0][o1 + c1][o2 + c2][o3 + c3]
        except IndexError:
            if active:
                raise RuntimeError(f"active on edge: {coords}")

    return active_neighbors

That yields a 76% improvement!

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
     4096    0.095    0.000    0.095    0.000 conway_cubes.py:30(get_active_neighbors)

Going multidimensional, again #

At this point, we could be OK with part of the code not being generic anymore, implement one get_active_neighbors() per dimension, and use them like this:

GET_ACTIVE_NEIGHBORS = {
    2: get_active_neighbors_2d,
    3: get_active_neighbors_3d,
    4: get_active_neighbors_4d,
}

def get_active_neighbors(world, active, coords):
    return GET_ACTIVE_NEIGHBORS[len(coords)](world, active, coords)

This reminds me of the pattern initially proposed for make_directions()... If only there was a way to do the same thing for a function.

Well, this is Python, we can generate code at runtime. Let's do something stupid:

3
from textwrap import dedent
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
def make_get_active_neighbors_str(dimensions):
    ids = list(range(dimensions))
    return dedent(f"""

        def get_active_neighbors(world, active, coords):
            {', '.join(f'c{i}' for i in ids)} = coords

            if {' or '.join(f'c{i} < 1' for i in ids)}:
                if active:
                    raise RuntimeError(f"active on edge: {{coords}}")

            active_neighbors = 0
            for {', '.join(f'o{i}' for i in ids)} in {make_directions(dimensions)}:
                try:
                    active_neighbors += world[{']['.join(f'o{i} + c{i}' for i in ids)}]
                except IndexError:
                    if active:
                        raise RuntimeError(f"active on edge: {{coords}}")

            return active_neighbors

    """)

@lru_cache()
def make_get_active_neighbors(dimensions):
    context = {}
    exec(make_get_active_neighbors_str(dimensions), context)
    return context['get_active_neighbors']

We build the string with the source code of a function, and then use exec to execute it in private global context. We're using an explicit context to avoid polluting module globals – everything that gets defined in the source code string is a context item.

You may notice that we're embedding the representation of the directions list directly in the code, instead of calling make_directions(); I'm not sure this brings a great speed-up, but it can't hurt.

Never use exec/eval with untrusted code, unless you want to get hacked; details.

Security issues aside, it makes code way harder to understand, and breaks a lot of conventions about where classes and functions come from; more details.

We're doing it here for ... didactic purposes. And speed. Mostly speed.

Before we try it out, we need to pull the dimension counting heuristic from ndenumerate() into a separate function:

61
62
63
64
65
66
67
68
69
70
71
def guess_dimensions(world):
    dimensions = 0
    while isinstance(world, list):
        dimensions += 1
        world = world[0]
    return dimensions

def ndenumerate(world, dimensions=None):
    dimensions = dimensions or guess_dimensions(world)

    # ...

... so we can also use it in simulate():

81
82
83
84
85
86
def simulate(old, new):
    dimensions = guess_dimensions(old)
    get_active_neighbors = make_get_active_neighbors(dimensions)

    for coords, active in ndenumerate(old, dimensions):
        # ...

... which leaves us with something slighly faster than our initial experiment:

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
     4096    0.091    0.000    0.091    0.000 <string>:3(get_active_neighbors)

The final version of the script.

The real world #

So, with our test 8 4 1 profiling parameters, we got a 91% (~11x) improvement in get_active_neighbors() cumulative time.

Does this reflect in the real-world performance?

$ python3 conway_cubes.py real 20 4 6
after cycle #0 (0.01s): ...
after cycle #1 (3.39s): ...
after cycle #2 (3.43s): ...
after cycle #3 (3.34s): ...
after cycle #4 (3.34s): ...
after cycle #5 (3.34s): ...
after cycle #6 (3.35s): ...
the result is 2276 (20.21s)

Almost. That's still an 86% (~7x) improvement.

Bonus: PyPy #

PyPy is an alternative Python implementation that's often faster than CPython (the standard implementation) due to its use of just-in-time compilation.

Currently, it works mostly out of the box for Python code up to version 3.7 (with the exception of some CPython extensions).

First, let's see how it performs on the unoptimized script:

$ pypy3 conway_cubes.py real 20 4 6
after cycle #0 (0.02s): ...
after cycle #1 (6.43s): ...
after cycle #2 (5.53s): ...
after cycle #3 (5.47s): ...
after cycle #4 (5.45s): ...
after cycle #5 (5.47s): ...
after cycle #6 (5.46s): ...
the result is 2276 (33.82s)

4.2x; not bad, for essentially zero work! Funnily enough, their website says that "on average, PyPy is 4.2 times faster than CPython".

So, was it all for nothing, could have we just used PyPy from the start? Yes and no, but mostly no:

$ pypy3 conway_cubes.py real 20 4 6
after cycle #0 (0.02s): ...
after cycle #1 (0.52s): ...
after cycle #2 (0.33s): ...
after cycle #3 (0.33s): ...
after cycle #4 (0.33s): ...
after cycle #5 (0.34s): ...
after cycle #6 (0.33s): ...
the result is 2276 (2.21s)

That is:

  • a 10x improvement over the same script run with CPython,
  • 4.2x over the unoptimized script with PyPy, and
  • 65x over the unoptimized script with CPython!

Conclusions #

Profile before optimizing. Most often, your intuition about where the code is slow is wrong.

Many small optimizations add up.

A different algorithm can be better than many small optimizations. At some point, there are no other small optimization left.

Optimizations have costs. Usually, they make code harder to understand. More changes increases the likelihood of bugs. More changes on harder to understand code even more so. Tests help you know you're not breaking anything. Profiling helps minimize the amount of code you change.

My program is too slow. How do I speed it up? from the Python Programming FAQ has more, better advice on the points above.

PyPy is amazing, give it a try if you can.

Function calls are slow in Python; that only matters if you're calling them millions of times.

You can do cool stuff with Python code generation; most of the time, it's not worth it.


This is part of a series: