Optimizating Conway

Conway’s Game of Life seems to be a common programming exercise. I had to program it in Pascal when in High School and in C in an intro college programming course. I remember in college, since I had already programmed it before, that I wanted to optimize the algorithm. However, a combination of writing in C and having only a week to work on it didn’t leave me with enough time to implement anything fancy.

A couple years later, I hiked the Appalachian Trail. Seven months away from computers, just hiking day in and day out. One of the things I found myself contemplating when walking up and down hills all day was that pesky Game of Life algorithm and ways that I could improve it.

Fast forward through twenty intervening years of life and experience with a few other programming languages to last weekend. I needed a fun programming exercise to raise my spirits so I looked up the rules to Conway’s Game of Life, sat down with vim and python, and implemented a few versions to test out some of the ideas I’d had kicking around in my head for a quarter century.

This blog post will only contain a few snippets of code to illustrate the differences between each approach.  Full code can be checked out from this github repository.

The Naive Approach: Tracking the whole board

The naive branch is an approximation of how I would have first coded Conway’s Game of Life way back in that high school programming course. The grid of cells is what I would have called a two dimensional array in my Pascal and C days. In Python, I’ve more often heard it called a list of lists. Each entry in the outer list is a row on the grid which are each represented by another list. Each entry in the inner lists are cells in that row of the grid. If a cell is populated, then the list entry contains True. If not, then the list entry contains False.

One populated cell surrounded by empty cells would look like this:

board = [
         [False, False, False],
         [False, True, False],
         [False, False, False],
        ]

Looking up an individual cell’s status is a matter of looking up an index in two lists: First the y-index in the outer list and then the x-index in an inner list:


# Is there a populated cell at x-axis 0, y-axis 1?
if board[0][1] is True:
    pass

Checking for changes is done by looping through every cell on the Board, and checking whether each cell’s neighbors made the cell fit a rule to populate or depopulate the cell on the next iteration.

for y_idx, row in enumerate(board):
    for x_idx, cell in enumerate(row):
        if cell:
            if not check_will_live((x_idx, y_idx), board, max_x, max_y):
                next_board[y_idx][x_idx] = False
        else:
            if check_new_life((x_idx, y_idx), board, max_x, max_y):
                next_board[y_idx][x_idx] = True

This is a simple mapping of the two-dimensional grid that Conway’s takes place on into a computer data structure and then a literal translation of Conway’s ruleset onto those cells. However, it seems dreadfully inefficient. Even in college I could see that there should be easy ways to speed this up; I just needed the time to implement them.

Intermediate: Better checking

The intermediate branch rectifies inefficiencies with checking of the next generation cells. The naive branch checks every single cell that is present in the grid. However, thinking about most Conway setups, most of the cells are blank. If we find a way to ignore most of the blank cells, then it would save us a lot of work. We can’t ignore all blank cells, though; if a blank cell has exactly three populated neighbors then the blank cell will become populated in the next generation.

The key to satisfying both of those is to realize that all the cells we’re going to need to change will either be populated (in which case, they could die and become empty in the next generation) or they will be a neighbor of a populated cell (in which case, they may become populated next generation). So we can loop through our board and ignore all of the unpopulated cells at first. If we find a populated cell, then we must both check that cell to see if it will die and also check its empty neighbors to see if they will be filled in the next generation.

The major change to implement that is here:

       checked_cells = set()
       # We still loop through every cell in the board but now
       # the toplevel code to do something if the cell is empty
       # has been removed.
       for y_idx, row in enumerate(board):
            for x_idx, cell in enumerate(row):
                if cell:
                    if not check_will_live((x_idx, y_idx), board, max_x, max_y):
                        next_board[y_idx][x_idx] = False
                    # Instead, inside of the conditional block to
                    # process when a cell is populated, there's
                    # a new loop to check all of the neighbors of
                    # a populated cell.
                    for neighbor in (n for n in find_neighbors((x_idx, y_idx), max_x, max_y)
                                     if n not in checked_cells):
                        # If the cell is empty, then we check whether
                        # it should be populated in the next generation
                        if not board[neighbor[1]][neighbor[0]]:
                            checked_cells.add(neighbor)
                            if check_new_life(neighbor, board, max_x, max_y):
                                next_board[neighbor[1]][neighbor[0]] = True

Observant readers might also notice that I’ve added a checked_cells set. This tracks which empty cells we’ve already examined to see if they will be populated next generation. Making use of that means that we will only check a specific empty cell once per generation no matter how many populated cells it’s a neighbor of.

These optimizations to the checking code made things about 6x as fast as the naive approach.

Gridless: Only tracking populated cells

The principle behind the intermediate branch of only operating on populated cells and their neighbors seemed like it should be applicable to the data structure I was storing the grid in as well as the checks. Instead of using fixed length arrays to store both the populated and empty portions of the grid, I figured it should be possible to simply store the populated portions of the grid and then use those for all of our operations.

However, C is a very anemic language when it comes to built in data structures.
if I was going to do that in my college class, I would have had to implement a linked list or a hash map data structure before I even got to the point where I could implement the rules of Conway’s Game of Life. Today, working in Python with it’s built in data types, it was very quick to implement a data structure of only the populated cells.

For the gridless branch, I replaced the 2d array with a set. The set contained tuples of (x-coordinate, y-coordinate) which defined the populated cells. One populated cell surrounded by empty cells would look like this:

board = set((
             (1,1),
           ))

Using a set had all sorts of advantages:

  • Setup became a matter of adding tuples representing just the populated cells to the set instead of creating the list of lists and putting True or False into every one of the list’s cells:
        - board = []
        - for y in range(0, max_y):
        -     for x in range(0, max_x):
        -         board[x][y] = (x, y) in initial_dataset
        + board = set()
        + for x, y in initial_dataset:
        +     board.add((x, y))
    
  • For most cases, the set will be smaller than the list of lists as it only has to store entries for the populated cells, not for all of the cells in the grid. Some patterns may be larger than the list of lists (because the overhead of the set is more than the overhead for a list of lists) but those patterns are largely unstable; after a few generations, the number of depopulated cells will increase to a level where the set will be smaller.
  • Displaying the grid only has to loop through the populated cells instead of looping through every entry in the list of lists to find the populated cells:

        - for y in range(0, max_y):
        -     for x in range(0, max_x):
        -         if board[x][y]:
        -             screen.addstr(y, x, ' ', curses.A_REVERSE)
        + for (x, y) in board:
        +     screen.addstr(y, x, ' ', curses.A_REVERSE)
    
  • Instead of copying the old board and then changing the status of cells each generation, it is now feasible to simply populate a new set with populated cells every generation. This is because the set() is empty except for the populated cells whereas the list of lists would need to have both the populated and the empty cells set to either True or False.

             - next_board = copy.deepcopy(board)
             + next_board = set()
    
  • Similar to the display loop, the main loop which checks what happens to the cells can now just loop through the populated cells instead of having to search the entire grid for them:

            # Perform checks and update the board
            for cell in board:
                if check_will_live(cell, board, max_x, max_y):
                    next_board.add(cell)
                babies = check_new_life(cell, board, max_x, max_y)
                next_board.update(babies)
            board = next_board
    
  • Checking whether a cell is populated now becomes a simple containment check on a set which is faster than looking up the list indexes in the 2d array:

    - if board[cell[0]][cell[1]]:
    + if cell in board:
    

Gridless made the program about 3x faster than intermediate, or about 20x faster than naive.

Master: Only checking cell changes once per iteration

Despite being 3x faster than intermediate, gridless was doing some extra work. The code in the master branch attempts to correct those.

The most important change was that empty cells in gridless were being checked for each populated cell that was its neighbor. Adding a checked_cells set like the intermediate branch had to keep track of that ensures that we only check whether an empty cell should be populated in the next generation one time:

        checked_cells = set()
        for cell in board:
            if check_will_live(cell, board, max_x, max_y):
                next_board.add(cell)
            checked_cells.add(cell)
            # Pass checked_cells into check_new_life so that
            # checking skips empty neighbors which have already
            # been checked this generation
            babies, barren = check_new_life(cell, board, checked_cells, max_x, max_y)
            checked_cells.update(babies)
            checked_cells.update(barren)

The other, but relatively small, optimization was to use Python’s builtin least-recently-used cache decorator on the find_neighbors function. This allowed us to skip computing the set of neighboring cells when those cells were requested soon after each other. In the set-based code, finding_neighbors is called for the same cell back to back quite frequently so this did have a noticable impact.

+ @functools.lru_cache()
  def find_neighbors(cell, max_x, max_y):

These changes sped up the master branch an additional 30% over what gridless had achieved or nearly 30x as fast as the naive implementation that we started with.

What have I learned?

  • I really made these optimizations to Conway’s Game of Life purely to see whether any of the ideas I had had about it in the past 25 years would make a difference. It is pleasing to see that they did.
  • I think that choice of programming language made a huge difference in how I approached this problem. Even allowing that I was a less experienced programmer 25 years ago, I think that the cost in programmer time to implement sets and other advanced data structures in C would likely have taken long enough that I wouldn’t have finished the fastest of these designs as a one week class assignment using C. Implementing the same ideas in Python, where the data structures exist as part of the language, took less than a day.
  • I had originally envisioned using a linked list or set instead of a 2d array as purely a memory optimization. After I started coding it, it quickly became clear that only saving the populated cells also reduced the number of cells that would have to be searched through. I’d like to say I saw that from the beginning but it was just serendipity.
Advertisements

Python, signal handlers, and exceptions

Never ever, ever raise a regular exception in a Python signal handler.

This is probably the best advice that I had never heard.  But after looking into an initial analysis of a timeout decorator bug it’s advice that I wish was prominently advertised.  So I’m publishing this little tale to let others know about this hidden gotcha so that, just maybe, when you have an opportunity to do this, a shiver will run down your spine, your memory will be tickled, and you’ll be able to avoid the pitfalls that I’ve recorded here.

When will I need to be reminded of this?

Signals are a means provided by an operating system to inform programs of  events that aren’t part of their normal program flow.  If you’re on a UNIX-like operating system and hit Control-C to cancel a program running in the shell, you’ve used a signal.  Control-C in the shell sends a SIGINT (Interrupt) signal to the program.  The program receives the SIGINT and a signal handler (a function which is written to take appropriate action when a signal is received) is invoked to handle it.  In most cases, for SIGINT, the signal handler tries to complete any pressing work (for instance, to finish writing data to a file so that the file isn’t left in a half-written state) and then causes the program to exit.

Python provides a signal library which is very similar to the C API underneath Python.  Python does a little bit of behind the scenes work to make sure that signals appear to be interacting with Python code rather than interfering with the interpreter’s implementation (meaning that signals will appear to happen between Python’s byte code instructions rather than leaving a Python instruction half-completed).

Your mission: exit when a signal occurs

If you search the internet for how to implement a timeout in Python, you’ll find tons of examples using signals, including one from the standard library documentation and one which is probably where the design for our original decorator came from.  So let’s create some quick test code to show how the signal works to implement a timeout

import signal
import time

def handler(signum, frame):
    print('Signal handler called with signal',
          signum)
    raise OSError("timeout exceeded!")

def long_function():
    time.sleep(10)

# Set the signal handler and a 1-second alarm
old_handler = signal.signal(signal.SIGALRM,
                            handler)
signal.alarm(1)
# This sleeps for longer than the alarm
start = time.time()
try:
    long_function()
except OSError as e:
    duration = time.time() - start
    print('Duration: %.2f' % duration)
    raise
finally:
    signal.signal(signal.SIGALRM, old_handler)
    signal.alarm(0) # Disable the alarm

This code is adapted from the example of implementing a timeout in the standard library documentation.  We first define a signal handler named handler() which will raise an OSError when it’s invoked. We then define a function, long_function() which is designed to take longer to run than our timeout. The we hook everything together:

  • On lines 13-14 we register the handler to be invoked when a SIGALRM occurs.
  • On line 15 we tell Python to emit SIGALRM after 1 second.
  • On line 19 we call long_function()
  • On line 20-23, we specify what happens when a timeout occurs
  • In the finally block on lines 24-26, we undo our signal handler by first re-registering the prior SIGALARM handler as the function to invoke if a new SIGALRM occurs and then disabling the alarm that we had previously set.

When we run the code we see that the signal handler raises the OSError as expected:

$ /bin/time -p python3 test.py
Signal handler called with signal 14
Duration: 1.00
Traceback (most recent call last):
  File "test.py", line 18, in
    long_function()
  File "test.py", line 11, in long_function
    time.sleep(2)
  File "test.py", line 8, in handler
    raise OSError("timeout exceeded!")
OSError: timeout exceeded!
real 1.04

Although long_function() takes 10 seconds to complete, the SIGALRM fires after 1 second.  That causes handler() to run which raises the OSError with the message timeout exceeded!.  The exception propagates to our toplevel where it is caught and prints Duration: 1.00 before re-raising so we can see the traceback.  We see that the output of /bin/time roughly agrees with the duration we calculated within the program… just a tad over 1 second to run the whole thing.

Exceptions seem to be working…?

It’s time to make our long_function code a little less trivial.

import signal
import time

def handler(signum, frame):
    print('Signal handler called with signal',
signum)
    raise OSError("timeout exceeded!")

def long_function():
    try:
        with open('/etc/passwd', 'r') as f:
            data = f.read()
            # Simulate reading a lot of data
            time.sleep(10)
    except OSError:
        # retry once:
        with open('/etc/passwd', 'r') as f:
            data = f.read()
            time.sleep(10)

# Set the signal handler and a 5-second alarm
old_handler = signal.signal(signal.SIGALRM,
handler)
signal.alarm(1)
start = time.time()
# This sleeps for longer than the alarm
try:
    long_function()
except OSError as e:
    duration = time.time() - start
    print('Duration: %.2f' % duration)
    raise
finally:
    signal.signal(signal.SIGALRM, old_handler)
    signal.alarm(0)        # Disable the alarm

We’ve changed long_function(), lines 9-20, to do two new things:

  1. It now reads from a file.  We still have time.sleep(), but that’s just to simulate a long read operation which will trigger our timeout.
  2. Open() and read() can raise a variety of exceptions.  If open() or read() raise an OSError exception, we want to retry reading the file once before we give up.

So what happens when we run this version?

$ /bin/time -p python3 test.py
Signal handler called with signal 14
real 11.07

That was unexpected…

As you can see from the output, our program still fired the SIGALRM.  And the signal handler still ran.  But after it ran, everything else seems to be different.  Apparently, the OSError didn’t propagate up the stack to our toplevel so we didn’t print the duration.  Furthermore, we ended up waiting for approximately 10 seconds in addition to the 1 second we waited for the timeout. What happened?

The key to understanding this is to understand that when the signal handler is invoked, Python adds the call to the signal handler onto its current call stack (The call stack roughly represents the nesting of functions that lead up to where the code is currently executing.  So, in our example, inside of the signal handler, the call stack would start with the module toplevel, then long_function(), and finally handler().  If you look at the traceback from the first example, you can see exactly that call stack leading you through all the function calls to the point where the exception was raised.)  When the signal handler raises its exception, Python unwinds the call stack one function at a time to find an exception handler (not to be confused with our signal handler) which can handle the exception.

Where was the program waiting when SIGALRM was emitted?  It was on line 14, inside of long_function().  So Python acts as though handler() was invoked directly after that line.  And when handler() raises its OSError exception, Python then unwinds the call stack from handler() to long_function() and sees that on line 15, there’s an except OSError: and so Python lets it catch the signal handler’s OSError instead of propagating it up the stack further.  And in our code, that exception handler decides to retry reading the file which is where there is a second 10 second delay as we read the file.  Since the SIGALRM was already used up, the timeout doesn’t fire this time.  So the rest of the bug progresses from there: long_function() now waits the full 10 seconds before returning because there’s no timeout to stop it.  It then returns normally to its caller.  The caller doesn’t receive an OSError exception.  So it doesn’t fire its own OSError exception handler which would have printed the Duration.

Okay, this isn’t good. Can it get worse?

There are even less intuitive ways that this bug can be provoked. For instance:

def long_function():
    time.sleep(0.8)
    try:
        time.sleep(0.1)
        time.sleep(0.1)
        time.sleep(0.1)
    except Exception:
        pass

In this version, we don’t have the context clue that we’re raising OSError in the signal handler and mistakenly catching it in long_function(). An experienced Python coder will realize that except Exception is sufficiently broad to be catching OSError, but in the clutter of other function calls and without the same named exception to provoke their suspicions initially, they might not realize that the occasional problems with the timeout not working could be triggered by this late exception handler.

This is also problematic if the programmer is having to orient themselves off of a traceback. That would lead them to look inside of long_function() for the source of a OSError. They won’t find it there because it’s inside of the signal handler which is outside of the function.

import zipfile
def long_function():
    time.sleep(0.9)
    zipfile.is_zipfile('my_file.zip')

In this case, there’s no exception handling in our code to clue us in. If you look at the implementation of zipfile.is_zipfile(), though, you’ll see that there’s an except OSError: inside of there. If the timeout’s alarm ends up firing while you are inside of that code, is_zipfile() will just as happily swallow the OSError as an exception handler inside of long_function() would have.

So I think I can solve this if I….

There are ways to make the timeout functionality more robust. For instance, let’s define our handler like this:

import functools

class OurTimeout(BaseException):
    pass

def handler(timeout, signum, frame):
    print('Signal handler called with signal',
          signum)
    signal.alarm(timeout)
    raise OurTimeout("timeout exceeded!")

# toplevel code:
one_second_timeout = functools.partial(handler, 1)
old_handler = signal.signal(signal.SIGALRM, one_second_timeout)
signal.alarm(1)
try:
    long_function()
except OurTimeout:
    print('Timeout from our timeout function)

The first thing you can see is that this now raises a custom timeout which inherits from BaseException. This is more robust than using standard timeouts which are rooted at Exception because well written Python code won’t catch exceptions inheriting directly from BaseException. It’s not foolproof, however, because it’s only convention that prevents someone from catching BaseException. And no matter how good a Python programmer you are, I’m sure you’ve been lazy enough to write code like this at least once:

def long_function():
    try:
        with open('/etc/passwd', 'r') as f:
            f.read()
    except:
        # I'll fix the possible exception classes later
        rollback()

Bare except: catches any exception, including BaseException so using a bare except will still break this timeout implementation.

The second thing this implementation changes is to reset the alarm inside of the signal handler. This is helpful as it means that the code will be able to timeout multiple times to get back to the toplevel exception handler. In our retry example, Both the initial attempt and the retry attempt would timeout so long_function() would end up taking only two seconds and fail due to our timeout in the end. However, there are still problems. For instance, this code ends up taking timeout * 3 seconds instead of just timeout seconds because the exception handler prevents us from hitting the break statement so we’ll have to keep hitting the timeout:

def long_function():
    for i in range(0, 3):
        try:
            time.sleep(10)
        except OSError:
            pass
        else:
            break

The following code, (arguably buggy because you *should* disable the alarm as the first thing in the recovery code instead of the last) can end up aborting the toplevel code’s recovery effort if timeout is too short:

try:
    long_function()
except OSError:
   rollback_partial_long_function()
   signal.alarm(0)

So even though this code is better than before, it is still fragile, error prone, and can do the wrong thing even with code that isn’t obviously wrong.

Doctor, is there still hope?

When I was looking at the problem with the timeout decorator in Ansible what struck me was that when using the decorator, the decorator was added outside of a function telling it to timeout but the timeout was occurring and potentially being processed inside of the function. That meant that it would always be unintuitive and error prone to someone trying to use the decorator:

@timeout.timeout(1)
def long_function():
    try:
        time.sleep(10)
    except:
        # Try an alternative
        pass

try:
    long_function()
except timeout.TimeoutError:
    print('long_function did not complete')

When looking at the code, the assumption is that on the inside there’s long_function, then outside of it, the timeout code, and outside of that the caller. So the expectation is that an exception raised by the timeout code should only be processed by the caller since exceptions in Python only propagate up the call stack. Since the decorator’s functionality was implemented via a signal handler, though, that expectation was disappointed.

To solve this, I realized that the way signals and exceptions interact would never allow exceptions to propagate correctly. So I switched from using a signal to using one thread for the timeout and one thread for running the function. Simplified, that flow looks like this (You can look at the code for the new decorator in Ansible if you’re okay with the GPLv3+ license. The following code is all mine in case you want to re-implement the idea without the GPLv3+ restrictions:

import multiprocessing
import multiprocessing.pool

def timeout(seconds=10):
    def decorator(func):
        def wrapper(*args, **kwargs):
            pool = multiprocessing.pool.ThreadPool(processes=1)
            results = pool.apply_async(func, args, kwargs)
            pool.close()
            try:
                return results.get(seconds)
            except multiprocessing.TimeoutError:
                raise OSError('Timeout expired after: %s' % seconds)
            finally:
                pool.terminate()
        return wrapper
    return decorator

@timeout(1)
def long_func():
    try:
        time.sleep(10)
    except OSError:
        print('Failure!')
    print('end of long_func')

try:
    long_func()
except OSError as e:
    print('Success!')
    raise

Edit: Davi Aguiar pointed out that the snippet was leaving the thread running after timeout. The example has been updated to add a pool.terminate() call inside of a finally: to take care of reaping the thread after the timeout expires.

As you can see, I create a multiprocessing.pool.ThreadPool with a single thread in it. Then I run the decorated function in that thread. I use res.get(timeout) with the timeout we set to get the results or raise an exception if the timeout is exceeded. If the timeout was exceeded, then I throw the OSError to be like our first example.

If all goes well, the exception handling inside of long_func won’t have any effect on what happens. Let’s see:

$ python2 test.py
Success!
Traceback (most recent call last):
  File "test.py", line 49, in
    long_func()
  File "test.py", line 35, in wrapper
    raise OSError('Timeout expired after: %s' % seconds)
OSError: Timeout expired after: 1

Yep, as you can see from the stack trace, the OSError is now being thrown from the decorator, not from within long_func(). So only the toplevel exception handler has a chance to handle the exception. This leads to more intuitive code and hopefully less bugs down the road.

So are there ever times when a signal handler is the right choice?

Signal handlers can be used for more than just timeouts. And they can do other things besides trying to cancel an ongoing function. For instance, the Apache web server has a signal handler for the HUP (HangUP) signal. When it receives that, it reloads its configuration from disk. If you take care to catch any potential exceptions during that process, you shouldn’t run across these caveats because these problems only apply to raising an exception from the signal handler.

When you do want to exit the function via a signal handler, I would be a little hesitant because you can never entirely escape from the drawbacks above and threading provides a more intuitive interface for the called and calling code.  I think that a few practices make it more robust, however.  As mentioned above:

  • Use a custom exception inheriting directly from BaseException to reduce the likelihood that unexpected code will catch it.
  • If you are doing it for a timeout, remember to reset the alarm inside of the signal handler in case something does catch your custom exception.  At least then you’ll continue to hit timeouts until you finally escape the code.

And one that wasn’t mentioned before:

It’s better to make an ad hoc exception-raising signal handler that handles a specific problem inside of a function rather than attempting to place it outside of the function.  For instance:

def timeout_handler():
    raise BaseException('Timeout')

def long_function():
    try:
        old_handler = signal.signal(signal.SIGALRM,
                                    timeout_handler)
        signal.alarm(1)
        time.sleep(10)   # Placeholder for the real,
                         # potentially blocking code
    except BaseException:
        print('Timed out!')
    finally:
        signal.signal(signal.SIGALRM, old_handler)
        signal.alarm(0)

long_function()

Why is it better to do it this way? If you do this, you localize the exception catching due to the signal handler with the code that could potentially raise the exception. It isn’t foolproof as you are likely still going to call helper functions (even in the stdlib) which could handle a conflicting exception but at least you are more clearly aware of the fact that the exception could potentially originate from any of your code inside this function rather than obfuscating all of that behind a function call.

Voluptuous and Python-3.4 Enums

Last year I played around with using jsonschema for validating some data that I was reading into my Python programs.  The API for the library was straightforward but the schema turned out to be a pretty sprawling affair: stellar-schema.json

This past weekend I started looking at adding some persistence to this program in the form of SQLAlchemy-backed data structures.  To make those work right, I decided that I had to change the schema for these data file as well.  Looking at the jsonschema, I realized that I was having a hard time remembering what all the schema keywords meant and how to modify the file; never a good sign!  So I decided to reimplement the schema in something simpler.

At work, we use a library called voluptuous.  One of the nice things about it is that the containers are Python data structures and the types are Python type constructors.  So simple things can be very simple:

from voluptuous import Schema

schema = [
  {'hosts': str,
   'gather_facts': bool,
   'tasks': [],
   }
]

The unfortunate thing about voluptuous is that the reference documentation is very bad.  It doesn’t have any recipe-style documentation which can teach you how to best accomplish tasks.  So when you do have to reach deeper to figure out how to do something a bit more complex, you can try to find an example in the README.md (a pretty good resource but there’s no table of contents so you’re often left wondering whether what you want is documented there or not) or you might have to hunt around with google for a blog post or stackoverflow question that explains how to accomplish that.  If someone else hasn’t discovered an answer already…. well, then, voluptuous may well have just the feature you need but you might never find it.

The feature that I needed today was to integrate Python-3.4’s Enum type with voluptuous.  The closest I found was this feature request which asked for enum support to be added.  The issue was closed with a few examples of “code that works” which didn’t quite explain all the underlying concepts to a new voluptuous user like myself.  I took away three ideas:

  • Schema(Coerce(EnumType)) was supposed to work
  • Schema(EnumType) was supposed to work but…
  • Schema(EnumType) doesn’t work the way the person who opened the issue (or I) would get value from.

I was still confused but at least now I had some pieces of code to try out.  So here’s what my first series of tests looked like:

from enum import Enum 
import voluptuous as v

# My eventual goal: data_to_validate = {"type": "one"}

MyTypes = Enum("MyTypes", ['one', 'two', 'three'])
s1 = v.Schema(MyTypes)
s2 = v.Schema(v.Coerce(MyTypes))

s1('one')  # validation error (figured that would be the case from the ticket)
s1(MyTypes.one)  #  Works but not helpful for me
s2('one')  # validation error (Thought this one would work...)

Hmm… so far, this isn’t looking too hopeful. The only thing that I got working isn’t going to help me validate my actual data. Well, let’s google for what Coerce actually does….

A short while later, I think I see what’s going on. Coerce will attempt to mutate the data given to it into a new type via the function it is given. In the case of an Enum, you can call the Enum on the actual values backing the Enum, not the symbolic labels. For the symbolic labels, you need to use square bracket notation. Square brackets are syntax for the __getitem__ magic method so maybe we can pass that in to get what we want:

MyTypes = Enum("MyTypes", ['one', 'two', 'three'])

s2 = v.Schema(v.Coerce(MyTypes))
s3 = v.Schema(v.Coerce(MyTypes.__getitem__))

s2(1)  # This validates
s3('one')  # And so does this!  Yay!

Okay, so now we think that this all makes sense…. but there’s actually one more wrinkle that we have to work out. It turns out that Coerce only marks a value as Invalid if the function it’s given throws a TypeError or ValueError. __getitem__ throws a KeyError so guess what:

>>> symbolic_only({"type": "five"})
Traceback (most recent call last):
  File "", line 1, in 
  File "/home/badger/.local/lib/python3.6/site-packages/voluptuous/schema_builder.py", line 267, in __call__
    return self._compiled([], data)
  File "/home/badger/.local/lib/python3.6/site-packages/voluptuous/schema_builder.py", line 587, in validate_dict
    return base_validate(path, iteritems(data), out)
  File "/home/badger/.local/lib/python3.6/site-packages/voluptuous/schema_builder.py", line 379, in validate_mapping
    cval = cvalue(key_path, value)
  File "/home/badger/.local/lib/python3.6/site-packages/voluptuous/schema_builder.py", line 769, in validate_callable
    return schema(data)
  File "/home/badger/.local/lib/python3.6/site-packages/voluptuous/validators.py", line 95, in __call__
    return self.type(v)
  File "/usr/lib64/python3.6/enum.py", line 327, in __getitem__
    return cls._member_map_[name]
KeyError: 'five'

The code throws a KeyError instead of a voluptuous Invalid exception. Okay, no problem, we just have to remember to wrap the __getitem__ with a function which returns ValueError if the name isn’t present. Anything else you should be aware of? Well, for my purposes, I only want the enum’s symbolic names to match but what if you wanted either the symbolic names or the actual values to work (s2 or s3)? For that, you can combine this with voluptuous’s Any function. Here’s what those validators will look like:

from enum import Enum 
import voluptuous as v

data_to_validate = {"type": "one"}
MyTypes = Enum("MyTypes", ['one', 'two', 'three'])

def mytypes_validator(value):
    try:
        MyTypes[value]
    except KeyError:
        raise ValueError(f"{value} is not a valid member of MyTypes")
    return value

symbolic_only = v.Schema({"type": v.Coerce(mytypes_validator)})
symbols_and_values = v.Schema({"type":
                       v.Any(v.Coerce(MyTypes),
                             v.Coerce(mytypes_validator),
                     )})

symbolic_only(data_to_validate)   # Hip Hip!
symbols_and_values(data_to_validate)  # Hooray!

symbols_and_values({"type": 1})  # If this is what you *really* want
symbolic_only({"type": 1})  # If you want implementation to remain hidden

symbols_and_values({"type": 5})  # Prove that invalids are getting caught

Better unittest coverage stats: excluding code that only executes on Python 2/3

What we’re testing

In my last post I used pytest to write a unittest for code which wrote to sys.stdout on Python2 and sys.stdout.buffer on Python3. Here’s that code again:

# byte_writer.py
import sys
import six

def write_bytes(some_bytes):
    # some_bytes must be a byte string
    if six.PY3:
        stdout = sys.stdout.buffer
    else:
        stdout = sys.stdout
    stdout.write(some_bytes)

and the unittest to test it:

# test_byte_writer.py
import io
import sys

from byte_writer import write_bytes

def test_write_byte(capfdbinary):
    write_bytes(b'\xff')
    out, err = capfdbinary.readouterr()
    assert out == b'\xff'

This all works great but as we all know, there’s always room for improvement.

Coverage

When working on code that’s intended to run on both Python2 and Python3, you often hear the advice that making sure that every line of your code is executed during automated testing is extremely helpful. This is, in part, because several classes of error that are easy to introduce in such code are easily picked up by tests which simply exercise the code. For instance:

  • Functions which were renamed or moved will throw an error when code attempts to use them.
  • Mixing of byte and text strings will throw an error as soon as the code that combines the strings runs on Python3.

These benefits are over and above the behaviours that your tests were specifically written to check for.

Once you accept that you want to strive for 100% coverage, though, the next question is how to get there. The Python coverage library can help. It integrates with your unittest framework to track which lines of code and which branches have been executed over the course of running your unittests. If you’re using pytest to write your unittests like I am, you will want to install the pytest-cov package to integrate coverage with pytest.

If you recall my previous post, I had to use a recent checkout of pytest to get my unittests working so we have to make sure not to overwrite that version when we install pytest-cov. It’s easiest to create a test-requirements.txt file to make sure that we get the right versions together:

$ cat test-requirements.txt
pytest-cov
git+git://github.com/pytest-dev/pytest.git@6161bcff6e3f07359c94a7be52ad32ecb882214
$ pip install --user --upgrade -r test-requirements.txt
$ pip3 install --user --upgrade -r test-requirements.txt

And then we can run pytest to get coverage information:

$ pytest --cov=byte_writer  --cov-report=term-missing --cov-branch
======================= test session starts ========================
test_byte_writer.py .

---------- coverage: platform linux, python 3.5.4-final-0 ----------
Name             Stmts   Miss Branch BrPart  Cover   Missing
------------------------------------------------------------
byte_writer.py       7      1      2      1    78%   10, 7->10
===================== 1 passed in 0.02 seconds =====================

Yay! The output shows that currently 78% of byte_writer.py is executed when the unittests are run. The Missing column shows two types of missing things. The 10 means that line 10 is not executed. The 7->10 means that there’s a branch in the code where only one of its conditions was executed. Since the two missing pieces coincide, we probably only have to add a single test case that satisfies the second code branch path to reach 100% coverage. Right?

Conditions that depend on Python version

If we take a look at lines 7 through 10 in byte_writer.py we can see that this might not be as easy as we first assumed:

    if six.PY3:
        stdout = sys.stdout.buffer
    else:
        stdout = sys.stdout

Which code path executes here is dependent on the Python version the code runs under. Line 8 is executed when we run on Python3 and Line 10 is skipped. When run on Python2 the opposite happens. Line 10 is executed and Line 8 is skipped. We could mock out the value in six.PY3 to hit both code paths but since sys.stdout only has a buffer attribute on Python3, we’d have to mock that out too. And that would lead us back into conflict with pytest capturing stdout as we we figured out in my previous post.

Taking a step back, it also doesn’t make sense that we’d test both code paths on the same run. Each code path should be executed when the environment is correct for it to run and we’d be better served by modifying the environment to trigger the correct code path. That way we also test that the code is detecting the environment correctly. For instance, if the conditional was triggered by the user’s encoding, we’d probably run the tests under different locale settings to check that each encoding went down the correct code path. In the case of a Python version check, we modify the environment by running the test suite under a different version of Python. So what we really want is to make sure that the correct branch is run when we run the test suite on Python3 and the other branch is run when we execute it under Python2. Since we already have to run the test suite under both Python2 and Python3 this has the net effect of testing all of the code.

So how do we achieve that?

Excluding lines from coverage

Coverage has the ability to match lines in your code to a string and then exclude those lines from the coverage report. This allows you to tell coverage that a branch of code will never be executed and thus it shouldn’t contribute to the list of unexecuted code. By default, the only string to match is “pragma: no cover“. However, “pragma: no cover” will unconditionally exclude the lines it matches which isn’t really what we want. We do want to test for coverage of the branch but only when we’re running on the correct Python version. Luckily, the matched lines are customizable and further, they can include environment variables. The combination of these two abilities means that we can add lines to exclude that are different when we run on Python2 and Python3. Here’s how to configure coverage to do what we need it to:

$ cat .coveragerc
[report]
exclude_lines=
    pragma: no cover
    pragma: no py${PYTEST_PYMAJVER} cover

This .coveragerc includes the standard matched line, pragma: no cover and our addition, pragma: no py${PYTEST_PYMAJVER} cover. The addition uses an environment variable, PYTEST_PYMAJVER so that we can vary the string that’s matched when we invoke pytest.

Next we need to change the code in byte_writer.py so that the special strings are present:

    if six.PY3:  # pragma: no py2 cover
        stdout = sys.stdout.buffer
    else:  # pragma: no py3 cover
        stdout = sys.stdout

As you can see, we’ve added the special comments to the two conditional lines. Let’s run this manually and see if it works now:

$ PYTEST_PYMAJVER=3 pytest --cov=byte_writer  --cov-report=term-missing --cov-branch
======================= test session starts ========================
test_byte_writer.py .
---------- coverage: platform linux, python 3.5.4-final-0 ----------
Name             Stmts   Miss Branch BrPart  Cover   Missing
------------------------------------------------------------
byte_writer.py       6      0      0      0   100%
==================== 1 passed in 0.02 seconds =====================

$ PYTEST_PYMAJVER=2 pytest-2 --cov=byte_writer  --cov-report=term-missing --cov-branch
======================= test session starts ========================
test_byte_writer.py .
--------- coverage: platform linux2, python 2.7.13-final-0 ---------
Name             Stmts   Miss Branch BrPart  Cover   Missing
------------------------------------------------------------
byte_writer.py       5      0      0      0   100%
===================== 1 passed in 0.02 seconds =====================

Well, that seemed to work! Let’s run it once more with the wrong PYTEST_PYMAJVER value to show that coverage is still recording information on the branches that are supposed to be used:

$ PYTEST_PYMAJVER=3 pytest-2 --cov=byte_writer  --cov-report=term-missing --cov-branch
======================= test session starts ========================
test_byte_writer.py .
-------- coverage: platform linux2, python 2.7.13-final-0 --------
Name             Stmts   Miss Branch BrPart  Cover   Missing
------------------------------------------------------------
byte_writer.py       6      1      0      0    83%   8
===================== 1 passed in 0.02 seconds =====================

Yep. When we specify the wrong PYTEST_PYMAJVER value, the coverage report shows that the missing line is included as an unexecuted line. So that seems to be working.

Setting PYTEST_PYMAJVER automatically

Just one more thing… it’s kind of a pain to have to set the PYTEST_PYMAJVER variable with every test run, isn’t it? Wouldn’t it be better if pytest would automatically set that for you? After all, pytest knows which Python version it’s running under so it should be able to. I thought so too so I wrote pytest-env-info to do just that. When installed, pytest-env-info will set PYTEST_VER, PYTEST_PYVER, and PYTEST_PYMAJVER so that they are available to pytest_cov and other, similar plugins which can use environment variables to configure themselves. It’s available on pypi so all you have to do to enable it is add it to your requirements so that pip will install it and then run pytest:

$ cat test-requirements.txt
pytest-cov
git+git://github.com/pytest-dev/pytest.git@6161bcff6e3f07359c94a7be52ad32ecb8822142
pytest-env-info
$ pip3 install --user -r test-requirements.txt
$ pytest --cov=byte_writer  --cov-report=term-missing --cov-branch
======================== test session starts =========================
plugins: env-info-0.1.0, cov-2.5.1
test_byte_writer.py .
---------- coverage: platform linux2, python 2.7.14-final-0 ----------
Name             Stmts   Miss Branch BrPart  Cover   Missing
------------------------------------------------------------
byte_writer.py       5      0      0      0   100%
====================== 1 passed in 0.02 seconds ======================

Python testing: Asserting raw byte output with pytest

The Code to Test

When writing code that can run on both Python2 and Python3, I’ve sometimes found that I need to send and receive bytes to stdout. Here’s some typical code I might write to do that:

# byte_writer.py
import sys
import six

def write_bytes(some_bytes):
    # some_bytes must be a byte string
    if six.PY3:
        stdout = sys.stdout.buffer
    else:
        stdout = sys.stdout
    stdout.write(some_bytes)

if __name__ == '__main__':
    write_bytes(b'\xff')

In this example, my code needs to write a raw byte to stdout. To do this, it uses sys.stdout.buffer on Python3 to circumvent the automatic encoding/decoding that occurs on Python3’s sys.stdout. So far so good. Python2 expects bytes to be written to sys.stdout by default so we can write the byte string directly to sys.stdout in that case.

The First Attempt: Pytest newb, but willing to learn!

Recently I wanted to write a unittest for some code like that. I had never done this in pytest before so my first try looked a lot like my experience with nose or unittest2: override sys.stdout with an io.BytesIO object and then assert that the right values showed up in sys.stdout:

# test_byte_writer.py
import io
import sys

import mock
import pytest
import six

from byte_writer import write_bytes

@pytest.fixture
def stdout():
    real_stdout = sys.stdout
    fake_stdout = io.BytesIO()
    if six.PY3:
        sys.stdout = mock.MagicMock()
        sys.stdout.buffer = fake_stdout
    else:
        sys.stdout = fake_stdout

    yield fake_stdout

    sys.stdout = real_stdout

def test_write_bytes(stdout):
    write_bytes()
    assert stdout.getvalue() == b'a'

This gave me an error:

[pts/38@roan /var/tmp/py3_coverage]$ pytest              (07:46:36)
_________________________ test_write_byte __________________________

stdout = 

    def test_write_byte(stdout):
        write_bytes(b'a')
>       assert stdout.getvalue() == b'a'
E       AssertionError: assert b'' == b'a'
E         Right contains more items, first extra item: 97
E         Use -v to get the full diff

test_byte_writer.py:27: AssertionError
----------------------- Captured stdout call -----------------------
a
===================== 1 failed in 0.03 seconds =====================

I could plainly see from pytest’s “Captured stdout” output that my test value had been printed to stdout. So it appeared that my stdout fixture just wasn’t capturing what was printed there. What could be the problem? Hmmm…. Captured stdout… If pytest is capturing stdout, then perhaps my fixture is getting overridden by pytest’s internal facility. Let’s google and see if there’s a solution.

The Second Attempt: Hey, that’s really neat!

Wow, not only did I find that there is a way to capture stdout with pytest, I found that you don’t have to write your own fixture to do so. You can just hook into pytest’s builtin capfd fixture to do so. Cool, that should be much simpler:

# test_byte_writer.py
import io
import sys

from byte_writer import write_bytes

def test_write_byte(capfd):
    write_bytes(b'a')
    out, err = capfd.readouterr()
    assert out == b'a'

Okay, that works fine on Python2 but on Python3 it gives:

[pts/38@roan /var/tmp/py3_coverage]$ pytest              (07:46:41)
_________________________ test_write_byte __________________________

capfd = 

    def test_write_byte(capfd):
        write_bytes(b'a')
        out, err = capfd.readouterr()
>       assert out == b'a'
E       AssertionError: assert 'a' == b'a'

test_byte_writer.py:10: AssertionError
===================== 1 failed in 0.02 seconds =====================

The assert looks innocuous enough. So if I was an insufficiently paranoid person I might be tempted to think that this was just stdout using python native string types (bytes on Python2 and text on Python3) so the solution would be to use a native string here ("a" instead of b"a". However, where the correctness of anyone else’s bytes <=> text string code is concerned, I subscribe to the philosophy that you can never be too paranoid. So….

The Third Attempt: I bet I can break this more!

Rather than make the seemingly easy fix of switching the test expectation from b"a" to "a" I decided that I should test whether some harder test data would break either pytest or my code. Now my code is intended to push bytes out to stdout even if those bytes are non-decodable in the user’s selected encoding. On modern UNIX systems this is usually controlled by the user’s locale. And most of the time, the locale setting specifies a UTF-8 compatible encoding. With that in mind, what happens when I pass a byte string that is not legal in UTF-8 to write_bytes() in my test function?

# test_byte_writer.py
import io
import sys

from byte_writer import write_bytes

def test_write_byte(capfd):
    write_bytes(b'\xff')
    out, err = capfd.readouterr()
    assert out == b'\xff'

Here I adapted the test function to attempt writing the byte 0xff (255) to stdout. In UTF-8, this is an illegal byte (ie: by itself, that byte cannot be mapped to any unicode code point) which makes it good for testing this. (If you want to make a truly robust unittest, you should probably standardize on the locale settings (and hence, the encoding) to use when running the tests. However, that deserves a blog post of its own.) Anyone want to guess what happens when I run this test?

[pts/38@roan /var/tmp/py3_coverage]$ pytest              (08:19:52)
_________________________ test_write_byte __________________________

capfd = 

    def test_write_byte(capfd):
        write_bytes(b'\xff')
        out, err = capfd.readouterr()
>       assert out == b'\xff'
E       AssertionError: assert '�' == b'\xff'

test_byte_writer.py:10: AssertionError
===================== 1 failed in 0.02 seconds =====================

On Python3, we see that the undecodable byte is replaced with the unicode replacement character. Pytest is likely running the equivalent of b"Byte string".decode(errors="replace") on stdout. This is good when capfd is used to display the Captured stdout call information to the console. Unfortunately, it is not what we need when we want to check that our exact byte string was emitted to stdout.

With this change, it also becomes apparent that the test isn’t doing the right thing on Python2 either:

[pts/38@roan /var/tmp/py3_coverage]$ pytest-2            (08:59:37)
_________________________ test_write_byte __________________________

capfd = 

    def test_write_byte(capfd):
        write_bytes(b'\xff')
        out, err = capfd.readouterr()
>       assert out == b'\xff'
E       AssertionError: assert '�' == '\xff'
E         - �
E         + \xff

test_byte_writer.py:10: AssertionError
========================= warnings summary =========================
test_byte_writer.py::test_write_byte
  /var/tmp/py3_coverage/test_byte_writer.py:10: UnicodeWarning: Unicode equal comparison failed to convert both arguments to Unicode - interpreting them as being unequal
    assert out == b'\xff'

-- Docs: http://doc.pytest.org/en/latest/warnings.html
=============== 1 failed, 1 warnings in 0.02 seconds ===============

In the previous version, this test passed. Now we see that the test was passing because Python2 evaluates u"a" == b"a" as True. However, that’s not really what we want to test; we want to test that the byte string we passed to write_bytes() is the actual byte string that was emitted on stdout. The new data shows that instead, the test is converting the value that got to stdout into a text string and then trying to compare that. So a fix is needed on both Python2 and Python3.

These problems are down in the guts of pytest. How are we going to fix them? Will we have to seek out a different strategy that lets us capture stdout, overriding pytest’s builtin?

The Fourth Attempt: Fortuitous Timing!

Well, as it turns out, the pytest maintainers merged a pull request four days ago which implements a capfdbinary fixture. capfdbinary is like the capfd fixture that I was using in the above example but returns data as byte strings instead of as text strings. Let’s install it and see what happens:

$ pip install --user git+git://github.com/pytest-dev/pytest.git@6161bcff6e3f07359c94a7be52ad32ecb8822142
$ mv ~/.local/bin/pytest ~/.local/bin/pytest-2
$ pip3 install --user git+git://github.com/pytest-dev/pytest.git@6161bcff6e3f07359c94a7be52ad32ecb8822142

And then update the test to use capfdbinary instead of capfd:

# test_byte_writer.py
import io
import sys

from byte_writer import write_bytes

def test_write_byte(capfdbinary):
    write_bytes(b'\xff')
    out, err = capfdbinary.readouterr()
    assert out == b'\xff'

And with those changes, the tests now pass:

[pts/38@roan /var/tmp/py3_coverage]$ pytest              (11:42:06)
======================= test session starts ========================
platform linux -- Python 3.5.4, pytest-3.2.5.dev194+ng6161bcf, py-1.4.34, pluggy-0.5.2
rootdir: /var/tmp/py3_coverage, inifile:
plugins: xdist-1.15.0, mock-1.5.0, cov-2.4.0, asyncio-0.5.0
collected 1 item                                                    

test_byte_writer.py .

===================== 1 passed in 0.01 seconds =====================

Yay! Mission accomplished.

Python2, string .format(), and unicode

Primer

If you’ve dealt with unicode and byte str mixing in python2 before, you’ll know that there are certain percent-formatting operations that you absolutely should not do with them. For instance, if you are combining a string of each type and they both have non-ascii characters then you are going to get a traceback:

>>> print(u'くら%s' % (b'とみ',))
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
UnicodeDecodeError: 'ascii' codec can't decode byte 0xe3 in position 0: ordinal not in range(128)
>>> print(b'くら%s' % (u'とみ',))
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
UnicodeDecodeError: 'ascii' codec can't decode byte 0xe3 in position 0: ordinal not in range(128)

The canonical answer to this is to clean up your code to not mix unicode and byte str which seems fair enough here. You can convert one of the two strings to match with the other fairly easily:

>>> print(u'くら%s' % (unicode(b'とみ', 'utf-8'),))
くらとみ

However, if you’re part of a project which was written before the need to separate the two string types was realized you may be mixing the two types sometimes and relying on bug reports and python tracebacks to alert you to pieces of the code that need to be fixed. If you don’t get tracebacks then you may not bother to explicitly convert in some cases. Unfortunately, as code is changed you may find that the areas you thought of as safe to mix aren’t quite as broad as they first appeared. That can lead to UnicodeError exceptions suddenly popping up in your code with seemingly harmless changes….

A New Idiom

If you’re like me and trying to adopt python3-supported idioms into your python-2.6+ code bases then one of the changes you may be making is to switch from using percent formatting to construct your strings to the new string .format() method. This is usually fairly straightforward:

name = u"Kuratomi"

# Old style
print("Hello Mr. %s!" % (name,))

# New style
print("Hello Mr. {0}!".format(name))

# Output:
Hello Mr. Kuratomi!
Hello Mr. Kuratomi!

This seems like an obvious transformation with no possibility of UnicodeError being thrown. And for this simple example you’d be right. But we all know that real code is a little more obfuscated than that. So let’s start making this a little more real-world, shall we?

name = u"くらとみ"
print("Hello Mr. %s!" % (name,))
print("Hello Mr. {0}!".format(name))

# Output
Hello Mr. くらとみ!
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
UnicodeEncodeError: 'ascii' codec can't encode characters in position 0-3: ordinal not in range(128)

What happened here? In our code we set name to a unicode string that has non-ascii characters. Used with the old-style percent formatting, this continued to work fine. But with the new-style .format() method we ended up with a UnicodeError. Why? Well under the hood, the percent formatting uses the “%” operator. The function that handles the “%” operator (__mod__()) sees that you were given two strings one of which is a byte str and one of which is a unicode string. It then decides to convert the byte str to a unicode string and combine the two. Since our example only has ascii characters in the byte string, it converts successfully and python can then construct the unicode string u"Hello Mr. くらとみ!". Since it’s always the byte str that’s converted to unicode type we can build up an idea of what things will work and which will throw an exception:

# These are good as the byte string
# which is converted is ascii-only
"Mr. %s" % (u"くらとみ",)
u"%s くらとみ" % ("Mr.",)

# Output of either of those:
u"Mr. くらとみ"

# These will throw an exception as the
# *byte string* contains non-ascii characters
u"Mr. %s" % ("くらとみ",)
"%s くらとみ" % (u"Mr",)

Okay, so that explains what’s happening with the percent-formatting example. What’s happening with the .format() code? .format() is a method of one of the two string types (str for python2 byte strings or unicode for python2 text strings). This gives programmers a feeling that the method is more closely associated with the type it is a method of than the parameters that it is given. So the design decision was made that the method should convert to the type that the method is bound to instead of always converting to unicode string type. This means that we have to make sure parameters can be converted to the type of the format string rather than always to unicode. Taking that in mind, this is the matrix of things we expect to work and expect to fail:

# These are good as the parameter string
# which is converted is ascii-only
u"{0} くらとみ".format("Mr.")
"{0} くらとみ".format(u"Mr.")

# Output (first is a unicode, second is a str):
u"Mr. くらとみ"
"Mr. くらとみ"

# These will throw an exception as the
# parameters contain non-ascii characters
u"Mr. {0}".format("くらとみ")
"Mr. {0}".format(u"くらとみ")

So now we know why we get a traceback in the converted code but not in the original code. Let’s apply this to our example:

name = u"くらとみ"
# name is a unicode type so we need to make
# sure .format() does not implicitly convert it
print(u"Hello Mr. {0}!".format(name))

# Output
Hello Mr. くらとみ!

Alright! That seems good now, right? Are we done? Well, let’s take this real-world thing one step farther. With real-world users we often get transient errors because users are entering a value we didn’t test with. In real-world code, variables often aren’t being set a few lines above where you’re using them. Instead, they’re coming from user input or a config file or command line parsing which happened tens of function calls and thousands of lines away from where you are encountering your traceback. After you step through your program for a few hours you may be able to realize that the relation between your variable and where it is used looks something like this:

# Near the start of your program
name = raw_input("Your name")
if not name.strip():
    name = u"くらとみ"

# [..thousands of lines of code..]

print(u"Hello Mr. {0}!".format(name))

So what’s happening? There’s two ways that our variable could be set. One of those ways (the return from raw_input()) sets it to a byte str. The other way (when we set the default value) sets it to a unicode string. The way we’re using the variable in the print() function means that the value will be converted to a unicode string if it’s a byte string. Remember that we earlier determined that ascii-only byte strings would convert but non-ascii byte strings would throw an error. So that means the code will behave correctly if the default is used or if the user enters “Kuratomi” but it will throw an exception if the user enters “くらとみ” because it has non-ascii characters.

This is where explicit conversion comes in. We need to explicitly convert the value to a unicode string so that we do not throw a traceback when we use it later. There’s two sensible locations to do that conversion. The better long term option is to convert where the variable is being set:

name = raw_input("Your name")
name = unicode(name, "utf-8", "replace")
if not name.strip():
    name = u"くらとみ"

Doing it there means that everywhere in your code you know that the variable will contain a unicode string. If you do this to all of your variables you will get to the point where you know that all of your variables are unicode strings unless you are explicitly converting them to byte str (or have special variables that should always be bytes — in which case you should have a naming convention to identify them). Having this sort of default makes it much easier to write code that uses the variable without fearing that it will unexpectedly cause tracebacks.

The other point at which you can convert is at the point that the variable is being used:

if isinstance(name, 'str'):
    name = unicode(name, 'utf-8', 'replace')
print(u"Hello Mr. {0}!".format(name))

The drawbacks to setting the variable here include having to put this code in wherever you are using it (usually more places than the variable could be set) and having to add the isinstance check because you don’t know whether it was set to a unicode or str type at this point. However, it can be useful to use this strategy when you have some critical code deployed and you know you’re getting tracebacks at a specific location but don’t know what unintended consequences might occur from changing the type of the variable everywhere. In this case you might analyze the problem for a bit and decide to hotfix your production machines to convert at the point of use but in your development tree you change it where the variable is being set so that you have a bit more time to work your way through all the places that shows you that you are mixing string types.

Dear Lazyweb, how would you nicely bundle python code?

I’ve been looking into bundling the python six library into ansible because it’s getting painful to maintain compatibility with the old versions on some distros. However, the distribution developer in me wanted to make it easy for distro packagers to make sure the system copy was used rather than the bundled copy if needed and also make it easy for other ansible developers to make use of it. It seemed like the way to achieve that was to make an import in our namespace that would transparently decide which version of six was needed and use that. I figured out three ways of doing this but haven’t figured out which is better. So throwing the three ways out there in the hopes that some python gurus can help me understand the pros and cons of each (and perhaps improve on what I have so far).

Boilerplate
To be both transparent to our developers and use system packages if the system had a recent enough six, I created a six package in our namespace. Inside of this module I included the real six library as _six.py. Then I created an __init__.py with code to decide whether to use the system six or the bundled _six.py. So the directory layout is like this:

+ ansible/
  + __init__.py
  + compat/
    + __init__.py
    + six/
      + __init__.py
      + _six.py

__init__.py has two tasks. It has to determine whether we want the system six library or the bundled one. And then it has to make that choice what other code gets when it does import ansible.compat.six. here’s the basic boilerplate:

# Does the system have a six library installed?
try:
    import six as _system_six
except ImportError:
    _system_six = None

if _system_six:
    # Various checks that system six library is current enough
    if not hasattr(_system_six.moves, 'shlex_quote'):
        _system_six = None

if _system_six:
    # Here's where we have to load up the system six library
else:
    # Alternatively, we load up the bundled library

Loading using standard import
Now things start to get interesting. We know which version of the six library we want. We just have to make it available to people who are now going to use it. In the past, I’d used the standard import mechanism so that was the first thing I tried here:

if _system_six:
    from six import *
else:
    from ._six import *

As a general way of doing this, it has some caveats. It only pulls in the symbols that the module considers public. If a module has any functions or variables that are supposed to be public and marked with a leading underscore then they won’t be pulled in. Or if a module has an __all__ = [...] that doesn’t contain all of the public symbols then those won’t get pulled in. You can pull those additions in by specifying them explicitly if you have to.

For this case, we don’t have any issues with those as six doesn’t use __all__ and none of the public symbols are marked with a leading underscore. However, when I then started porting the ansible code to use ansible.compat.six I encountered an interesting problem:

# Simple things like this work
>>> from ansible.compat.six import moves
>>> moves.urllib.parse.urlsplit('https://toshio.fedorapeople.org/')
SplitResult(scheme='https', netloc='toshio.fedorapeople.org', path='/', query='', fragment='')

# this throws an error:
>>> from ansible.compat.six.moves import urllib
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
ImportError: No module named moves

Hmm… I’m not quite sure what’s happening but I zero in on the word “module”. Maybe there’s something special about modules such that import * doesn’t give me access to import subpackages or submodules of that. Time to look for answers on the Internet…

The Sorta, Kinda, Hush-Hush, Semi-Official Way

Googling for a means to replace a module from itself eventually leads to a strategy that seems to have both some people who like it and some who don’t. It seems to be supported officially but people don’t want to encourage people to use it. It involves a module replacing its own entry in sys.modules. Going back to our example, it looks like this:

import sys
[...]
if _system_six:
    six = _system_six
else:
    from . import _six as six

sys.modules['ansible.compat.six'] = six

When I ran this with a simple test case of a python package with several nested modules, that seemed to clear up the problem. I was able to import submodules of the real module from my fake module just fine. So I was hopeful that everything would be fine when I implemented it for six.

Nope:

>>> from ansible.compat.six.moves import urllib
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
ImportError: No module named moves

Hmm… same error. So I take a look inside of six.py to see if there’s any clue as to why my simple test case with multiple files and directories worked but six’s single file is giving us headaches. Inside I find that six is doing its own magic with a custom importer to make moves work. I spend a little while trying to figure out if there’s something specifically conflicting between my code and six’s code and then throw my hands up. There’s a lot of stuff that I’ve never used before here… it’ll take me a while to wrap my head around it and there’s no assurance that I’ll be able to make my code work with what six is doing even after I understand it. Is there anything else I could try to just tell my code to run everything that six would normally do when it is imported but do it in my ansible.compat.six namespace?

You tell me: Am I beating my code with the ugly stick?

As a matter of fact, python does provide us with a keyword in python2 and a function in python3 that might do exactly that. So here’s strategy number three:

import os.path
[...]
if _system_six:
    import six
else:
    from . import _six as six
six_py_file = '{0}.py'.format(os.path.splitext(six.__file__)[0])
exec (open(six_py_file, 'r'))

Yep, exec will take an open file handle of a python module and execute it in the current namespace. So this seems like it will do what we want. Let’s test it:

>>> from ansible.compat.six.moves import urllib
>>>
>>> from ansible.compat.six.moves.urllib.parse import urlsplit
>>> urlsplit('https://toshio.fedorapeople.org/')
SplitResult(scheme='https', netloc='toshio.fedorapeople.org', path='/', query='', fragment='')

So dear readers, you tell me — I now have some code that works but it relies on exec. And moreover, it relies on exec to overwrite the current namespace. Is this a good idea or a bad idea? Let’s contemplate a little further — is this an idea that should only be applied sparingly (Using sys.modules instead if the module isn’t messing around with a custom importer of its own) or is it a general purpose strategy that should be applied to other libraries that I might bundle as well? Are there caveats to doing things this way? For instance, is it bypassing the standard import caching and so might be slower? Is there a better way to do this that in my ignorance I jsut don’t know about?