Skip to content

Commit 44f6446

Browse files
committed
Use a fake GADT for sequence folds and traversals
1 parent 97af0e8 commit 44f6446

File tree

4 files changed

+260
-15
lines changed

4 files changed

+260
-15
lines changed

containers-tests/containers-tests.cabal

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ library
106106
Data.Map.Strict.Internal
107107
Data.Sequence
108108
Data.Sequence.Internal
109+
Data.Sequence.Internal.Depth
109110
Data.Sequence.Internal.Sorting
110111
Data.Set
111112
Data.Set.Internal

containers/containers.cabal

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ Library
7575
Data.Graph
7676
Data.Sequence
7777
Data.Sequence.Internal
78+
Data.Sequence.Internal.Depth
7879
Data.Sequence.Internal.Sorting
7980
Data.Tree
8081

containers/src/Data/Sequence/Internal.hs

Lines changed: 138 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
{- OPTIONS_GHC -ddump-simpl #-}
12
{-# LANGUAGE CPP #-}
23
#include "containers.h"
34
{-# LANGUAGE BangPatterns #-}
@@ -7,6 +8,7 @@
78
{-# LANGUAGE DeriveLift #-}
89
{-# LANGUAGE StandaloneDeriving #-}
910
{-# LANGUAGE FlexibleInstances #-}
11+
{-# LANGUAGE GADTs #-}
1012
{-# LANGUAGE InstanceSigs #-}
1113
{-# LANGUAGE ScopedTypeVariables #-}
1214
{-# LANGUAGE TemplateHaskellQuotes #-}
@@ -177,6 +179,7 @@ module Data.Sequence.Internal (
177179
node2,
178180
node3,
179181
#endif
182+
bongo
180183
) where
181184

182185
import Utils.Containers.Internal.Prelude hiding (
@@ -194,7 +197,7 @@ import Control.Applicative ((<$>), (<**>), Alternative,
194197
import qualified Control.Applicative as Applicative
195198
import Control.DeepSeq (NFData(rnf),NFData1(liftRnf))
196199
import Control.Monad (MonadPlus(..))
197-
import Data.Monoid (Monoid(..))
200+
import Data.Monoid (Monoid(..), Endo(..), Dual(..))
198201
import Data.Functor (Functor(..))
199202
import Utils.Containers.Internal.State (State(..), execState)
200203
import Data.Foldable (foldr', toList)
@@ -234,6 +237,7 @@ import Data.Functor.Identity (Identity(..))
234237
import Utils.Containers.Internal.StrictPair (StrictPair (..), toPair)
235238
import Control.Monad.Zip (MonadZip (..))
236239
import Control.Monad.Fix (MonadFix (..), fix)
240+
import Data.Sequence.Internal.Depth (Depth_ (..), Depth2_ (..))
237241

238242
default ()
239243

@@ -378,16 +382,38 @@ fmapSeq f (Seq xs) = Seq (fmap (fmap f) xs)
378382
#-}
379383
#endif
380384

385+
--type Depth = Depth_ Elem Node
386+
type Depth = Depth_ Node
387+
type Depth2 = Depth2_ Node
388+
381389
instance Foldable Seq where
382390
#ifdef __GLASGOW_HASKELL__
383391
foldMap :: forall m a. Monoid m => (a -> m) -> Seq a -> m
384-
foldMap = coerce (foldMap :: (Elem a -> m) -> FingerTree (Elem a) -> m)
392+
foldMap f (Seq t0) = foldMapFT Bottom t0
393+
where
394+
foldMapBlob :: Depth (Elem a) t -> t -> m
395+
foldMapBlob Bottom (Elem a) = f a
396+
foldMapBlob (Deeper w) (Node2 _ x y) = foldMapBlob w x <> foldMapBlob w y
397+
foldMapBlob (Deeper w) (Node3 _ x y z) = foldMapBlob w x <> foldMapBlob w y <> foldMapBlob w z
398+
399+
foldMapFT :: Depth (Elem a) t -> FingerTree t -> m
400+
foldMapFT !_ EmptyT = mempty
401+
foldMapFT w (Single t) = foldMapBlob w t
402+
foldMapFT w (Deep _ pr m sf) =
403+
foldMap (foldMapBlob w) pr
404+
<> foldMapFT (Deeper w) m
405+
<> foldMap (foldMapBlob w) sf
385406

386407
foldr :: forall a b. (a -> b -> b) -> b -> Seq a -> b
387-
foldr = coerce (foldr :: (Elem a -> b -> b) -> b -> FingerTree (Elem a) -> b)
408+
-- We define this explicitly so we can inline the foldMap. And we don't
409+
-- define it as a coercion of the FingerTree version because we want users
410+
-- to have the option of (effectively) inlining it explicitly.
411+
foldr f z t = appEndo (GHC.Exts.inline foldMap (coerce f) t) z
388412

389413
foldl :: forall b a. (b -> a -> b) -> b -> Seq a -> b
390-
foldl = coerce (foldl :: (b -> Elem a -> b) -> b -> FingerTree (Elem a) -> b)
414+
-- Should we define this by hand to associate optimally? Or is GHC
415+
-- clever enough to do that for us?
416+
foldl f z t = appEndo (getDual (GHC.Exts.inline foldMap (Dual . Endo . flip f) t)) z
391417

392418
foldr' :: forall a b. (a -> b -> b) -> b -> Seq a -> b
393419
foldr' = coerce (foldr' :: (Elem a -> b -> b) -> b -> FingerTree (Elem a) -> b)
@@ -426,7 +452,37 @@ instance Foldable Seq where
426452
instance Traversable Seq where
427453
#if __GLASGOW_HASKELL__
428454
{-# INLINABLE traverse #-}
429-
#endif
455+
traverse :: forall f a b. Applicative f => (a -> f b) -> Seq a -> f (Seq b)
456+
traverse f (Seq t0) = Seq <$> traverseFT Bottom2 t0
457+
where
458+
traverseFT :: Depth2 (Elem a) t (Elem b) u -> FingerTree t -> f (FingerTree u)
459+
traverseFT !_ EmptyT = pure EmptyT
460+
traverseFT w (Single t) = Single <$> traverseBlob w t
461+
traverseFT w (Deep s pr m sf) = liftA3 (Deep s)
462+
(traverse (traverseBlob w) pr)
463+
(traverseFT (Deeper2 w) m)
464+
(traverse (traverseBlob w) sf)
465+
466+
-- Traverse a 2-3 tree, given its height.
467+
traverseBlob :: Depth2 (Elem a) t (Elem b) u -> t -> f u
468+
traverseBlob Bottom2 (Elem a) = Elem <$> f a
469+
470+
-- We have a special case here to avoid needing to `fmap Elem` over
471+
-- each of the leaves, in case that's not free in the relevant functor.
472+
-- We still end up using extra fmaps for the very first level of the
473+
-- FingerTree and the Seq constructor. While we *could* avoid that,
474+
-- doing so requires a good bit of extra code to save *at most* nine
475+
-- fmap applications for the sequence. It would also save on Depth
476+
-- comparisons, but I doubt that matters very much.
477+
traverseBlob (Deeper2 Bottom2) (Node2 s (Elem x) (Elem y))
478+
= liftA2 (\x' y' -> Node2 s (Elem x') (Elem y')) (f x) (f y)
479+
traverseBlob (Deeper2 Bottom2) (Node3 s (Elem x) (Elem y) (Elem z))
480+
= liftA3 (\x' y' z' -> Node3 s (Elem x') (Elem y') (Elem z'))
481+
(f x) (f y) (f z)
482+
483+
traverseBlob (Deeper2 w) (Node2 s x y) = liftA2 (Node2 s) (traverseBlob w x) (traverseBlob w y)
484+
traverseBlob (Deeper2 w) (Node3 s x y z) = liftA3 (Node3 s) (traverseBlob w x) (traverseBlob w y) (traverseBlob w z)
485+
#else
430486
traverse _ (Seq EmptyT) = pure (Seq EmptyT)
431487
traverse f' (Seq (Single (Elem x'))) =
432488
(\x'' -> Seq (Single (Elem x''))) <$> f' x'
@@ -498,6 +554,7 @@ instance Traversable Seq where
498554
:: Applicative f
499555
=> (Node a -> f (Node b)) -> Node (Node a) -> f (Node (Node b))
500556
traverseNodeN f t = traverse f t
557+
#endif
501558

502559
instance NFData a => NFData (Seq a) where
503560
rnf (Seq xs) = rnf xs
@@ -1067,7 +1124,33 @@ instance Sized a => Sized (FingerTree a) where
10671124
size (Single x) = size x
10681125
size (Deep v _ _ _) = v
10691126

1127+
-- We don't fold FingerTrees directly, but instead coerce them to
1128+
-- Seqs and fold those. This seems backwards! Why do it? We certainly
1129+
-- *could* fold FingerTrees directly, but we'd need a slightly different
1130+
-- version of the Depth GADT to do so. While that's not a big deal,
1131+
-- it is a bit annoying. Note: we need the current version of Depth
1132+
-- to deal with the Sized issues for indexed folds.
10701133
instance Foldable FingerTree where
1134+
#ifdef __GLASGOW_HASKELL__
1135+
foldMap :: forall m a. Monoid m => (a -> m) -> FingerTree a -> m
1136+
foldMap f = foldMapFT Bottom
1137+
where
1138+
foldMapBlob :: Depth a t -> t -> m
1139+
foldMapBlob Bottom a = f a
1140+
foldMapBlob (Deeper w) (Node2 _ x y) = foldMapBlob w x <> foldMapBlob w y
1141+
foldMapBlob (Deeper w) (Node3 _ x y z) = foldMapBlob w x <> foldMapBlob w y <> foldMapBlob w z
1142+
1143+
foldMapFT :: Depth a t -> FingerTree t -> m
1144+
foldMapFT !_ EmptyT = mempty
1145+
foldMapFT w (Single t) = foldMapBlob w t
1146+
foldMapFT w (Deep _ pr m sf) =
1147+
foldMap (foldMapBlob w) pr
1148+
<> foldMapFT (Deeper w) m
1149+
<> foldMap (foldMapBlob w) sf
1150+
1151+
-- foldMap = coerce (foldMap :: (a -> m) -> Seq a -> m)
1152+
{-# INLINABLE foldMap #-}
1153+
#else
10711154
foldMap _ EmptyT = mempty
10721155
foldMap f' (Single x') = f' x'
10731156
foldMap f' (Deep _ pr' m' sf') =
@@ -1094,8 +1177,6 @@ instance Foldable FingerTree where
10941177

10951178
foldMapNodeN :: Monoid m => (Node a -> m) -> Node (Node a) -> m
10961179
foldMapNodeN f t = foldNode (<>) f t
1097-
#if __GLASGOW_HASKELL__
1098-
{-# INLINABLE foldMap #-}
10991180
#endif
11001181

11011182
foldr _ z' EmptyT = z'
@@ -1265,7 +1346,7 @@ foldDigit _ f (One a) = f a
12651346
foldDigit (<+>) f (Two a b) = f a <+> f b
12661347
foldDigit (<+>) f (Three a b c) = f a <+> f b <+> f c
12671348
foldDigit (<+>) f (Four a b c d) = f a <+> f b <+> f c <+> f d
1268-
{-# INLINE foldDigit #-}
1349+
{-# INLINABLE foldDigit #-}
12691350

12701351
instance Foldable Digit where
12711352
foldMap = foldDigit mappend
@@ -3234,15 +3315,56 @@ foldWithIndexNode (<+>) f s (Node3 _ a b c) = f s a <+> f sPsa b <+> f sPsab c
32343315
-- element in the sequence.
32353316
--
32363317
-- @since 0.5.8
3237-
foldMapWithIndex :: Monoid m => (Int -> a -> m) -> Seq a -> m
3318+
foldMapWithIndex :: forall m a. Monoid m => (Int -> a -> m) -> Seq a -> m
3319+
#ifdef __GLASGOW_HASKELL__
3320+
foldMapWithIndex f (Seq t) = foldMapWithIndexFT Bottom 0 t
3321+
where
3322+
foldMapWithIndexFT :: Depth (Elem a) t -> Int -> FingerTree t -> m
3323+
foldMapWithIndexFT !_ !_ EmptyT = mempty
3324+
foldMapWithIndexFT d s (Single xs) = foldMapWithIndexBlob d s xs
3325+
foldMapWithIndexFT d s (Deep _ pr m sf) = case depthSized d of { Sizzy ->
3326+
foldWithIndexDigit (<>) (foldMapWithIndexBlob d) s pr <>
3327+
foldMapWithIndexFT (Deeper d) sPspr m <>
3328+
foldWithIndexDigit (<>) (foldMapWithIndexBlob d) sPsprm sf
3329+
where
3330+
!sPspr = s + size pr
3331+
!sPsprm = sPspr + size m
3332+
}
3333+
3334+
foldMapWithIndexBlob :: Depth (Elem a) t -> Int -> t -> m
3335+
foldMapWithIndexBlob Bottom k (Elem a) = f k a
3336+
foldMapWithIndexBlob (Deeper yop) k (Node2 _s t1 t2) =
3337+
foldMapWithIndexBlob yop k t1 <>
3338+
foldMapWithIndexBlob yop (k + sizeBlob yop t1) t2
3339+
foldMapWithIndexBlob (Deeper yop) k (Node3 _s t1 t2 t3) =
3340+
foldMapWithIndexBlob yop k t1 <>
3341+
foldMapWithIndexBlob yop (k + st1) t2 <>
3342+
foldMapWithIndexBlob yop (k + st1t2) t3
3343+
where
3344+
st1 = sizeBlob yop t1
3345+
st1t2 = st1 + sizeBlob yop t2
3346+
{-# INLINABLE foldMapWithIndex #-}
3347+
3348+
data Sizzy a where
3349+
Sizzy :: Sized a => Sizzy a
3350+
3351+
depthSized :: Depth (Elem a) t -> Sizzy t
3352+
depthSized Bottom = Sizzy
3353+
depthSized (Deeper _) = Sizzy
3354+
3355+
sizeBlob :: Depth (Elem a) t -> t -> Int
3356+
sizeBlob Bottom = size
3357+
sizeBlob (Deeper _) = size
3358+
3359+
#else
32383360
foldMapWithIndex f' (Seq xs') = foldMapWithIndexTreeE (lift_elem f') 0 xs'
32393361
where
32403362
lift_elem :: (Int -> a -> m) -> (Int -> Elem a -> m)
3241-
#ifdef __GLASGOW_HASKELL__
3363+
# ifdef __GLASGOW_HASKELL__
32423364
lift_elem g = coerce g
3243-
#else
3365+
# else
32443366
lift_elem g = \s (Elem a) -> g s a
3245-
#endif
3367+
# endif
32463368
{-# INLINE lift_elem #-}
32473369
-- We have to specialize these functions by hand, unfortunately, because
32483370
-- GHC does not specialize until *all* instances are determined.
@@ -3281,9 +3403,6 @@ foldMapWithIndex f' (Seq xs') = foldMapWithIndexTreeE (lift_elem f') 0 xs'
32813403

32823404
foldMapWithIndexNodeN :: Monoid m => (Int -> Node a -> m) -> Int -> Node (Node a) -> m
32833405
foldMapWithIndexNodeN f i t = foldWithIndexNode (<>) f i t
3284-
3285-
#if __GLASGOW_HASKELL__
3286-
{-# INLINABLE foldMapWithIndex #-}
32873406
#endif
32883407

32893408
-- | 'traverseWithIndex' is a version of 'traverse' that also offers
@@ -5036,3 +5155,7 @@ fromList2 n = execState (replicateA n (State ht))
50365155
where
50375156
ht (x:xs) = (xs, x)
50385157
ht [] = error "fromList2: short list"
5158+
5159+
{-# NOINLINE bongo #-}
5160+
bongo :: Seq [a] -> [a]
5161+
bongo xs = GHC.Exts.inline foldMap id xs
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
{-# OPTIONS_GHC -ddump-prep #-}
2+
{-# LANGUAGE GADTs #-}
3+
{-# LANGUAGE KindSignatures #-}
4+
{-# LANGUAGE PatternSynonyms #-}
5+
{-# LANGUAGE RoleAnnotations #-}
6+
{-# LANGUAGE Trustworthy #-}
7+
{-# LANGUAGE TypeOperators #-}
8+
{-# LANGUAGE ViewPatterns #-}
9+
10+
-- | This module defines efficient representations of GADTs that are shaped
11+
-- like (strict) unary natural numbers. That is, each type looks, from the
12+
-- outside, something like this:
13+
--
14+
-- @
15+
-- data NatLike ... where
16+
-- ZeroLike :: NatLike ...
17+
-- SuccLike :: !(NatLike ...) -> NatLike ...
18+
-- @
19+
--
20+
-- but in fact it is represented by a single machine word. We put these in a
21+
-- separate module to confine the highly unsafe magic used in the
22+
-- implementation.
23+
--
24+
-- Caution: Unlike the GADTs they represent, the types in this module are
25+
-- bounded by @maxBound \@Word@, and attempting to take a successor of the
26+
-- maximum bound will throw an overflow error. That's okay for our purposes
27+
-- of implementing certain functions in "Data.Sequence.Internal"—the spine
28+
-- of a well-formed sequence can only reach a length of around the word
29+
-- size, not even close to @maxBound \@Word@.
30+
31+
module Data.Sequence.Internal.Depth
32+
( Depth_ (Bottom, Deeper)
33+
, Depth2_ (Bottom2, Deeper2)
34+
) where
35+
36+
import Data.Kind (Type)
37+
import Unsafe.Coerce (unsafeCoerce)
38+
39+
-- @Depth_@ is an optimized representation of the following GADT:
40+
--
41+
-- @
42+
-- data Depth_ node a t where
43+
-- Bottom :: Depth_ node a a
44+
-- Deeper :: !(Depth_ node a t) -> Depth_ node a (node t)
45+
-- @
46+
--
47+
-- "Data.Sequence.Internal" fills in the @node@ parameter with its @Node@
48+
-- constructor; we have to be more general in this module because we don't
49+
-- have access to that.
50+
--
51+
-- @Depth_@ is represented internally as a 'Word' for performance, and the
52+
-- 'Bottom' and 'Deeper' pattern synonyms implement the above GADT interface.
53+
-- The implementation is "safe"—in the very unlikely event of arithmetic
54+
-- overflow, an error will be thrown. This decision is subject to change;
55+
-- arithmetic overflow on 64-bit systems requires somewhat absurdly long
56+
-- computations on sequences constructed with extensive amounts of internal
57+
-- sharing (e.g., using the '*>' operator repeatedly).
58+
newtype Depth_ (node :: Type -> Type) (a :: Type) (t :: Type)
59+
= Depth_ Word
60+
type role Depth_ nominal nominal nominal
61+
62+
-- | The depth is 0.
63+
pattern Bottom :: () => t ~ a => Depth_ node a t
64+
pattern Bottom <- (checkBottom -> AtBottom)
65+
where
66+
Bottom = Depth_ 0
67+
68+
-- | The depth is non-zero.
69+
pattern Deeper :: () => t ~ node t' => Depth_ node a t' -> Depth_ node a t
70+
pattern Deeper d <- (checkBottom -> NotBottom d)
71+
where
72+
Deeper (Depth_ d)
73+
| d == maxBound = error "Depth overflow"
74+
| otherwise = Depth_ (d + 1)
75+
76+
{-# COMPLETE Bottom, Deeper #-}
77+
78+
data CheckedBottom node a t where
79+
AtBottom :: CheckedBottom node a a
80+
NotBottom :: !(Depth_ node a t) -> CheckedBottom node a (node t)
81+
82+
checkBottom :: Depth_ node a t -> CheckedBottom node a t
83+
checkBottom (Depth_ 0) = unsafeCoerce AtBottom
84+
checkBottom (Depth_ d) = unsafeCoerce (NotBottom (Depth_ (d - 1)))
85+
86+
87+
-- | A version of 'Depth_' for implementing traversals. Conceptually,
88+
--
89+
-- @
90+
-- data Depth2_ node a t b u where
91+
-- Bottom2 :: Depth_ node a a b b
92+
-- Deeper2 :: !(Depth_ node a t b u) -> Depth_ node a (node t) b (node u)
93+
-- @
94+
newtype Depth2_ (node :: Type -> Type) (a :: Type) (t :: Type) (b :: Type) (u :: Type)
95+
= Depth2_ Word
96+
type role Depth2_ nominal nominal nominal nominal nominal
97+
98+
-- | The depth is 0.
99+
pattern Bottom2 :: () => (t ~ a, u ~ b) => Depth2_ node a t b u
100+
pattern Bottom2 <- (checkBottom2 -> AtBottom2)
101+
where
102+
Bottom2 = Depth2_ 0
103+
104+
-- | The depth is non-zero.
105+
pattern Deeper2 :: () => (t ~ node t', u ~ node u') => Depth2_ node a t' b u' -> Depth2_ node a t b u
106+
pattern Deeper2 d <- (checkBottom2 -> NotBottom2 d)
107+
where
108+
Deeper2 (Depth2_ d)
109+
| d == maxBound = error "Depth2 overflow"
110+
| otherwise = Depth2_ (d + 1)
111+
112+
{-# COMPLETE Bottom2, Deeper2 #-}
113+
114+
data CheckedBottom2 node a t b u where
115+
AtBottom2 :: CheckedBottom2 node a a b b
116+
NotBottom2 :: !(Depth2_ node a t b u) -> CheckedBottom2 node a (node t) b (node u)
117+
118+
checkBottom2 :: Depth2_ node a t b u -> CheckedBottom2 node a t b u
119+
checkBottom2 (Depth2_ 0) = unsafeCoerce AtBottom2
120+
checkBottom2 (Depth2_ d) = unsafeCoerce (NotBottom2 (Depth2_ (d - 1)))

0 commit comments

Comments
 (0)