diff --git a/System/Random/MWC.hs b/System/Random/MWC.hs index cdd946a..6c01e36 100644 --- a/System/Random/MWC.hs +++ b/System/Random/MWC.hs @@ -1,6 +1,7 @@ {-# LANGUAGE BangPatterns, CPP, DeriveDataTypeable, FlexibleContexts, MagicHash, Rank2Types, ScopedTypeVariables, TypeFamilies, UnboxedTuples, ForeignFunctionInterface #-} + -- | -- Module : System.Random.MWC -- Copyright : (c) 2009-2012 Bryan O'Sullivan @@ -90,6 +91,9 @@ module System.Random.MWC , save , restore + -- * Fold + , foldMUniforms + -- * References -- $references ) where @@ -119,7 +123,7 @@ import Foreign.Marshal.Alloc (allocaBytes) import Foreign.Marshal.Array (peekArray) import qualified Data.Vector.Generic as G import qualified Data.Vector.Unboxed as I -import qualified Data.Vector.Unboxed.Mutable as M +import Data.Primitive.ByteArray import System.CPUTime (cpuTimePrecision, getCPUTime) import System.IO (IOMode(..), hGetBuf, hPutStrLn, stderr, withBinaryFile) import System.IO.Unsafe (unsafePerformIO) @@ -156,7 +160,7 @@ class Variate a where -- 2**(-33). To do the same with 'Double' variates, subtract -- 2**(-53). uniform :: (PrimMonad m) => Gen (PrimState m) -> m a - -- | Generate single uniformly distributed random variable in a + -- | Generate a single uniformly distributed random variable in a -- given range. -- -- * For integral types inclusive range is used. @@ -313,7 +317,7 @@ wordsToDouble x y = (fromIntegral u * m_inv_32 + (0.5 + m_inv_53) + -- | State of the pseudo-random number generator. It uses mutable -- state so same generator shouldn't be used from the different -- threads simultaneously. -newtype Gen s = Gen (M.MVector s Word32) +newtype Gen s = Gen (MutableByteArray s) -- | A shorter name for PRNG state in the 'IO' monad. type GenIO = Gen (PrimState IO) @@ -362,19 +366,19 @@ create = initialize defaultSeed initialize :: (PrimMonad m, Vector v Word32) => v Word32 -> m (Gen (PrimState m)) initialize seed = do - q <- M.unsafeNew 258 + q <- mkAlignedByteArray fill q if fini == 258 then do - M.unsafeWrite q ioff $ G.unsafeIndex seed ioff .&. 255 - M.unsafeWrite q coff $ G.unsafeIndex seed coff + writeByteArray q ioff $ G.unsafeIndex seed ioff .&. 255 + writeByteArray q coff $ G.unsafeIndex seed coff else do - M.unsafeWrite q ioff 255 - M.unsafeWrite q coff 362436 + writeByteArray q ioff (255 :: Word32) + writeByteArray q coff (362436 :: Word32) return (Gen q) where fill q = go 0 where go i | i == 256 = return () - | otherwise = M.unsafeWrite q i s >> go (i+1) + | otherwise = writeByteArray q i s >> go (i+1) where s | i >= fini = if fini == 0 then G.unsafeIndex defaultSeed i else G.unsafeIndex defaultSeed i `xor` @@ -396,16 +400,43 @@ newtype Seed = Seed { -- -- > restore (toSeed v) = initialize v toSeed :: (Vector v Word32) => v Word32 -> Seed -toSeed v = Seed $ I.create $ do { Gen q <- initialize v; return q } +toSeed v = + Seed $ I.create $ do + Gen q <- initialize v + unsafeFreezeByteArray q >>= I.unsafeThaw . byteArrayToVector + +byteArrayToVector :: (Vector v Word32) => ByteArray -> v Word32 +byteArrayToVector q = G.fromList $ + let nWord32 = quot (sizeofByteArray q) SIZEOF_WORD32 + in map (indexByteArray q) [0..nWord32-1] + +vectorToByteArray :: (Vector v Word32, PrimMonad m) => v Word32 -> m (MutableByteArray (PrimState m)) +vectorToByteArray v = do + b <- mkAlignedByteArray + mapM_ (uncurry $ writeByteArray b) $ zip [0..] $ G.toList v + return b + +mkAlignedByteArray :: PrimMonad m => m (MutableByteArray (PrimState m)) +mkAlignedByteArray = + -- The indexes ioff and coff (256,257) are read and written to an order of magnitude more + -- than other indexes, and always consecutively. Hence, it's important that the + -- corresponding memory sits on the same cache line. We also want the overall array to + -- use the least count of cache lines. + -- + -- Assuming 64 bytes cache lines, a 64 bytes alignment meets the aforementionned + -- requirements. + newAlignedPinnedByteArray (258 * SIZEOF_WORD32) 64 -- | Save the state of a 'Gen', for later use by 'restore'. save :: PrimMonad m => Gen (PrimState m) -> m Seed -save (Gen q) = Seed `liftM` G.freeze q +-- its' ok to unsafeFreezeByteArray here because byteArrayToVector will not return +-- any of its memory +save (Gen q) = Seed . byteArrayToVector <$> unsafeFreezeByteArray q {-# INLINE save #-} -- | Create a new 'Gen' that mirrors the state of a saved 'Seed'. restore :: PrimMonad m => Seed -> m (Gen (PrimState m)) -restore (Seed s) = Gen `liftM` G.thaw s +restore (Seed s) = Gen <$> vectorToByteArray s {-# INLINE restore #-} @@ -520,19 +551,25 @@ aa :: Word64 aa = 1540315826 {-# INLINE aa #-} +{-# INLINE read32 #-} +read32 :: PrimMonad m => MutableByteArray (PrimState m) -> Int -> m Word32 +read32 b i = + readByteArray b i + + uniformWord32 :: PrimMonad m => Gen (PrimState m) -> m Word32 uniformWord32 (Gen q) = do - i <- nextIndex `liftM` M.unsafeRead q ioff - c <- fromIntegral `liftM` M.unsafeRead q coff - qi <- fromIntegral `liftM` M.unsafeRead q i + i <- nextIndex `liftM` read32 q ioff + c <- fromIntegral `liftM` read32 q coff + qi <- fromIntegral `liftM` read32 q i let t = aa * qi + c c' = fromIntegral (t `shiftR` 32) x = fromIntegral t + c' (# x', c'' #) | x < c' = (# x + 1, c' + 1 #) | otherwise = (# x, c' #) - M.unsafeWrite q i x' - M.unsafeWrite q ioff (fromIntegral i) - M.unsafeWrite q coff (fromIntegral c'') + writeByteArray q i x' + writeByteArray q ioff (fromIntegral i :: Word32) + writeByteArray q coff c'' return x' {-# INLINE uniformWord32 #-} @@ -544,11 +581,11 @@ uniform1 f gen = do uniform2 :: PrimMonad m => (Word32 -> Word32 -> a) -> Gen (PrimState m) -> m a uniform2 f (Gen q) = do - i <- nextIndex `liftM` M.unsafeRead q ioff + i <- nextIndex `liftM` read32 q ioff let j = nextIndex i - c <- fromIntegral `liftM` M.unsafeRead q coff - qi <- fromIntegral `liftM` M.unsafeRead q i - qj <- fromIntegral `liftM` M.unsafeRead q j + c <- fromIntegral `liftM` read32 q coff + qi <- fromIntegral `liftM` read32 q i + qj <- fromIntegral `liftM` read32 q j let t = aa * qi + c c' = fromIntegral (t `shiftR` 32) x = fromIntegral t + c' @@ -559,13 +596,68 @@ uniform2 f (Gen q) = do y = fromIntegral u + d' (# y', d'' #) | y < d' = (# y + 1, d' + 1 #) | otherwise = (# y, d' #) - M.unsafeWrite q i x' - M.unsafeWrite q j y' - M.unsafeWrite q ioff (fromIntegral j) - M.unsafeWrite q coff (fromIntegral d'') + writeByteArray q i x' + writeByteArray q j y' + writeByteArray q ioff (fromIntegral j :: Word32) + writeByteArray q coff d'' return $! f x' y' {-# INLINE uniform2 #-} +data AccumWithUniforms a = AWM { + _coeff :: {-# UNPACK #-} !Word32 + , _index :: {-# UNPACK #-} !Int + , _accumulator :: !a +} + +-- | Fold-like function allowing to consume random numbers efficiently produced +-- with a minimal number of reads and writes to the state vector. +-- +-- To generate @n@ numbers, this function does @n + 2@ reads and @n + 2@ writes. +foldMUniforms :: PrimMonad m + => Int + -- ^ How many 'Word32' should be generated + -> (a -> Word32 -> m a) + -- ^ The accumulating function + -> a + -- ^ The accumulator's initial value + -> Gen (PrimState m) + -- ^ The RNG + -> m a +foldMUniforms n f acc0 (Gen q) = do + i0 <- fromIntegral <$> read32 q ioff + c0 <- fromIntegral <$> read32 q coff + + let accum (AWM cPrev iPrev accPrev) = do + let i = nextIndex iPrev + qi <- fromIntegral <$> read32 q i + let t = aa * qi + fromIntegral cPrev + c' = fromIntegral (t `shiftR` 32) + x = fromIntegral t + c' + (# x', c'' #) | x < c' = (# x + 1, c' + 1 #) + | otherwise = (# x, c' #) + writeByteArray q i x' + AWM c'' i <$> f accPrev x' + + (AWM cF iF accF) <- iterateNM accum (AWM c0 i0 acc0) n + + writeByteArray q ioff (fromIntegral iF :: Word32) + writeByteArray q coff cF + + return accF + +{-# INLINE foldMUniforms #-} + +-- Equivalent to @foldM (\_ -> f) a [0..n-1]@. +iterateNM :: Monad m => (a -> m a) -> a -> Int -> m a +iterateNM f a0 n0 = + go n0 a0 + where + go 0 !a = return a + go n a = f a >>= go (n-1) + +{-# INLINE iterateNM #-} + + -- Type family for fixed size integrals. For signed data types it's -- its unsigned couterpart with same size and for unsigned data types -- it's same type diff --git a/benchmarks/mwc-random-benchmarks.cabal b/benchmarks/mwc-random-benchmarks.cabal index c4d0dac..989a61e 100644 --- a/benchmarks/mwc-random-benchmarks.cabal +++ b/benchmarks/mwc-random-benchmarks.cabal @@ -15,4 +15,5 @@ executable bm criterion, mersenne-random, mwc-random, - random + random, + vector