I don't quite remember how I arrived at this, but it occurred to me
last week that probability distributions form a monad. This is the
first time I've invented a new monad that I hadn't seen before; then I
implemented it and it behaved pretty much the way I thought it would.
So I feel like I've finally arrived, monadwise.
Suppose a monad value represents all the possible outcomes of an
event, each with a probability of occurrence. For concreteness, let's
suppose all our probability distributions are discrete. Then we might
have:
data ProbDist p a = ProbDist [(a,p)] deriving (Eq, Show)
unpd (ProbDist ps) = ps
Each
a is an outcome, and each
p is the
probability of that outcome occurring. For example, biased and
unbiased coins:
unbiasedCoin = ProbDist [ ("heads", 0.5),
("tails", 0.5) ];
biasedCoin = ProbDist [ ("heads", 0.6),
("tails", 0.4) ];
Or a couple of simple functions for making dice:
import Data.Ratio
d sides = ProbDist [(i, 1 % sides) | i - [1 .. sides]]
die = d 6
d n is an n-sided die.
The Functor instance is straightforward:
instance Functor (ProbDist p) where
fmap f (ProbDist pas) = ProbDist $ map (\(a,p) -> (f a, p)) pas
The
Monad instance requires
return and
>>=. The
return function merely takes an event and
turns it into a distribution where that event occurs with probability
1. I find
join easier to think about than
>>=.
The
join function takes a nested distribution, where each
outcome of the outer distribution specifies an inner distribution for
the actual events, and collapses it into a regular, overall
distribution. For example, suppose you put a biased coin and an
unbiased coin in a bag, then pull one out and flip it:
bag :: ProbDist Double (ProbDist Double String)
bag = ProbDist [ (biasedCoin, 0.5),
(unbiasedCoin, 0.5) ]
The
join operator collapses this into a single
ProbDist
Double String:
ProbDist [("heads",0.3),
("tails",0.2),
("heads",0.25),
("tails",0.25)]
It would be nice if
join could combine the duplicate
heads into a single
("heads", 0.55) entry. But that
would force an
Eq a constraint on the event type, which isn't
allowed, because
(>>=) must work for all data types, not
just for instances of
Eq. This is a problem with Haskell,
not with the monad itself. It's the same problem that prevents one
from making a good set monad in Haskell, even though categorially sets
are a perfectly good monad. (The
return function constructs
singletons, and the
join function is simply set union.)
Maybe in the next language.
Perhaps someone else will find the >>= operator easier to
understand than join? I don't know. Anyway, it's simple
enough to derive once you understand join; here's the
code:
instance (Num p) => Monad (ProbDist p) where
return a = ProbDist [(a, 1)]
(ProbDist pas) >>= f = ProbDist $ do
(a, p) <- pas
let (ProbDist pbs) = f a
(b, q) <- pbs
return (b, p*q)
So now we can do some straightforward experiments:
liftM2 (+) (d 6) (d 6)
ProbDist [(2,1 % 36),(3,1 % 36),(4,1 % 36),(5,1 % 36),(6,1 %
36),(7,1 % 36),(3,1 % 36),(4,1 % 36),(5,1 % 36),(6,1 %
36),(7,1 % 36),(8,1 % 36),(4,1 % 36),(5,1 % 36),(6,1 %
36),(7,1 % 36),(8,1 % 36),(9,1 % 36),(5,1 % 36),(6,1 %
36),(7,1 % 36),(8,1 % 36),(9,1 % 36),(10,1 % 36),(6,1 %
36),(7,1 % 36),(8,1 % 36),(9,1 % 36),(10,1 % 36),(11,1 %
36),(7,1 % 36),(8,1 % 36),(9,1 % 36),(10,1 % 36),(11,1 %
36),(12,1 % 36)]
This is nasty-looking; we really need to merge the multiple listings
of the same event. Here is a function to do that:
agglomerate :: (Num p, Eq b) => (a -> b) -> ProbDist p a -> ProbDist p b
agglomerate f pd = ProbDist $ foldr insert [] (unpd (fmap f pd)) where
insert (k, p) [] = [(k, p)]
insert (k, p) ((k', p'):kps) | k == k' = (k, p+p'):kps
| otherwise = (k', p'):(insert (k,p) kps)
agg :: (Num p, Eq a) => ProbDist p a -> ProbDist p a
agg = agglomerate id
Then
agg $ liftM2 (+) (d 6) (d 6) produces:
ProbDist [(12,1 % 36),(11,1 % 18),(10,1 % 12),(9,1 % 9),
(8,5 % 36),(7,1 % 6),(6,5 % 36),(5,1 % 9),
(4,1 % 12),(3,1 % 18),(2,1 % 36)]
Hey, that's correct.
There must be a shorter way to write insert. It really
bothers me, because it looks look it should be possible to do it as a
fold. But I couldn't make it look any better.
You are not limited to calculating probabilities. The monad actually
will count things. For example, let us throw three dice and count how
many ways there are to throw various numbers of sixes:
eq6 n = if n == 6 then 1 else 0
agg $ liftM3 (\a b c -> eq6 a + eq6 b + eq6 c) die die die
ProbDist [(3,1),(2,15),(1,75),(0,125)]
There is one way to throw three sixes, 15 ways to throw two sixes, 75
ways to throw one six, and 125 ways to throw no sixes. So
ProbDist is a misnomer.
It's easy to
convert counts to probabilities:
probMap :: (p -> q) -> ProbDist p a -> ProbDist q a
probMap f (ProbDist pds) = ProbDist $ (map (\(a,p) -> (a, f p))) pds
normalize :: (Fractional p) => ProbDist p a -> ProbDist p a
normalize pd@(ProbDist pas) = probMap (/ total) pd where
total = sum . (map snd) $ pas
normalize $ agg $ probMap toRational $
liftM3 (\a b c -> eq6 a + eq6 b + eq6 c) die die die
ProbDist [(3,1 % 216),(2,5 % 72),(1,25 % 72),(0,125 % 216)]
I think this is the first time I've gotten to write
die die
die in a computer program.
The do notation is very nice. Here we calculate the
distribution where we roll four dice and discard the smallest:
stat = do
a <- d 6
b <- d 6
c <- d 6
d <- d 6
return (a+b+c+d - minimum [a,b,c,d])
probMap fromRational $ agg stat
ProbDist [(18,1.6203703703703703e-2),
(17,4.1666666666666664e-2), (16,7.253086419753087e-2),
(15,0.10108024691358025), (14,0.12345679012345678),
(13,0.13271604938271606), (12,0.12885802469135801),
(11,0.11419753086419752), (10,9.41358024691358e-2),
(9,7.021604938271606e-2), (8,4.7839506172839504e-2),
(7,2.9320987654320986e-2), (6,1.6203703703703703e-2),
(5,7.716049382716049e-3), (4,3.0864197530864196e-3),
(3,7.716049382716049e-4)]
One thing I was hoping to get didn't work out. I had this idea that
I'd be able to calculate the outcome of a game of craps like this:
dice = liftM2 (+) (d 6) (d 6)
point n = do
roll <- dice
case roll of 7 -> return "lose"
_ | roll == n = "win"
_ | otherwise = point n
craps = do
roll <- dice
case roll of 2 -> return "lose"
3 -> return "lose"
4 -> point 4
5 -> point 5
6 -> point 6
7 -> return "win"
8 -> point 8
9 -> point 9
10 -> point 10
11 -> return "win"
12 -> return "lose"
This doesn't work at all;
point is an infinite loop because
the first value of
dice, namely 2, causes a recursive call.
I might be able to do something about this, but I'll have to think
about it more.
It also occurred to me that the use of * in the definition of
>>= / join could be generalized. A couple of years
back I mentioned a
paper of Green, Karvounarakis, and Tannen that discusses
"provenance semirings". The idea is that each item in a database is
annotated with some "provenance" information about why it is there,
and you want to calculate the provenance for items in tables that are
computed from table joins. My earlier
explanation is here.
One special case of provenance information is that the provenances are
probabilities that the database information is correct, and then the
probabilities are calculated correctly for the joins, by
multiplication and addition of probabilities. But in the general case
the provenances are opaque symbols, and the multiplication and
addition construct regular expressions over these symbols. One could
generalize ProbDist similarly, and the ProbDist
monad (even more of a misnomer this time) would calculate the
provenance automatically. It occurs to me now that there's probably a
natural way to view a database table join as a sort of Kleisli
composition, but this article has gone on too long already.
Happy new year, everyone.
[ Addendum 20100103: unsurprisingly, this is not a new idea. Several
readers wrote in with references to
previous discussion of this monad, and related monads. It turns
out that the idea goes back at least to 1981. ]
My thanks to Graham Hunter for his donation.