This is the first in a series of posts about program derivation. In particular, I am attempting to derive a matrix multiplication algorithm that runs efficiently on parallel architectures such as GPUs.
As I mentioned in an earlier post, I’ve been contributing to the Accelerate project. The Accelerate EDSL defines various parallel primitives such as map
, fold
, and scan
(and many more).
The scan
primitive (also known as all-prefix-sums) is quite famous because it is useful in a wide range of parallel algorithms and, at first glance, one could be forgiven for thinking it is not amenable to parallelisation. However, a well-known work efficient algorithm for scan
was popularised by Guy Blelloch which performs \(O(n)\) work.
The algorithm is undeniably clever. Looking at it, it is not at all obvious how one might have gone about developing it oneself. A recent series of posts by Conal Elliott set out to redress this situation by attempting to derive the algorithm from a high level specification. His success has inspired me to follow a similar process to derive a work efficient matrix multiplication algorithm.
The process I am following is roughly as follows:
generalise the concept of matrix multiplication to data structures other than lists or arrays.
develop a generic implementation that relies, as far as possible, on reusable algebraic machinery in type classes such as Functor, Applicative, Foldable and Traversable.
use this generic implementation as a specification to derive an efficient algorithm. To call it a hunch that the underlying data structure is going to be tree-like is an understatement.
This post is a preamble. It develops a generic dot product implementation that will serve as a specification for the derivation of an efficient algorithm in a later post.
Background
In order to understand this post I highly recommend that you read Conor McBride and Ross Paterson’s paper: Applicative programming with effects. A basic grasp of linear algebra would also be helpful.
What is a dot product?
In mathematics the dot product is usually defined on vectors. Given two vectors of length \(n\), \((a_1, \dots, a_n)\) and \((b_1, \dots, b_n)\) the dot product is defined as:
\(a_1 b_1 + \dots + a_n b_n\)
or without the use of the pernicious “\(\dots\)”:
\(\sum_{i=1}^{n}{a_i b_i}\)
The implementation for lists is fairly straightforward.
dot :: Num a => [a] -> [a] -> a
dot xs ys = foldl (+) 0 (zipWith (*) xs ys)
This definition will work just fine on two lists of different length, owing to the definition of zipWith
.
zipWith :: (a -> b -> c) -> [a] -> [b] -> [c]
zipWith f [] _ = []
zipWith f _ [] = []
zipWith f (x:xs) (y:ys) = f x y : zipWith f xs ys
This is fine for lists but will become problematic later when we look at other data structures.
There is no reason that this definition shouldn’t be extended to other data structures such as \(n\)-dimensional arrays or even trees. Let’s look at how we might define dot products on trees.
Dot product on Tree
s
We define trees as follows (however it does not really matter whether only the leaves contain numbers or whether branch nodes can too):
data Tree a = Leaf a | Branch (Tree a) (Tree a)
For the sake of succinctness, I will represent trees using nested pairs denoted with curly braces. e.g. Branch (Leaf 1) (Leaf 2)
becomes {1,2}
, Branch (Leaf 1) (Branch (Leaf 2) (Leaf 3))
becomes {1,{2,3}}
.
What should be the dot product of {1,{2,3}}
and {4,{5,6}}
? A reasonable answer would be 1 * 4 + 2 * 5 + 3 * 6 == 32
. For each leaf in the first tree find the corresponding leaf in the second tree, multiply them together and then sum all the results together.
This definition relies on the two trees having the same shape. To see why let’s see if we can we define a function in the style of zipWith
for trees. Unfortunately, this is problematic.
zipWithT f (Leaf a) (Leaf b) = Leaf (f a b)
zipWithT f (Branch s t) (Branch s' t') = Branch (zipWithT f s s')
(zipWithT f t t')
zipWithT f (Leaf a) (Branch s' t') = {- ? -} undefined
zipWithT f (Branch s t) (Leaf b) = {- ? -} undefined
There’s a problem with the last two cases. While I won’t go so far as to say that there is no definition we could provide, it’s clear that there are a number of choices that could be taken. In each case one needs to take an arbitrary element from the Branch
argument and apply function f
to it and the Leaf
argument.
Even if there is a definition that makes reasonable sense can we say whether it’s possible to provide a zipWith
-like definition for an arbitrary data structure?
An alternative is to modify our data structures to contain types that represent the shape of the data structure. We can then define dot
such that it must take two arguments of exactly the same shape.
Data structures with shapes
I’ll illustrate this approach with vectors first, before moving onto trees. Vectors are just lists with their length encode into their type.
Vectors
First, we add some essentials to the top of our module.
{-# LANGUAGE GADTs, EmptyDataDecls, FlexibleInstances, DeriveFunctor, DeriveFoldable #-}
{-# LANGUAGE ScopedTypeVariables, FlexibleContexts, UndecidableInstances #-}
Now we define two new data types, Z
and S
, representing Peano numbers. Both data types are empty since we will never be using their values.
data Z
data S n
Now vectors.
infixr 5 `Cons`
data Vec n a where
Nil :: Vec Z a
Cons :: a -> Vec n a -> Vec (S n) a
If you haven’t seen these data types before it’s worth noting that you can define total (vs partial) versions of head
and tail
. Trying to take the head of an empty vector simply doesn’t type check.
headVec :: Vec (S n) a -> a
headVec (Cons x _) = x
tailVec :: Vec (S n) a -> Vec n a
tailVec (Cons _ xs) = xs
With can now define zipWithV
.
zipWithV :: (a -> b -> c) -> Vec n a -> Vec n b -> Vec n c
zipWithV f Nil Nil = Nil
zipWithV f (Cons x xs) (Cons y ys) = f x y `Cons` zipWithV f xs ys
Unfortunately, GHC’s type checker does not detect that a case such as the one below is impossible. (In fact, if your warnings are turned up high enough GHC will warn that two patterns are missing in the definition above.)
-- Although this pattern match is impossible GHC's type checker
-- won't complain
zipWithV f (Cons x xs) Nil = {- something -} undefined
Trees
The length of a tree is not quite a meaningful enough representation. Instead we represent its shape as a nested tuples of the unit (()
) type.
data Tree sh a where
Leaf :: a -> Tree () a
Branch :: Tree m a -> Tree n a -> Tree (m,n) a
For example:
{1,{2,3}} :: Tree ((),((),())) Integer
The new definition of zipWithT
only differs in its type.
zipWithT :: (a -> b -> c) -> Tree sh a -> Tree sh b -> Tree sh c
zipWithT f (Leaf a) (Leaf b) = Leaf (f a b)
zipWithT f (Branch s t) (Branch s' t') = Branch (zipWithT f s s')
(zipWithT f t t')
Now finish off the definitions:
foldlT :: (a -> b -> a) -> a -> Tree sh b -> a
foldlT f z (Leaf a) = f z a
foldlT f z (Branch s t) = foldlT f (foldlT f z s) t
dotT :: Num a => Tree sh a -> Tree sh a -> a
dotT t1 t2 = foldlT (+) 0 (zipWithT (*) t1 t2
Generalising to arbitrary data structures
Any seasoned Haskell veteran knows the utility of type classes such as Functor
, Applicative
, and Foldable
. We have now seen that a dot product is essentially a zipWith
followed by a fold
. (It makes little difference whether its a left or right fold).
Since zipWith
is really just liftA2
(found in module Control.Applicative
) on the ZipList
data structure. This leads us to the following definition:
dot :: (Num a, Foldable f, Applicative f) => f a -> f a -> a
dot x y = foldl (+) 0 (liftA2 (*) x y)
This function requires instances for Functor
, Foldable
and Applicative
. Given that instances for the first two type classes are both easy to write (and in some cases derivable using Haskell’s deriving
syntax), I will only discuss Applicative
instances in this post. (The instances for vectors and shape-encoded trees are left as an exercise for the reader.)
One might reasonably wonder, must the two arguments to dot
have the same shape as before? It turns out that, yes, they do and for similar reasons. I’ll demonstrate the point by looking at how to define Applicative
instances for lists, vectors and trees.
Lists
The default Applicative
instance for lists is unsuitable for a generic dot product. However, the Applicative
instance on its wrapper type ZipList
is adequate but has an unsatisfying definition for pure
(to say the least).
instance Applicative ZipList where
pure x = ZipList (repeat x)
ZipList fs <*> ZipList xs = ZipList (zipWith id fs xs)
Of course, this is necessary for lists since we can’t guarantee that two lists of the same length will be applied together. How else would you define pure
to make it work on an arbitrary length lists xs
?
(+) <$> (pure 1) <*> (ZipList xs)
The definition of pure
is much more satisfying for vectors.
Vectors
Obviously we want a similar definition for pure
as for lists (ZipList
). But we don’t want to produce an infinite list, just one of the appropriate length.
Defining the Applicative
instance for vectors leads us to an interesting observation which holds true in general. For any data structure which encodes its own shape:
- You need one instance of
Applicative
for each constructor of the data type. - The instance heads must mirror the types of the constructors.
In the code below there are two instances and each instance head closely mirrors the data constructor’s type. e.g. Cons :: a -> Vec n a -> Vec (S n) a
mirrors instance Applicative (Vec n) => Applicative (Vec (S n))
.
instance Applicative (Vec Z) where
pure _ = Nil
Nil <*> Nil = Nil
instance Applicative (Vec n) => Applicative (Vec (S n)) where
pure a = a `Cons` pure a
(fa `Cons` fas) <*> (a `Cons` as) = fa a `Cons` (fas <*> as)
That’s it. Function pure
will produce a vector of just the right length.
Trees
Unlike the case for lists, it’s hard to define an Applicative
instance for non-shape-encoded trees. Let’s have a look.
instance Applicative Tree where
pure a = Leaf a
(Leaf fa) <*> (Leaf b) = Leaf (fa b)
(Branch fa fb) <*> (Branch a b) = Branch (fa <*> a) (fb <*> b)
(Leaf fa) <*> (Branch a b) = {- ? -} undefined
(Branch fa fb) <*> (Leaf a) = {- ? -} undefined
This problem has been noticed before on the Haskell-beginners mailing list. The response is interesting because it mentions the “unpleasant property of returning infinite tree[s]”; the same problem we had with lists!
With shape-encoded trees this is not a problem. Function pure
produces a tree of the appropriate shape. Also, note how the head of the second instance mirrors the definition of the Branch
constructor (:: Tree m a -> Tree n a -> Tree (m,n) a)
instance Applicative (Tree ()) where
pure a = Leaf a
Leaf fa <*> Leaf a = Leaf (fa a)
instance (Applicative (Tree m), Applicative (Tree n))
=> Applicative (Tree (m,n)) where
pure a = Branch (pure a) (pure a)
(Branch fs ft) <*> (Branch s t) = Branch (fs <*> s) (ft <*> t)
Arbitrary binary associative operators.
Phew, that’s it. We now have an implementation for dot
that will work on an arbitrary data structure as long as one can define Functor
, Foldable
and Applicative
instances. We have also learned that it is a good idea to encode the data structure’s shape in its type so that Applicative
instances can be defined. (This will be important later on when we want to take the transpose of generic matrices, but I’m getting ahead of myself.)
But what if you want to use binary associative operators other than addition and multiplication for the dot product? This is easy using Haskell’s Monoid
type class, and it plays nicely with the Foldable
type class. In fact, it allows us to omit any mention of identity elements using the method fold:: (Foldable t, Monoid m) => t m -> m
. We define an even more generic dot product as follows:
dotGen :: (Foldable f, Applicative f, Monoid p, Monoid s)
=> (a -> p, p -> a) -> (a -> s, s-> a) -> f a -> f a -> a
dotGen (pinject, pproject) (sinject, sproject) x y =
sproject . fold . fmap (sinject . pproject) $ liftA2 mappend px py
where
px = fmap pinject x
py = fmap pinject y
This function takes two pairs of functions for injecting into and projecting from monoids. We can then define our original dot
function using the existing Sum
and Product
wrapper types.
dot :: (Num a, Foldable f, Applicative f) => f a -> f a -> a
dot = dotGen (Product, getProduct) (Sum, getSum)
In the next episode…
In my next post we will consider generic matrix multiplication. This operation is defined over arbitrary collections of collections of numbers and, naturally, makes use of our generic dot product. Until then, adios.
Slides
On 17 Nov 2011 I gave a talk at fp-syd about this work. You can find the slides here.