Skip to content

Commit 31ccf5a

Browse files
authored
Merge pull request #24 from McMasterU/project_inject
Project inject
2 parents 0658aa0 + a4c22f9 commit 31ccf5a

File tree

20 files changed

+810
-232
lines changed

20 files changed

+810
-232
lines changed

app/Main.hs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
module Main where
55

6-
import Data.Array
6+
import qualified Data.Array as Array
77
import Data.Complex
88
import qualified Data.IntMap.Strict as IM
99
import Data.List (intercalate)
@@ -13,7 +13,7 @@ import Data.Maybe (fromJust)
1313
import Data.STRef.Strict
1414
import qualified Data.Set as Set
1515
import Graphics.EasyPlot
16-
import HashedExpression
16+
import HashedExpression.Internal.Expression
1717
import HashedExpression.Codegen
1818
import HashedExpression.Codegen.CSimple
1919
import HashedExpression.Differentiation.Reverse
@@ -25,12 +25,15 @@ import HashedExpression.Prettify
2525
import Data.String.Interpolate
2626
import Prelude hiding ((^))
2727
import Control.Monad (forM_)
28+
import Data.Data
2829

2930
main :: IO ()
3031
main = do
3132
let x = variable2D @10 @10 "x"
3233
let y = variable2D @10 @10 "y"
33-
let f = norm2square (ift (ft (x +: 0)) - 5)
34+
let kaka = variable1D @20 "x"
35+
let k = project (at @2) kaka
36+
let f = norm2square (ift (ft (x +: 0)) - 5) + k
3437
case constructProblem f (Constraint []) of
3538
ProblemValid problem ->
3639
case generateProblemCode (CSimpleConfig { output = OutputText }) problem Map.empty of

src/HashedExpression.hs

Lines changed: 10 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -19,43 +19,21 @@
1919
-- the above code creates a simple HashedExpression using the
2020
-- 'variable' constructor method and taking advantage of the 'Num' class instance
2121
module HashedExpression
22-
( -- * Expression Constructors
23-
Expression,
24-
R,
25-
C,
26-
Covector,
27-
Scalar,
28-
PowerOp (..),
29-
PiecewiseOp (..),
30-
VectorSpaceOp (..),
31-
FTOp (..),
32-
NodeID,
33-
ComplexRealOp (..),
34-
RotateOp (..),
35-
InnerProductSpaceOp (..),
36-
37-
-- * Combinators
38-
constant,
39-
constant1D,
40-
constant2D,
41-
constant3D,
42-
variable,
43-
variable1D,
44-
variable2D,
45-
variable3D,
46-
param,
47-
param1D,
48-
param2D,
49-
param3D,
50-
51-
-- * Evaluation
52-
Evaluable (..),
53-
prettify,
22+
( module HashedExpression.Internal.Expression,
23+
module HashedExpression.Operation,
24+
module HashedExpression.Internal.Simplify,
25+
module HashedExpression.Prettify,
26+
module HashedExpression.Interp,
27+
module HashedExpression.Problem,
28+
module HashedExpression.Value,
5429
)
5530
where
5631

5732
import HashedExpression.Internal.Expression
33+
import HashedExpression.Internal.Simplify
5834
import HashedExpression.Interp
5935
import HashedExpression.Operation
6036
import HashedExpression.Prettify
37+
import HashedExpression.Problem
38+
import HashedExpression.Value
6139
import Prelude hiding ((^))

src/HashedExpression/Codegen/CSimple.hs

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,14 @@ for iter bound codes =
8484
]
8585
++ scoped codes
8686

87+
forRange :: Text -> (Int, Int, Int) -> Code -> Code
88+
forRange iter (start, end, step) codes =
89+
scoped $
90+
[ [i|int #{iter};|],
91+
[i|for (#{iter} = #{start}; #{iter} <= #{end}; #{iter} += #{step})|]
92+
]
93+
++ scoped codes
94+
8795
if_ :: Text -> Code -> Code
8896
if_ condition codes = [[i|if (#{condition})|]] ++ scoped codes
8997

@@ -337,6 +345,122 @@ evaluating CSimpleCodegen {..} rootIDs =
337345
case shape of
338346
[size] -> [[I.i|dft_1d(#{size}, #{addressOf arg}, #{addressOf n}, FFTW_BACKWARD);|]]
339347
[size1, size2] -> [[I.i|dft_2d(#{size1}, #{size2}, #{addressOf arg}, #{addressOf n}, FFTW_BACKWARD);|]]
348+
Project dss arg ->
349+
case (dss, retrieveShape arg cExpressionMap) of
350+
([ds], [size]) ->
351+
let toIndex i = [I.i|#{i} % #{size}|]
352+
in scoped $
353+
"int nxt = 0;" :
354+
( forRange i (toRange ds size) $
355+
if et == R
356+
then
357+
[ [I.i|#{n !! "nxt"} = #{arg !! (toIndex i)};|],
358+
"nxt++;"
359+
]
360+
else
361+
[ [I.i|#{n `reAt` "nxt"} = #{arg `reAt` (toIndex i)};|],
362+
[I.i|#{n `imAt` "nxt"} = #{arg `imAt` (toIndex i)};|],
363+
"nxt++;"
364+
]
365+
)
366+
([ds1, ds2], [size1, size2]) ->
367+
let toIndex i j = [I.i|(#{i} % #{size1}) * #{size2} + (#{j} % #{size2})|]
368+
in scoped $
369+
"int nxt = 0;" :
370+
( forRange i (toRange ds1 size1) $
371+
forRange j (toRange ds2 size2) $
372+
if et == R
373+
then
374+
[ [I.i|#{n !! "nxt"} = #{arg !! (toIndex i j)};|],
375+
"nxt++;"
376+
]
377+
else
378+
[ [I.i|#{n `reAt` "nxt"} = #{arg `reAt` (toIndex i j)};|],
379+
[I.i|#{n `imAt` "nxt"} = #{arg `imAt` (toIndex i j)};|],
380+
"nxt++;"
381+
]
382+
)
383+
([ds1, ds2, ds3], [size1, size2, size3]) ->
384+
let toIndex i j k = [I.i|(#{i} % #{size1}) * #{size2} * #{size3} + (#{j} % #{size2}) * #{size3} + (#{k} % #{size3})|]
385+
in scoped $
386+
"int nxt = 0;" :
387+
( forRange i (toRange ds1 size1) $
388+
forRange j (toRange ds2 size2) $
389+
forRange k (toRange ds3 size3) $
390+
if et == R
391+
then
392+
[ [I.i|#{n !! "nxt"} = #{arg !! (toIndex i j k)};|],
393+
"nxt++;"
394+
]
395+
else
396+
[ [I.i|#{n `reAt` "nxt"} = #{arg `reAt` (toIndex i j k)};|],
397+
[I.i|#{n `imAt` "nxt"} = #{arg `imAt` (toIndex i j k)};|],
398+
"nxt++;"
399+
]
400+
)
401+
Inject dss sub base ->
402+
let copyBase =
403+
if et == R
404+
then for i (len n) [[I.i|#{n !! i} = #{base !! i};|]]
405+
else
406+
for i (len n) $
407+
[ [I.i|#{n `reAt` i} = #{base `reAt` i};|],
408+
[I.i|#{n `imAt` i} = #{base `imAt` i};|]
409+
]
410+
injectSub =
411+
case (dss, retrieveShape n cExpressionMap) of
412+
([ds], [size]) ->
413+
let toIndex i = [I.i|#{i} % #{size}|]
414+
in scoped $
415+
"int nxt = 0;" :
416+
( forRange i (toRange ds size) $
417+
if et == R
418+
then
419+
[ [I.i|#{n !! (toIndex i)} = #{sub !! "nxt"};|],
420+
"nxt++;"
421+
]
422+
else
423+
[ [I.i|#{n `reAt` (toIndex i)} = #{sub `reAt` "nxt"};|],
424+
[I.i|#{n `imAt` (toIndex i)} = #{sub `imAt` "nxt"};|],
425+
"nxt++;"
426+
]
427+
)
428+
([ds1, ds2], [size1, size2]) ->
429+
let toIndex i j = [I.i|(#{i} % #{size1}) * #{size2} + (#{j} % #{size2})|]
430+
in scoped $
431+
"int nxt = 0;" :
432+
( forRange i (toRange ds1 size1) $
433+
forRange j (toRange ds2 size2) $
434+
if et == R
435+
then
436+
[ [I.i|#{n !! (toIndex i j)} = #{sub !! "nxt"};|],
437+
"nxt++;"
438+
]
439+
else
440+
[ [I.i|#{n `reAt` (toIndex i j)} = #{sub `reAt` "nxt"};|],
441+
[I.i|#{n `imAt` (toIndex i j)} = #{sub `imAt` "nxt"};|],
442+
"nxt++;"
443+
]
444+
)
445+
([ds1, ds2, ds3], [size1, size2, size3]) ->
446+
let toIndex i j k = [I.i|(#{i} % #{size1}) * #{size2} * #{size3} + (#{j} % #{size2}) * #{size3} + (#{k} % #{size3})|]
447+
in scoped $
448+
"int nxt = 0;" :
449+
( forRange i (toRange ds1 size1) $
450+
forRange j (toRange ds2 size2) $
451+
forRange k (toRange ds3 size3) $
452+
if et == R
453+
then
454+
[ [I.i|#{n !! (toIndex i j k)} = #{sub !! "nxt"};|],
455+
"nxt++;"
456+
]
457+
else
458+
[ [I.i|#{n `reAt` (toIndex i j k)} = #{sub `reAt` "nxt"};|],
459+
[I.i|#{n `imAt` (toIndex i j k)} = #{sub `imAt` "nxt"};|],
460+
"nxt++;"
461+
]
462+
)
463+
in scoped $ copyBase ++ injectSub
340464
node -> error $ "Not implemented " ++ show node
341465

342466
-------------------------------------------------------------------------------

src/HashedExpression/Differentiation/Reverse.hs

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,6 @@ partialDerivativesMap (Expression rootID mp) =
191191
dBranch <- from dN * conjugate (from associate)
192192
addDerivative branch dBranch
193193
Rotate amount x -> do
194-
dX <- perform (Unary (specRotate (map negate amount))) [dN]
195194
dX <- rotate (map negate amount) $ from dN
196195
addDerivative x dX
197196
FT x -> do
@@ -202,5 +201,25 @@ partialDerivativesMap (Expression rootID mp) =
202201
let sz = fromIntegral $ product shape
203202
dX <- sNum (1.0 / sz) *. ft (from dN)
204203
addDerivative x dX
204+
Project dss x -> do
205+
let zeroX = introduceNode (retrieveShape x curMp, R, Const 0)
206+
case et of
207+
R -> do
208+
dX <- inject dss (from dN) zeroX
209+
addDerivative x dX
210+
C -> do
211+
dX <- inject dss (from dN) (zeroX +: zeroX)
212+
addDerivative x dX
213+
Inject dss x y -> do
214+
let zeroX = introduceNode (retrieveShape x curMp, R, Const 0)
215+
dX <- project dss (from dN)
216+
addDerivative x dX
217+
case et of
218+
R -> do
219+
dY <- inject dss zeroX (from dN)
220+
addDerivative y dY
221+
C -> do
222+
dY <- inject dss (zeroX +: zeroX) (from dN)
223+
addDerivative y dY
205224
(_, res) = runState go init
206225
in (contextMap res, partialDerivativeMap res)

src/HashedExpression/Internal/Context.hs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,15 @@ instance (MonadExpression m) => FTOp (m NodeID) (m NodeID) where
151151
x <- operand
152152
perform (Unary specIFT) [x]
153153

154+
instance (MonadExpression m) => ProjectInjectOp [DimSelector] (m NodeID) (m NodeID) where
155+
project ss operand = do
156+
x <- operand
157+
perform (Unary (specProject ss)) [x]
158+
inject ss sub base = do
159+
x <- sub
160+
y <- base
161+
perform (Binary (specInject ss)) [x, y]
162+
154163
-------------------------------------------------------------------------------
155164

156165
instance (MonadExpression m) => MulCovectorOp (m NodeID) (m NodeID) (m NodeID) where

src/HashedExpression/Internal/Expression.hs

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ module HashedExpression.Internal.Expression
2121
Op (..),
2222
Node,
2323
NodeID,
24+
DimSelector (..),
25+
ProjectInjectOp (..),
2426
ExpressionMap,
2527
Expression (..),
2628
Arg,
@@ -182,6 +184,10 @@ data Op
182184
Rotate RotateAmount Arg
183185
| FT Arg
184186
| IFT Arg
187+
| -- | Projection
188+
Project [DimSelector] Arg
189+
| -- | Injection
190+
Inject [DimSelector] SubArg BaseArg -- inject Arg into BaseArg
185191
| -- | differentiable operators (only for exterior method)
186192
DVar String
187193
| DZero
@@ -205,8 +211,25 @@ type ConditionArg = NodeID
205211
-- of a 'ConditionArg' expression
206212
type BranchArg = NodeID
207213

214+
-- |
208215
type CovectorArg = NodeID
209216

217+
type SubArg = NodeID
218+
219+
type BaseArg = NodeID
220+
221+
-- |
222+
type Position = [Int]
223+
224+
-- | DimSelector for projection
225+
data DimSelector
226+
= At Int -- Will collapse the corresponding dimension
227+
| Range -- (inclusion)
228+
Int -- start
229+
Int -- end
230+
Int -- step
231+
deriving (Show, Eq, Ord)
232+
210233
-- --------------------------------------------------------------------------------------------------------------------
211234

212235
-- * Expression Element Types
@@ -387,6 +410,12 @@ class FTOp a b | a -> b, b -> a where
387410
ft :: a -> b
388411
ift :: b -> a
389412

413+
-- |
414+
class ProjectInjectOp s a b | s a -> b where
415+
project :: s -> a -> b
416+
inject :: s -> b -> a -> a
417+
418+
-------------------------------------------------------------------------------
390419
class MulCovectorOp a b c | a b -> c, c -> a, c -> b where
391420
(|*|) :: a -> b -> c
392421

src/HashedExpression/Internal/Hash.hs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,10 @@ offsetHash offset hash =
5757
separator :: String
5858
separator = "a"
5959

60+
toStringHash :: DimSelector -> String
61+
toStringHash (At i) = "at" ++ show i
62+
toStringHash (Range b e st) = "range" ++ show b ++ separator ++ show e ++ separator ++ show st
63+
6064
-- | Compute a hash value for a given 'Node' and number of rehash
6165
hash :: Node -> Int -> Int
6266
hash (shape, et, node) rehashNum =
@@ -105,6 +109,8 @@ hash (shape, et, node) rehashNum =
105109
Rotate amount arg -> offsetHash 29 . hashString' $ (intercalate separator . map show $ amount) ++ separator ++ show arg
106110
FT arg -> offsetHash 30 . hashString' $ show arg
107111
IFT arg -> offsetHash 31 . hashString' $ show arg
112+
Project ss arg -> offsetHash 32 . hashString' $ (intercalate separator . map toStringHash $ ss) ++ separator ++ show arg
113+
Inject ss sub base -> offsetHash 33 . hashString' $ (intercalate separator . map toStringHash $ ss) ++ separator ++ show sub ++ show base
108114
-------------------------------------------------------------------------------
109115
Conjugate arg -> offsetHash 37 . hashString' $ show arg
110116
-- Mark: Covector

src/HashedExpression/Internal/Node.hs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,10 +77,13 @@ nodeTypeWeight node =
7777
Rotate {} -> 29
7878
FT {} -> 30
7979
IFT {} -> 31
80+
Project {} -> 32
81+
Inject {} -> 33
82+
-------------------------------------------------
8083
Scale {} -> 36 -- Right after RealImag
8184
RealImag {} -> 37 -- At the end right after sum
8285
Sum {} -> 38 -- Sum at the end
83-
------------------------
86+
-------------------------------------------------
8487
DVar {} -> 101
8588
DZero {} -> 102
8689
MulD {} -> 103
@@ -130,6 +133,8 @@ opArgs node =
130133
Rotate _ arg -> [arg]
131134
FT arg -> [arg]
132135
IFT arg -> [arg]
136+
Project ss arg -> [arg]
137+
Inject ss sub base -> [sub, base]
133138
DZero -> []
134139
MulD arg1 arg2 -> [arg1, arg2]
135140
ScaleD arg1 arg2 -> [arg1, arg2]
@@ -173,6 +178,8 @@ mapOp f op =
173178
Rotate am arg -> Rotate am (f arg)
174179
FT arg -> FT (f arg)
175180
IFT arg -> IFT (f arg)
181+
Project s arg -> Project s (f arg)
182+
Inject s sub base -> Inject s (f sub) (f base)
176183
DZero -> DZero
177184
MulD arg1 arg2 -> MulD (f arg1) (f arg2)
178185
ScaleD arg1 arg2 -> ScaleD (f arg1) (f arg2)

0 commit comments

Comments
 (0)