|
| 1 | +-- | |
| 2 | +-- Module : HashedExpression.Differentiation.Exterior.Collect |
| 3 | +-- Copyright : (c) OCA 2020 |
| 4 | +-- License : MIT (see the LICENSE file) |
| 5 | +-- Maintainer : anandc@mcmaster.ca |
| 6 | +-- Stability : provisional |
| 7 | +-- Portability : unportable |
| 8 | +-- |
| 9 | +-- Compute differentiations using reverse accumulation method |
| 10 | +-- https://en.wikipedia.org/wiki/Automatic_differentiation#Reverse_accumulation |
| 11 | +module HashedExpression.Differentiation.Reverse where |
| 12 | + |
| 13 | +import Control.Monad.State.Strict |
| 14 | +import qualified Data.IntMap.Strict as IM |
| 15 | +import Data.List (foldl') |
| 16 | +import Data.List.HT (removeEach) |
| 17 | +import Data.Map.Strict (Map) |
| 18 | +import qualified Data.Map.Strict as Map |
| 19 | +import HashedExpression.Differentiation.Reverse.State |
| 20 | +import HashedExpression.Internal |
| 21 | +import HashedExpression.Internal.Expression |
| 22 | +import HashedExpression.Internal.Hash |
| 23 | +import HashedExpression.Internal.Node |
| 24 | +import HashedExpression.Internal.OperationSpec |
| 25 | +import HashedExpression.Internal.Structure |
| 26 | +import Prelude hiding ((^)) |
| 27 | + |
| 28 | +-- | |
| 29 | +partialDerivativesMapByReverse :: |
| 30 | + Expression Scalar R -> |
| 31 | + (ExpressionMap, Map String NodeID) |
| 32 | +partialDerivativesMapByReverse (Expression rootID mp) = |
| 33 | + let reverseTopoOrder = reverse $ topologicalSort (mp, rootID) |
| 34 | + init = ComputeDState mp IM.empty Map.empty |
| 35 | + -- Chain rule |
| 36 | + go :: ComputeReverseM () |
| 37 | + go = forM_ reverseTopoOrder $ \nID -> do |
| 38 | + --- NodeID of derivative w.r.t to current node: d(f) / d(nID) |
| 39 | + dN <- |
| 40 | + if nID == rootID |
| 41 | + then sNum 1 |
| 42 | + else do |
| 43 | + dPartsFromParent <- IM.lookup nID <$> gets computedPartsByParents |
| 44 | + -- Sum all the derivative parts incurred by its parents |
| 45 | + case dPartsFromParent of |
| 46 | + Just [d] -> from d |
| 47 | + Just ds -> perform (Nary specSum) ds |
| 48 | + curMp <- gets contextMap |
| 49 | + let (shape, et, op) = retrieveNode nID curMp |
| 50 | + let one = introduceNode (shape, R, Const 1) |
| 51 | + let zero = introduceNode (shape, R, Const 0) |
| 52 | + case op of |
| 53 | + Var name -> modifyPartialDerivativeMap (Map.insert name dN) |
| 54 | + Const _ -> return () |
| 55 | + Sum args -> do |
| 56 | + forM_ args $ \x -> do |
| 57 | + let dX = dN |
| 58 | + modifyComputedPartsByParents (IM.insertWith (++) x [dX]) |
| 59 | + Mul args -> do |
| 60 | + forM_ (removeEach args) $ \(x, rest) -> do |
| 61 | + productRest <- perform (Nary specMul) rest |
| 62 | + if et == R |
| 63 | + then do |
| 64 | + dX <- from dN * from productRest |
| 65 | + modifyComputedPartsByParents (IM.insertWith (++) x [dX]) |
| 66 | + else do |
| 67 | + dX <- from dN * conjugate (from productRest) |
| 68 | + modifyComputedPartsByParents (IM.insertWith (++) x [dX]) |
| 69 | + Power alpha x -> case et of |
| 70 | + R -> do |
| 71 | + dX <- sNum (fromIntegral alpha) *. (from dN * (from x ^ (alpha -1))) |
| 72 | + modifyComputedPartsByParents (IM.insertWith (++) x [dX]) |
| 73 | + C -> do |
| 74 | + dX <- sNum (fromIntegral alpha) *. (from dN * conjugate (from x ^ (alpha - 1))) |
| 75 | + modifyComputedPartsByParents (IM.insertWith (++) x [dX]) |
| 76 | + Neg x -> do |
| 77 | + dX <- negate $ from dN |
| 78 | + modifyComputedPartsByParents (IM.insertWith (++) x [dX]) |
| 79 | + Scale scalar scalee -> do |
| 80 | + case (retrieveElementType scalar curMp, retrieveElementType scalee curMp) of |
| 81 | + (R, R) -> do |
| 82 | + -- for scalar |
| 83 | + dScalar <- from dN <.> from scalee |
| 84 | + modifyComputedPartsByParents (IM.insertWith (++) scalar [dScalar]) |
| 85 | + -- for scalee |
| 86 | + dScalee <- from scalar *. from dN |
| 87 | + modifyComputedPartsByParents (IM.insertWith (++) scalee [dScalee]) |
| 88 | + (R, C) -> do |
| 89 | + -- for scalar |
| 90 | + dScalar <- xRe (from scalee) <.> xRe (from dN) + xIm (from scalee) <.> xIm (from dN) |
| 91 | + modifyComputedPartsByParents (IM.insertWith (++) scalar [dScalar]) |
| 92 | + -- for scalee |
| 93 | + dScalee <- from scalar *. from dN |
| 94 | + modifyComputedPartsByParents (IM.insertWith (++) scalee [dScalee]) |
| 95 | + (C, C) -> do |
| 96 | + -- for scalar |
| 97 | + dScalar <- from dN <.> from scalee |
| 98 | + modifyComputedPartsByParents (IM.insertWith (++) scalar [dScalar]) |
| 99 | + -- for scalee |
| 100 | + dScalee <- conjugate (from scalar) *. from dN |
| 101 | + modifyComputedPartsByParents (IM.insertWith (++) scalee [dScalee]) |
| 102 | + Div x y -> do |
| 103 | + dX <- from dN / from y |
| 104 | + modifyComputedPartsByParents (IM.insertWith (++) x [dX]) |
| 105 | + dY <- from dN * from x * (from y ^ (-2)) |
| 106 | + modifyComputedPartsByParents (IM.insertWith (++) y [dY]) |
| 107 | + Sqrt x -> do |
| 108 | + dX <- sNum 0.5 *. (from dN / sqrt (from x)) |
| 109 | + modifyComputedPartsByParents (IM.insertWith (++) x [dX]) |
| 110 | + Sin x -> do |
| 111 | + dX <- from dN * cos (from x) |
| 112 | + modifyComputedPartsByParents (IM.insertWith (++) x [dX]) |
| 113 | + Cos x -> do |
| 114 | + dX <- from dN * (- sin (from x)) |
| 115 | + modifyComputedPartsByParents (IM.insertWith (++) x [dX]) |
| 116 | + Tan x -> do |
| 117 | + dX <- from dN * (cos (from x) ^ (-2)) |
| 118 | + modifyComputedPartsByParents (IM.insertWith (++) x [dX]) |
| 119 | + Exp x -> do |
| 120 | + dX <- from dN * exp (from x) |
| 121 | + modifyComputedPartsByParents (IM.insertWith (++) x [dX]) |
| 122 | + Log x -> do |
| 123 | + dX <- from dN * (from x ^ (-1)) |
| 124 | + modifyComputedPartsByParents (IM.insertWith (++) x [dX]) |
| 125 | + Sinh x -> do |
| 126 | + dX <- from dN * cosh (from x) |
| 127 | + modifyComputedPartsByParents (IM.insertWith (++) x [dX]) |
| 128 | + Cosh x -> do |
| 129 | + dX <- from dN * sinh (from x) |
| 130 | + modifyComputedPartsByParents (IM.insertWith (++) x [dX]) |
| 131 | + Tanh x -> do |
| 132 | + dX <- from dN * (one - tanh (from x) ^ 2) |
| 133 | + modifyComputedPartsByParents (IM.insertWith (++) x [dX]) |
| 134 | + Asin x -> do |
| 135 | + dX <- from dN * (one / sqrt (one - from x ^ 2)) |
| 136 | + modifyComputedPartsByParents (IM.insertWith (++) x [dX]) |
| 137 | + Acos x -> do |
| 138 | + dX <- from dN * (- one / sqrt (one - from x ^ 2)) |
| 139 | + modifyComputedPartsByParents (IM.insertWith (++) x [dX]) |
| 140 | + Atan x -> do |
| 141 | + dX <- from dN * (one / one + from x ^ 2) |
| 142 | + modifyComputedPartsByParents (IM.insertWith (++) x [dX]) |
| 143 | + Asinh x -> do |
| 144 | + dX <- from dN * (one / sqrt (one + from x ^ 2)) |
| 145 | + modifyComputedPartsByParents (IM.insertWith (++) x [dX]) |
| 146 | + Acosh x -> do |
| 147 | + dX <- from dN * (one / sqrt (from x ^ 2 - one)) |
| 148 | + modifyComputedPartsByParents (IM.insertWith (++) x [dX]) |
| 149 | + Atanh x -> do |
| 150 | + dX <- from dN * (one / sqrt (one - from x ^ 2)) |
| 151 | + modifyComputedPartsByParents (IM.insertWith (++) x [dX]) |
| 152 | + RealImag re im -> do |
| 153 | + dRe <- xRe $ from dN |
| 154 | + modifyComputedPartsByParents (IM.insertWith (++) re [dRe]) |
| 155 | + dIm <- xIm $ from dN |
| 156 | + modifyComputedPartsByParents (IM.insertWith (++) im [dIm]) |
| 157 | + RealPart reIm -> do |
| 158 | + dReIm <- from dN +: zero |
| 159 | + modifyComputedPartsByParents (IM.insertWith (++) reIm [dReIm]) |
| 160 | + ImagPart reIm -> do |
| 161 | + dReIm <- zero +: from dN |
| 162 | + modifyComputedPartsByParents (IM.insertWith (++) reIm [dReIm]) |
| 163 | + InnerProd x y -> do |
| 164 | + case et of |
| 165 | + R -> do |
| 166 | + dX <- from dN *. from y |
| 167 | + modifyComputedPartsByParents (IM.insertWith (++) x [dX]) |
| 168 | + dY <- from dN *. from x |
| 169 | + modifyComputedPartsByParents (IM.insertWith (++) y [dY]) |
| 170 | + C -> do |
| 171 | + dX <- from dN *. from y |
| 172 | + modifyComputedPartsByParents (IM.insertWith (++) x [dX]) |
| 173 | + dY <- conjugate (from dN) *. from x |
| 174 | + modifyComputedPartsByParents (IM.insertWith (++) y [dY]) |
| 175 | + Piecewise marks condition branches -> do |
| 176 | + dCondition <- zero |
| 177 | + modifyComputedPartsByParents (IM.insertWith (++) condition [dCondition]) |
| 178 | + let numBranches = length branches |
| 179 | + forM_ (zip branches [0 ..]) $ \(branch, idx) -> case et of |
| 180 | + R -> do |
| 181 | + associate <- piecewise marks (from condition) (replicate idx zero ++ [one] ++ replicate (numBranches - idx - 1) zero) |
| 182 | + dBranch <- from dN * from associate |
| 183 | + modifyComputedPartsByParents (IM.insertWith (++) branch [dBranch]) |
| 184 | + C -> do |
| 185 | + let zeroC = zero +: zero |
| 186 | + let oneC = one +: zero |
| 187 | + associate <- piecewise marks (from condition) (replicate idx zeroC ++ [oneC] ++ replicate (numBranches - idx - 1) zeroC) |
| 188 | + dBranch <- from dN * conjugate (from associate) |
| 189 | + modifyComputedPartsByParents (IM.insertWith (++) branch [dBranch]) |
| 190 | + Rotate amount x -> do |
| 191 | + dX <- perform (Unary (specRotate (map negate amount))) [dN] |
| 192 | + dX <- rotate (map negate amount) $ from dN |
| 193 | + modifyComputedPartsByParents (IM.insertWith (++) x [dX]) |
| 194 | + ReFT x |
| 195 | + | retrieveElementType x curMp == R -> do |
| 196 | + dX <- reFT (from dN) |
| 197 | + modifyComputedPartsByParents (IM.insertWith (++) x [dX]) |
| 198 | + | otherwise -> do |
| 199 | + dX <- reFT (from dN) +: (- imFT (from dN)) |
| 200 | + modifyComputedPartsByParents (IM.insertWith (++) x [dX]) |
| 201 | + ImFT x |
| 202 | + | retrieveElementType x curMp == R -> do |
| 203 | + dX <- imFT (from dN) |
| 204 | + modifyComputedPartsByParents (IM.insertWith (++) x [dX]) |
| 205 | + | otherwise -> do |
| 206 | + dX <- imFT (from dN) +: (- reFT (from dN)) |
| 207 | + modifyComputedPartsByParents (IM.insertWith (++) x [dX]) |
| 208 | + (_, res) = runState go init |
| 209 | + in (contextMap res, partialDerivativeMap res) |
| 210 | + |
0 commit comments