Skip to content

Commit 0788eea

Browse files
authored
Merge pull request #17 from McMasterU/differentiation_backward
Differentiation reverse method
2 parents e471d50 + 62c9be7 commit 0788eea

27 files changed

+636
-124
lines changed

HashedExpression.cabal

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ cabal-version: 1.12
44
--
55
-- see: https://github.com/sol/hpack
66
--
7-
-- hash: 53e3c8434f7f2a6206e7b49dcb5c186d6ea5ffb6931a3c138160d6bbae9c09e8
7+
-- hash: 158d5ed7ffd77d8520dc6eedabb8d47f1240293abc2485d2358bc8fa6a262c85
88

99
name: HashedExpression
1010
version: 0.1.0.0
@@ -39,11 +39,13 @@ library
3939
HashedExpression.Codegen
4040
HashedExpression.Codegen.CSIMD
4141
HashedExpression.Codegen.CSimple
42-
HashedExpression.Derivative
43-
HashedExpression.Derivative.Partial
42+
HashedExpression.Differentiation.Exterior
43+
HashedExpression.Differentiation.Exterior.Collect
44+
HashedExpression.Differentiation.Exterior.Derivative
45+
HashedExpression.Differentiation.Reverse
46+
HashedExpression.Differentiation.Reverse.State
4447
HashedExpression.Embed.FFTW
4548
HashedExpression.Internal
46-
HashedExpression.Internal.CollectDifferential
4749
HashedExpression.Internal.Collision
4850
HashedExpression.Internal.Expression
4951
HashedExpression.Internal.Hash
@@ -180,6 +182,7 @@ test-suite HashedExpression-test
180182
InterpSpec
181183
NormalizeSpec
182184
ProblemSpec
185+
ReverseDifferentiationSpec
183186
StructureSpec
184187
Var
185188
Paths_HashedExpression

TODO.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ stack haddock --haddock-arguments "--odir=docs/"
3131
### TODO add regression tests for examples
3232
### TODO Better interface for Transformation's (in Inner.hs) needed?
3333
- Make relationship between Transformation/Modification/Change clearer? Put in it's own module?
34-
- toRecursiveSimplification and toRecursiveCollecting should be reduced to one function?
3534
### TODO Make sure we don't introduce bugs doing CodeGen for FT
3635
### TODO Maybe use Numeric Prelude for better Num class and then better VectorSpace
3736
### TODO add cVariable1D name = variable1D (name ++ "Re") +: variable1D (name ++ "Im"), cVariable2D = ...

app/Main.hs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,9 @@ import Data.Map (empty, fromList, union)
1111
import Data.Maybe (fromJust)
1212
import Data.STRef.Strict
1313
import qualified Data.Set as Set
14-
import HashedExpression.Derivative.Partial
1514
import Graphics.EasyPlot
1615
import HashedExpression
17-
import HashedExpression.Derivative
16+
import HashedExpression.Differentiation.Exterior.Derivative
1817
import HashedExpression.Interp
1918
import HashedExpression.Operation
2019
import qualified HashedExpression.Operation

src/HashedExpression.hs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,8 @@ module HashedExpression
5757
)
5858
where
5959

60-
import HashedExpression.Derivative
61-
import HashedExpression.Internal.CollectDifferential
60+
import HashedExpression.Differentiation.Exterior.Collect
61+
import HashedExpression.Differentiation.Exterior.Derivative
6262
import HashedExpression.Internal.Expression
6363
import HashedExpression.Internal.Normalize
6464
import HashedExpression.Interp

src/HashedExpression/Codegen/CSimple.hs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,11 @@ evaluating CSimpleCodegen {..} rootIDs =
219219
]
220220
RealPart arg -> for i (len n) [[I.i|#{n !! i} = #{arg `reAt` i};|]]
221221
ImagPart arg -> for i (len n) [[I.i|#{n !! i} = #{arg `imAt` i};|]]
222+
Conjugate arg ->
223+
for i (len n) $
224+
[ [I.i|#{n `reAt` i} = #{arg `reAt` i};|],
225+
[I.i|#{n `imAt` i} = -#{arg `imAt` i};|]
226+
]
222227
InnerProd arg1 arg2
223228
| et == R && null (shapeOf arg1) -> [[I.i|#{n !! nooffset} = #{arg1 !! nooffset} * #{arg2 !! nooffset};|]]
224229
| et == R ->

src/HashedExpression/Derivative/Partial.hs

Lines changed: 0 additions & 36 deletions
This file was deleted.
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
-- |
2+
-- Module : HashedExpression.Differentiation.Exterior.Derivative
3+
-- Copyright : (c) OCA 2020
4+
-- License : MIT (see the LICENSE file)
5+
-- Maintainer : anandc@mcmaster.ca
6+
-- Stability : provisional
7+
-- Portability : unportable
8+
module HashedExpression.Differentiation.Exterior where
9+
10+
import Data.Map.Strict (Map)
11+
import qualified Data.Map.Strict as Map
12+
import Data.Maybe (mapMaybe)
13+
import HashedExpression.Differentiation.Exterior.Collect
14+
import HashedExpression.Differentiation.Exterior.Derivative
15+
import HashedExpression.Internal
16+
import HashedExpression.Internal.Expression
17+
import HashedExpression.Internal.Node
18+
19+
partialDerivativesMapByExterior :: Expression Scalar R -> (ExpressionMap, Map String NodeID)
20+
partialDerivativesMapByExterior exp =
21+
let (mp, rootID) = unwrap . collectDifferentials . derivativeAllVars $ exp
22+
in (mp, partialDerivativesMap (mp, rootID))
23+
24+
-- | Return a map from variable name to the corresponding partial derivative node id
25+
-- Partial derivatives in Expression Scalar Covector should be collected before passing to this function
26+
partialDerivativesMap :: (ExpressionMap, NodeID) -> Map String NodeID
27+
partialDerivativesMap (dfMp, dfId) =
28+
case retrieveOp dfId dfMp of
29+
Sum ns | retrieveElementType dfId dfMp == Covector -> Map.fromList $ mapMaybe getPartial ns
30+
_ -> Map.fromList $ mapMaybe getPartial [dfId]
31+
where
32+
getPartial :: NodeID -> Maybe (String, NodeID)
33+
getPartial nId
34+
| MulD partialId dId <- retrieveOp nId dfMp,
35+
DVar name <- retrieveOp dId dfMp =
36+
Just (name, partialId)
37+
| InnerProdD partialId dId <- retrieveOp nId dfMp,
38+
DVar name <- retrieveOp dId dfMp =
39+
Just (name, partialId)
40+
| otherwise = Nothing

src/HashedExpression/Internal/CollectDifferential.hs renamed to src/HashedExpression/Differentiation/Exterior/Collect.hs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
-- |
2-
-- Module : HashedExpression.Internal.CollectDifferential
2+
-- Module : HashedExpression.Differentiation.Exterior.Collect
33
-- Copyright : (c) OCA 2020
44
-- License : MIT (see the LICENSE file)
55
-- Maintainer : anandc@mcmaster.ca
@@ -8,7 +8,7 @@
88
--
99
-- This module exists solely to factor terms around their differentials. When properly factored, the term multiplying
1010
-- a differential (say dx) is it's corresponding parital derivative (i.e derivative w.r.t x)
11-
module HashedExpression.Internal.CollectDifferential
11+
module HashedExpression.Differentiation.Exterior.Collect
1212
( collectDifferentials,
1313
)
1414
where

src/HashedExpression/Derivative.hs renamed to src/HashedExpression/Differentiation/Exterior/Derivative.hs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
{-# LANGUAGE ScopedTypeVariables #-}
22

33
-- |
4-
-- Module : HashedExpression.Derivative
4+
-- Module : HashedExpression.Differentiation.Exterior.Derivative
55
-- Copyright : (c) OCA 2020
66
-- License : MIT (see the LICENSE file)
77
-- Maintainer : anandc@mcmaster.ca
@@ -17,7 +17,7 @@
1717
-- Computing an exterior derivative on an expression @Expression d R@ will result in a @Expression d Covector@, i.e a 'Covector' field
1818
-- (also known as 1-form). This will contain 'dVar' terms representing where implicit differentiation has occurred. See 'CollectDifferential'
1919
-- to factor like terms for producing partial derivatives
20-
module HashedExpression.Derivative
20+
module HashedExpression.Differentiation.Exterior.Derivative
2121
( exteriorDerivative,
2222
derivativeAllVars,
2323
)
Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
-- |
2+
-- Module : HashedExpression.Differentiation.Exterior.Collect
3+
-- Copyright : (c) OCA 2020
4+
-- License : MIT (see the LICENSE file)
5+
-- Maintainer : anandc@mcmaster.ca
6+
-- Stability : provisional
7+
-- Portability : unportable
8+
--
9+
-- Compute differentiations using reverse accumulation method
10+
-- https://en.wikipedia.org/wiki/Automatic_differentiation#Reverse_accumulation
11+
module HashedExpression.Differentiation.Reverse where
12+
13+
import Control.Monad.State.Strict
14+
import qualified Data.IntMap.Strict as IM
15+
import Data.List (foldl')
16+
import Data.List.HT (removeEach)
17+
import Data.Map.Strict (Map)
18+
import qualified Data.Map.Strict as Map
19+
import HashedExpression.Differentiation.Reverse.State
20+
import HashedExpression.Internal
21+
import HashedExpression.Internal.Expression
22+
import HashedExpression.Internal.Hash
23+
import HashedExpression.Internal.Node
24+
import HashedExpression.Internal.OperationSpec
25+
import HashedExpression.Internal.Structure
26+
import Prelude hiding ((^))
27+
28+
-- |
29+
partialDerivativesMapByReverse ::
30+
Expression Scalar R ->
31+
(ExpressionMap, Map String NodeID)
32+
partialDerivativesMapByReverse (Expression rootID mp) =
33+
let reverseTopoOrder = reverse $ topologicalSort (mp, rootID)
34+
init = ComputeDState mp IM.empty Map.empty
35+
-- Chain rule
36+
go :: ComputeReverseM ()
37+
go = forM_ reverseTopoOrder $ \nID -> do
38+
--- NodeID of derivative w.r.t to current node: d(f) / d(nID)
39+
dN <-
40+
if nID == rootID
41+
then sNum 1
42+
else do
43+
dPartsFromParent <- IM.lookup nID <$> gets computedPartsByParents
44+
-- Sum all the derivative parts incurred by its parents
45+
case dPartsFromParent of
46+
Just [d] -> from d
47+
Just ds -> perform (Nary specSum) ds
48+
curMp <- gets contextMap
49+
let (shape, et, op) = retrieveNode nID curMp
50+
let one = introduceNode (shape, R, Const 1)
51+
let zero = introduceNode (shape, R, Const 0)
52+
case op of
53+
Var name -> modifyPartialDerivativeMap (Map.insert name dN)
54+
Const _ -> return ()
55+
Sum args -> do
56+
forM_ args $ \x -> do
57+
let dX = dN
58+
modifyComputedPartsByParents (IM.insertWith (++) x [dX])
59+
Mul args -> do
60+
forM_ (removeEach args) $ \(x, rest) -> do
61+
productRest <- perform (Nary specMul) rest
62+
if et == R
63+
then do
64+
dX <- from dN * from productRest
65+
modifyComputedPartsByParents (IM.insertWith (++) x [dX])
66+
else do
67+
dX <- from dN * conjugate (from productRest)
68+
modifyComputedPartsByParents (IM.insertWith (++) x [dX])
69+
Power alpha x -> case et of
70+
R -> do
71+
dX <- sNum (fromIntegral alpha) *. (from dN * (from x ^ (alpha -1)))
72+
modifyComputedPartsByParents (IM.insertWith (++) x [dX])
73+
C -> do
74+
dX <- sNum (fromIntegral alpha) *. (from dN * conjugate (from x ^ (alpha - 1)))
75+
modifyComputedPartsByParents (IM.insertWith (++) x [dX])
76+
Neg x -> do
77+
dX <- negate $ from dN
78+
modifyComputedPartsByParents (IM.insertWith (++) x [dX])
79+
Scale scalar scalee -> do
80+
case (retrieveElementType scalar curMp, retrieveElementType scalee curMp) of
81+
(R, R) -> do
82+
-- for scalar
83+
dScalar <- from dN <.> from scalee
84+
modifyComputedPartsByParents (IM.insertWith (++) scalar [dScalar])
85+
-- for scalee
86+
dScalee <- from scalar *. from dN
87+
modifyComputedPartsByParents (IM.insertWith (++) scalee [dScalee])
88+
(R, C) -> do
89+
-- for scalar
90+
dScalar <- xRe (from scalee) <.> xRe (from dN) + xIm (from scalee) <.> xIm (from dN)
91+
modifyComputedPartsByParents (IM.insertWith (++) scalar [dScalar])
92+
-- for scalee
93+
dScalee <- from scalar *. from dN
94+
modifyComputedPartsByParents (IM.insertWith (++) scalee [dScalee])
95+
(C, C) -> do
96+
-- for scalar
97+
dScalar <- from dN <.> from scalee
98+
modifyComputedPartsByParents (IM.insertWith (++) scalar [dScalar])
99+
-- for scalee
100+
dScalee <- conjugate (from scalar) *. from dN
101+
modifyComputedPartsByParents (IM.insertWith (++) scalee [dScalee])
102+
Div x y -> do
103+
dX <- from dN / from y
104+
modifyComputedPartsByParents (IM.insertWith (++) x [dX])
105+
dY <- from dN * from x * (from y ^ (-2))
106+
modifyComputedPartsByParents (IM.insertWith (++) y [dY])
107+
Sqrt x -> do
108+
dX <- sNum 0.5 *. (from dN / sqrt (from x))
109+
modifyComputedPartsByParents (IM.insertWith (++) x [dX])
110+
Sin x -> do
111+
dX <- from dN * cos (from x)
112+
modifyComputedPartsByParents (IM.insertWith (++) x [dX])
113+
Cos x -> do
114+
dX <- from dN * (- sin (from x))
115+
modifyComputedPartsByParents (IM.insertWith (++) x [dX])
116+
Tan x -> do
117+
dX <- from dN * (cos (from x) ^ (-2))
118+
modifyComputedPartsByParents (IM.insertWith (++) x [dX])
119+
Exp x -> do
120+
dX <- from dN * exp (from x)
121+
modifyComputedPartsByParents (IM.insertWith (++) x [dX])
122+
Log x -> do
123+
dX <- from dN * (from x ^ (-1))
124+
modifyComputedPartsByParents (IM.insertWith (++) x [dX])
125+
Sinh x -> do
126+
dX <- from dN * cosh (from x)
127+
modifyComputedPartsByParents (IM.insertWith (++) x [dX])
128+
Cosh x -> do
129+
dX <- from dN * sinh (from x)
130+
modifyComputedPartsByParents (IM.insertWith (++) x [dX])
131+
Tanh x -> do
132+
dX <- from dN * (one - tanh (from x) ^ 2)
133+
modifyComputedPartsByParents (IM.insertWith (++) x [dX])
134+
Asin x -> do
135+
dX <- from dN * (one / sqrt (one - from x ^ 2))
136+
modifyComputedPartsByParents (IM.insertWith (++) x [dX])
137+
Acos x -> do
138+
dX <- from dN * (- one / sqrt (one - from x ^ 2))
139+
modifyComputedPartsByParents (IM.insertWith (++) x [dX])
140+
Atan x -> do
141+
dX <- from dN * (one / one + from x ^ 2)
142+
modifyComputedPartsByParents (IM.insertWith (++) x [dX])
143+
Asinh x -> do
144+
dX <- from dN * (one / sqrt (one + from x ^ 2))
145+
modifyComputedPartsByParents (IM.insertWith (++) x [dX])
146+
Acosh x -> do
147+
dX <- from dN * (one / sqrt (from x ^ 2 - one))
148+
modifyComputedPartsByParents (IM.insertWith (++) x [dX])
149+
Atanh x -> do
150+
dX <- from dN * (one / sqrt (one - from x ^ 2))
151+
modifyComputedPartsByParents (IM.insertWith (++) x [dX])
152+
RealImag re im -> do
153+
dRe <- xRe $ from dN
154+
modifyComputedPartsByParents (IM.insertWith (++) re [dRe])
155+
dIm <- xIm $ from dN
156+
modifyComputedPartsByParents (IM.insertWith (++) im [dIm])
157+
RealPart reIm -> do
158+
dReIm <- from dN +: zero
159+
modifyComputedPartsByParents (IM.insertWith (++) reIm [dReIm])
160+
ImagPart reIm -> do
161+
dReIm <- zero +: from dN
162+
modifyComputedPartsByParents (IM.insertWith (++) reIm [dReIm])
163+
InnerProd x y -> do
164+
case et of
165+
R -> do
166+
dX <- from dN *. from y
167+
modifyComputedPartsByParents (IM.insertWith (++) x [dX])
168+
dY <- from dN *. from x
169+
modifyComputedPartsByParents (IM.insertWith (++) y [dY])
170+
C -> do
171+
dX <- from dN *. from y
172+
modifyComputedPartsByParents (IM.insertWith (++) x [dX])
173+
dY <- conjugate (from dN) *. from x
174+
modifyComputedPartsByParents (IM.insertWith (++) y [dY])
175+
Piecewise marks condition branches -> do
176+
dCondition <- zero
177+
modifyComputedPartsByParents (IM.insertWith (++) condition [dCondition])
178+
let numBranches = length branches
179+
forM_ (zip branches [0 ..]) $ \(branch, idx) -> case et of
180+
R -> do
181+
associate <- piecewise marks (from condition) (replicate idx zero ++ [one] ++ replicate (numBranches - idx - 1) zero)
182+
dBranch <- from dN * from associate
183+
modifyComputedPartsByParents (IM.insertWith (++) branch [dBranch])
184+
C -> do
185+
let zeroC = zero +: zero
186+
let oneC = one +: zero
187+
associate <- piecewise marks (from condition) (replicate idx zeroC ++ [oneC] ++ replicate (numBranches - idx - 1) zeroC)
188+
dBranch <- from dN * conjugate (from associate)
189+
modifyComputedPartsByParents (IM.insertWith (++) branch [dBranch])
190+
Rotate amount x -> do
191+
dX <- perform (Unary (specRotate (map negate amount))) [dN]
192+
dX <- rotate (map negate amount) $ from dN
193+
modifyComputedPartsByParents (IM.insertWith (++) x [dX])
194+
ReFT x
195+
| retrieveElementType x curMp == R -> do
196+
dX <- reFT (from dN)
197+
modifyComputedPartsByParents (IM.insertWith (++) x [dX])
198+
| otherwise -> do
199+
dX <- reFT (from dN) +: (- imFT (from dN))
200+
modifyComputedPartsByParents (IM.insertWith (++) x [dX])
201+
ImFT x
202+
| retrieveElementType x curMp == R -> do
203+
dX <- imFT (from dN)
204+
modifyComputedPartsByParents (IM.insertWith (++) x [dX])
205+
| otherwise -> do
206+
dX <- imFT (from dN) +: (- reFT (from dN))
207+
modifyComputedPartsByParents (IM.insertWith (++) x [dX])
208+
(_, res) = runState go init
209+
in (contextMap res, partialDerivativeMap res)
210+

0 commit comments

Comments
 (0)