greedy algorithm with coroutines

2013.May.01

A number of common problems are optimally solved by greedy algorithms: algorithms where a locally optimal choice at each stage of the calculation leads to a globally optimal solution.

The most common example of this is change counting. When trying to make change out of a cash register using American coinage, we can minimise the numer of coins returned by first returning as many quarters as possible, then returning as many dimes, nickles, and finally pennies. In each step of this process, we don’t need any information from previous steps.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
# we'll use decimal.Decimal for exact precision
from decimal import Decimal
quarter, dime, nickel, penny = (Decimal(x) for x in ('0.25', '0.10', '0.05', '0.01'))

# we imagine a cash register with infinite supplies of each denomination
def change(amount):
	returned_change = []

	# try to return as many quarters as possible:
	while sum(returned_change) + quarter <= amount:
		returned_change.append(quarter)
	
	# next try dimes
	while sum(returned_change) + dime <= amount:
		returned_change.append(dime)

	# nickles
	while sum(returned_change) + nickel <= amount:
		returned_change.append(nickel)

	# pennies
	while sum(returned_change) + penny <= amount:
		returned_change.append(penny)
	
	return returned_change

assert change(Decimal('0.75')) == [quarter, quarter, quarter]
assert change(Decimal('0.35')) == [quarter, dime]
assert change(Decimal('0.43')) == [quarter, dime, nickel, penny, penny, penny]

Let’s try refactoring this to boil it down to its essence.

We’ll now parameterise this on the coinage (so we can calculate change when we go visit Australia.)

We’ll also turn it into a generator, since maybe it’s more polite for us to allow the caller to decide what structure to put the results into.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# we'll use decimal.Decimal for exact precision
from decimal import Decimal
quarter, dime, nickel, penny = (Decimal(x) for x in ('0.25', '0.10', '0.05', '0.01'))
coinage = set([quarter, dime, nickel, penny])

def change(amount, coinage=coinage):
	returned_amount = Decimal('0')
	for coin in reversed(sorted(coinage)):
		while returned_amount + coin <= amount:
			yield coin
			returned_amount += coin

assert list(change(Decimal('0.75'))) == [quarter, quarter, quarter]
assert list(change(Decimal('0.35'))) == [quarter, dime]
assert list(change(Decimal('0.43'))) == [quarter, dime, nickel, penny, penny, penny]

Note that the essence of the above is: pick coins from our infinite register until we meet some condition (we’ve exceeded the amount we want to return).

Let’s turn this inside out and try to use itertools primitives to model this simple greedy algorithm:

1
2
from itertools import repeat, chain, takewhile
greedy = lambda items, predicate: chain.from_iterable(takewhile(predicate,repeat(x)) for x in reversed(sorted(items)))

We’ll take a set of items and a predicate and give the result of picking items until that predicate signals us to move on.

In the example of giving change, we’ll take

1
2
3
4
5
6
7
8
from itertools import repeat, chain, takewhile
def greedy(items, predicate):
	return chain.from_iterable( # each takewhile gives a list, so let's flatten them
	         takewhile(predicate,repeat(x)) # take from an infinite reservoir of items (coins)
	                                        #   until we're told to stop
	         for x in reversed(sorted(items))) # do this for each of the items we're
	                                           #   we're allowed to pick from (coin denominations)
	                                           #   and do it in reverse, sorted order (big coins to small coins)

Now, predicate is easy to write. It just needs to be some stateful callable that keeps track of the amount we’ve built up so far and tells us when to stop.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
from decimal import Decimal
quarter, dime, nickel, penny = (Decimal(x) for x in ('0.25', '0.10', '0.05', '0.01'))
coinage = set([quarter, dime, nickel, penny])

def pred(amount, state=None):
	state = state or [] # capture the state in the closure
	def takecoin(item):
		if sum(state) + item <= amount:
			state.append(item)
			return True
		return False
	return takecoin

from itertools import repeat, chain, takewhile
greedy = lambda items, predicate: chain.from_iterable(takewhile(predicate,repeat(x)) for x in reversed(sorted(items)))

assert list(greedy(coinage, pred(Decimal('0.75')))) == [quarter, quarter, quarter]
assert list(greedy(coinage, pred(Decimal('0.35')))) == [quarter, dime]
assert list(greedy(coinage, pred(Decimal('0.43')))) == [quarter, dime, nickel, penny, penny, penny]

We don’t need to use a rich object for the above; we can capture the current state easily enough by just closing around some mutable state.

We can also model our stateful predicate using a generator. We’ll retain the same interface (the predicate is just some callable,) but instead of capturing state in some mutable cell of our closure (the state list in the above,) we’ll capture the state directly in the generator.

This works, because a generator is merely a delayed, pausable, and resumable computation. In CPython terms, a generator is merely a function and some state for that function: a code object representing the code being run and a frame object, which contains all the standard information we’d expect in a stack frame. This includes the programme counter (last instruction processed,) the set of local variables, &c.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from decimal import Decimal
quarter, dime, nickel, penny = (Decimal(x) for x in ('0.25', '0.10', '0.05', '0.01'))
coinage = set([quarter, dime, nickel, penny])

# this is a syntactical annoyance with Python coroutines:
#   we need to pump/prime the generator
# this is the source of the "can't send non-None value to a just-started generator
# I'll explain this in depth in a future blog post
def primed(gen):
	def decorator(*args, **kwargs):
		instance = gen(*args, **kwargs)
		next(instance) # pump/prime
		return instance
	return decorator

@primed
def accept(amount, state=0):
	item = yield None # accept the first item uncritically
	while True:
		if state+item <= amount:
			state += item
			item = yield True
		else:
			item = yield False

from itertools import repeat, chain, takewhile
greedy = lambda items, predicate: chain.from_iterable(takewhile(predicate,repeat(x)) for x in reversed(sorted(items)))

assert list(greedy(coinage, accept(Decimal('0.75')).send)) == [quarter, quarter, quarter]
assert list(greedy(coinage, accept(Decimal('0.35')).send)) == [quarter, dime]
assert list(greedy(coinage, accept(Decimal('0.43')).send)) == [quarter, dime, nickel, penny, penny, penny]

In this case, this isn’t necessarily any better than the mutable closure from above. There are, however, cases where this may be a faster, cleaner, or better solution. (For example, it’s much clumsier to deal with a value-type in a closure without having to construct your own cell.)

One nice thing about either one of these solutions is that we can apply it to a variety of problems.

For example, both change counting and conversion from arabic to roman numerals can be solved with the same modelling.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
# vim: set fileencoding=utf-8 :
from __future__ import division

def primed(gen):
	def decorator(*args, **kwargs):
		instance = gen(*args, **kwargs)
		next(instance) # pump/prime
		return instance.send
	return decorator

@primed
def pred(amount, state=0):
	item = yield None
	while True:
		if state+item <= amount:
			state += item
			item = yield True
		else:
			item = yield False

from itertools import repeat, chain, takewhile
greedy = lambda items, predicate: chain.from_iterable(takewhile(predicate,repeat(x)) for x in reversed(sorted(items)))

# arabic to roman numerals problem: symbol mapping
mapping = {  1:  'i',   4: 'iv',    5:  'v',   9: 'ix',  10: 'x',
            40: 'ix',  50:  'x',   90: 'xc', 100:  'c', 400: 'cd',
           500:  'd', 900: 'cm', 1000:  'm',}
class numerals(list):
	def __format__(self, fmt):
		return ''.join(mapping[x].upper() for x in self)

# change counting problem: coin denominations
denominations = {1,5,10,25,100,500,1000,2000}
class purse(list):
	def __format__(self, fmt):
		return ' + '.join('{:d}×{}'.format(
		               sum(1 for _ in cs),
		               ('{:d}¢' if c < 100 else '{:.0f}$').format(c if c < 100 else c/100))
		             for c,cs in groupby(self))

if __name__ == '__main__':
	from itertools import groupby
	from random import randint

	for _ in xrange(4):
		arabic = randint(1900,2200)
		roman = greedy(mapping,pred(arabic))
		print 'The year {} is written {}'.format(arabic, numerals(roman))

	for _ in xrange(4):
		amount = randint(0,1000)
		coins = greedy(denominations,pred(amount))
		print 'Your change for {:.2f}$ = {}'.format(amount/100, purse(coins))