Simple FFT in Haskell

The article develops a simple implementation of the fast Fourier transform in Haskell.

Raw performance of the algorithm is explicitly not a goal here; for instance, I use things like nub , Writer , and lists for simplicity. On the other hand, I do pay attention to the algorithmic complexity in terms of the number of arithmetic operations performed; the analysis thereof will be done in a subsequent article.

Background

Discrete Fourier transform turns \(n\) complex numbers \(a_0,a_1,\ldots,a_{n-1}\) into \(n\) complex numbers

\[f_k = \sum_{l=0}^{n-1} e^{- 2 \pi i k l / n} a_l.\]

An alternative way to think about \(f_k\) is as the values of the polynomial

\[P(x)=\sum_{l=0}^{n-1} a_l x^l\]

at \(n\) points \(w^0,w^1,\ldots,w^{n-1}\), where \(w=e^{-2 \pi i / n}\) is a certain \(n\)th primitive root of unity.

The naive calculation requires \(\Theta(n^2)\) operations; our goal is to reduce that number to \(\Theta(n \log n)\).

An excellent explanation of the algorithm (which inspired this article in the first place) is given by Daniel Gusfield in his video lectures; he calls it “the most important algorithm that most computer scientists have never studied”. You only need to watch the first two lectures (and maybe the beginning of the third one) to understand the algorithm and this article.

Roots of unity

Roots of unity could in principle be represented in the Cartesian form by the Complex a type. However, that would make it very hard to compare them for equality, which we are going to do to achieve a subquadratic complexity.

So here’s a small module just for representing these special complex numbers in the polar form, taking advantage of the fact that their absolute values are always 1 and their phases are rational multiples of \(\pi\).

module RootOfUnity ( U -- abstract , mkU , toComplex , u_pow , u_sqr ) where import Data.Complex -- | U q corresponds to the complex number exp(2 i pi q) newtype U = U Rational deriving ( Show , Eq , Ord ) -- | Convert a U number to the equivalent complex number toComplex :: Floating a => U -> Complex a U q) = mkPolar 1 ( 2 * pi * realToFrac q) toComplex (q)mkPolarq) -- | Smart constructor for U numbers; automatically performs normalization mkU :: Rational -> U = U (q - realToFrac ( floor q)) mkU q(qq)) -- | Raise a U number to a power u_pow :: U -> Integer -> U U q) p = mkU ( fromIntegral p * q) u_pow (q) pmkU (q) -- | Square a U number u_sqr :: U -> U = u_pow x 2 u_sqr xu_pow x

Fast Fourier transform

{-# LANGUAGE ScopedTypeVariables #-} module FFT (fft) where (fft) import Data.Complex import Data.Ratio import Data.Monoid import qualified Data.Map as Map import Data.List import Data.Bifunctor import Control.Monad.Trans.Writer import RootOfUnity

So we want to evaluate the polynomial \(P(x)=\sum_{l=0}^{n-1}a_lx^l\) at points \(w^k\). The trick is to represent \(P(x)\) as \(A_e(x^2) + x A_o(x^2)\), where \(A_e(x)=a_0+a_2 x + \ldots\) and \(A_o(x)=a_1+a_3 x + \ldots\) are polynomials constructed out of the even-numbered and odd-numbered coefficients of \(P\), respectively.

When \(x\) is a root of unity, so is \(x^2\); this allows us to apply the algorithm recursively to evaluate \(A_e\) and \(A_o\) for the squared numbers.

But the real boon comes when \(n\) is even; then there will be half as many of these squared numbers, because \(w^k\) and \(w^{k+n/2}\), when squared, both give the same number \(w^{2k}\). This is when the divide and conquer strategy really pays off.

We will represent a polynomial \(\sum_{l=0}^{n-1}a_lx^l\) in Haskell as a list of coefficients [a_0,a_1,...] , starting with \(a_0\).

To be able to split a polynomial into the even and odd parts, let’s define a corresponding list function

split :: [a] -> ([a], [a]) [a]([a], [a]) = foldr f ([], []) splitf ([], []) where = (a : r2, r1) f a (r1, r2)(ar2, r1)

(I think I learned the idea of this elegant implementation from Dominic Steinitz.)

Now, the core of the algorithm: a function that evaluates a polynomial at a given list of points on the unit circle. It tracks the number of performed arithmetic operations through a Writer monad over the Sum monoid.

evalFourier :: forall a . RealFloat a => [ Complex a] -- ^ polynomial coefficients, starting from a_0 a] -> [ U ] -- ^ points at which to evaluate the polynomial -> Writer ( Sum Int ) [ Complex a] ) [a]

If the polynomial is a constant, there’s not much to calculate. This is our base case.

= return $ 0 <$ pts evalFourier [] ptspts = return $ c <$ pts evalFourier [c] ptspts

Otherwise, use the recursive algorithm outlined above.

= do evalFourier coeffs pts let = nub $ u_sqr <$> pts -- values of x^2 squaresnubu_sqrpts = split coeffs (even_coeffs, odd_coeffs)split coeffs <- evalFourier even_coeffs squares even_valuesevalFourier even_coeffs squares <- evalFourier odd_coeffs squares odd_valuesevalFourier odd_coeffs squares let -- a mapping from x^2 to (A_e(x^2), A_o(x^2)) = square_map Map.fromList . zip squares squares $ zip even_values odd_values even_values odd_values -- evaluate the polynomial at a single point eval1 :: U -> Writer ( Sum Int ) ( Complex a) ) (a) = do eval1 x let (ye,yo) = (square_map Map.! u_sqr x) (ye,yo)(square_mapu_sqr x) r = ye + toComplex x * yo yetoComplex xyo $ Sum 2 -- this took two arithmetic operations tell return r mapM eval1 pts eval1 pts

The actual FFT function is a simple wrapper around evalFourier which substitutes the specific points and performs some simple conversions. It returns the result of the DFT and the number of operations performed.