summaryrefslogtreecommitdiff
path: root/lore/System/Random/Shuffle.hs
blob: 02cd3e00446dd009c5d13a8693fbe889d931fe87 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
{- |
Module      : System.Random.Shuffle
Copyright   : (c) 2009 Oleg Kiselyov, Manlio Perillo
License     : BSD3 (see LICENSE file)

<http://okmij.org/ftp/Haskell/perfect-shuffle.txt>


Example:

    import System.Random (newStdGen)
    import System.Random.Shuffle (shuffle')

    main = do
      rng <- newStdGen
      let xs = [1,2,3,4,5]
      print $ shuffle' xs (length xs) rng
-}
{-# OPTIONS_GHC -funbox-strict-fields #-}

module System.Random.Shuffle
  ( shuffle
  , shuffle'
  , shuffleM
  )
where

import           Data.Function                  ( fix )
import           System.Random                  ( RandomGen
                                                , randomR
                                                )
import           Control.Monad                  ( liftM
                                                , liftM2
                                                )
import           Control.Monad.Random           ( MonadRandom
                                                , getRandomR
                                                )


-- | A complete binary tree, of leaves and internal nodes.
-- Internal node: Node card l r
-- where card is the number of leaves under the node.
-- Invariant: card >=2. All internal tree nodes are always full.
data Tree a = Leaf !a
            | Node !Int !(Tree a) !(Tree a)
              deriving Show


-- | Convert a sequence (e1...en) to a complete binary tree
buildTree :: [a] -> Tree a
buildTree = (fix growLevel) . (map Leaf)
 where
  growLevel _    [node] = node
  growLevel self l      = self $ inner l

  inner []             = []
  inner [e           ] = [e]
  inner (e1 : e2 : es) = e1 `seq` e2 `seq` (join e1 e2) : inner es

  join l@(Leaf _      ) r@(Leaf _      ) = Node 2 l r
  join l@(Node ct _ _ ) r@(Leaf _      ) = Node (ct + 1) l r
  join l@(Leaf _      ) r@(Node ct  _ _) = Node (ct + 1) l r
  join l@(Node ctl _ _) r@(Node ctr _ _) = Node (ctl + ctr) l r


-- |Given a sequence (e1,...en) to shuffle, and a sequence
-- (r1,...r[n-1]) of numbers such that r[i] is an independent sample
-- from a uniform random distribution [0..n-i], compute the
-- corresponding permutation of the input sequence.
shuffle :: [a] -> [Int] -> [a]
shuffle elements = shuffleTree (buildTree elements)
 where
  shuffleTree (Leaf e) [] = [e]
  shuffleTree tree (r : rs) =
    let (b, rest) = extractTree r tree in b : (shuffleTree rest rs)
  shuffleTree _ _ = error "[shuffle] called with lists of different lengths"

  -- Extracts the n-th element from the tree and returns
  -- that element, paired with a tree with the element
  -- deleted.
  -- The function maintains the invariant of the completeness
  -- of the tree: all internal nodes are always full.
  extractTree 0 (Node _ (Leaf e) r       ) = (e, r)
  extractTree 1 (Node 2 (Leaf l) (Leaf r)) = (r, Leaf l)
  extractTree n (Node c (Leaf l) r) =
    let (e, r') = extractTree (n - 1) r in (e, Node (c - 1) (Leaf l) r')

  extractTree n (Node n' l (Leaf e)) | n + 1 == n' = (e, l)

  extractTree n (Node c l@(Node cl _ _) r)
    | n < cl
    = let (e, l') = extractTree n l in (e, Node (c - 1) l' r)
    | otherwise
    = let (e, r') = extractTree (n - cl) r in (e, Node (c - 1) l r')
  extractTree _ _ = error "[extractTree] impossible"

-- |Given a sequence (e1,...en) to shuffle, its length, and a random
-- generator, compute the corresponding permutation of the input
-- sequence.
shuffle' :: RandomGen gen => [a] -> Int -> gen -> [a]
shuffle' elements len = shuffle elements . rseq len
 where
      -- The sequence (r1,...r[n-1]) of numbers such that r[i] is an
      -- independent sample from a uniform random distribution
      -- [0..n-i]
  rseq :: RandomGen gen => Int -> gen -> [Int]
  rseq n = fst . unzip . rseq' (n - 1)
   where
    rseq' :: RandomGen gen => Int -> gen -> [(Int, gen)]
    rseq' 0 _   = []
    rseq' i gen = (j, gen) : rseq' (i - 1) gen'
      where (j, gen') = randomR (0, i) gen

-- |shuffle' wrapped in a random monad
shuffleM :: (MonadRandom m) => [a] -> m [a]
shuffleM elements
  | null elements = return []
  | otherwise     = liftM (shuffle elements) (rseqM (length elements - 1))
 where
  rseqM :: (MonadRandom m) => Int -> m [Int]
  rseqM 0 = return []
  rseqM i = liftM2 (:) (getRandomR (0, i)) (rseqM (i - 1))