Skip to content

Commit 6f4cd79

Browse files
committed
[ resolver conflict ] merge remote
2 parents 0748823 + 759e468 commit 6f4cd79

File tree

20 files changed

+165
-129
lines changed

20 files changed

+165
-129
lines changed

Makefile

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,6 @@ doc:
77
format:
88
find src -type f -name '*.hs' | xargs $(FORMAT)
99
find test -type f -name '*.hs' | xargs $(FORMAT)
10-
find symphony/Symphony -type f -name '*.hs' | xargs $(FORMAT)
11-
12-
13-
check:
14-
stack clean
15-
stack build --fast --ghc-options -Wall
1610

1711
clean:
1812
-git clean -f -x C/

README.md

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ Type-safe modelling DSL, symbolic transformation, and code generation for solvin
44

55

66
## Features
7-
- A type-safe, correct by construction APIs to model optimization problems, empowered by Haskell's phantom-type and type-level programming.
7+
- A type-safe, correct-by-construction APIs to model optimization problems, empowered by Haskell's phantom-type and type-level programming.
88
- For example, adding 2 expressions with mismatched shape or element type (**R** or C) will result in type error will result in type error:
99
```haskell
1010
λ> let x = variable1D @10 "x"
@@ -106,19 +106,22 @@ Model is in [app/Examples/Ex2.hs](app/Examples/Ex2.hs), data is in [examples/ex2
106106

107107
```haskell
108108
sigmoid :: (Dimension d) => Expression d R -> Expression d R
109-
sigmoid x = 1.0 / (1.0 + exp (-x))
109+
sigmoid x = 1.0 / (1.0 + exp (- x))
110110

111111
ex2_logisticRegression :: OptimizationProblem
112112
ex2_logisticRegression =
113-
let x = param2D @118 @28 "x"
114-
y = param2D @118 @1 "y"
115-
theta = variable2D @28 @1 "theta"
113+
let -- variables
114+
theta = variable1D @28 "theta"
115+
-- parameters
116+
x = param2D @118 @28 "x"
117+
y = param1D @118 "y"
116118
hypothesis = sigmoid (x ** theta)
119+
-- regularization
117120
lambda = 1
118-
regTheta = project (range @1 @27, at @0) theta
121+
regTheta = project (range @1 @27) theta
119122
regularization = (lambda / 2) * (regTheta <.> regTheta)
120123
in OptimizationProblem
121-
{ objective = sumElements ((-y) * log hypothesis - (1 - y) * log (1 - hypothesis)) + regularization,
124+
{ objective = sumElements ((- y) * log hypothesis - (1 - y) * log (1 - hypothesis)) + regularization,
122125
constraints = [],
123126
values =
124127
[ x :-> VFile (TXT "x_expanded.txt"),

app/Examples/Brain.hs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ brain_reconstructFromMRI =
2626
re :-> VFile (HDF5 "kspace.h5" "re"),
2727
mask :-> VFile (HDF5 "mask.h5" "mask")
2828
],
29-
workingDir = "problems" </> "brain"
29+
workingDir = "examples" </> "brain"
3030
}
3131

3232
brain :: IO ()

app/Examples/Ex1.hs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ ex1_linearRegression =
1818
[ x :-> VFile (TXT "x.txt"),
1919
y :-> VFile (TXT "y.txt")
2020
],
21-
workingDir = "problems" </> "ex1"
21+
workingDir = "examples" </> "ex1"
2222
}
2323

2424
ex1 :: IO ()

app/Examples/Ex2.hs

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,28 +2,31 @@ module Examples.Ex2 where
22

33
import HashedExpression
44
import System.FilePath ((</>))
5-
import Prelude hiding ((^), (**))
5+
import Prelude hiding ((**), (^))
66

77
sigmoid :: (Dimension d) => Expression d R -> Expression d R
8-
sigmoid x = 1.0 / (1.0 + exp (-x))
8+
sigmoid x = 1.0 / (1.0 + exp (- x))
99

1010
ex2_logisticRegression :: OptimizationProblem
1111
ex2_logisticRegression =
12-
let x = param2D @118 @28 "x"
13-
y = param2D @118 @1 "y"
14-
theta = variable2D @28 @1 "theta"
12+
let -- variables
13+
theta = variable1D @28 "theta"
14+
-- parameters
15+
x = param2D @118 @28 "x"
16+
y = param1D @118 "y"
1517
hypothesis = sigmoid (x ** theta)
18+
-- regularization
1619
lambda = 1
17-
regTheta = project (range @1 @27, at @0) theta
20+
regTheta = project (range @1 @27) theta
1821
regularization = (lambda / 2) * (regTheta <.> regTheta)
1922
in OptimizationProblem
20-
{ objective = sumElements ((-y) * log hypothesis - (1 - y) * log (1 - hypothesis)) + regularization,
23+
{ objective = sumElements ((- y) * log hypothesis - (1 - y) * log (1 - hypothesis)) + regularization,
2124
constraints = [],
2225
values =
2326
[ x :-> VFile (TXT "x_expanded.txt"),
2427
y :-> VFile (TXT "y.txt")
2528
],
26-
workingDir = "problems" </> "ex2"
29+
workingDir = "examples" </> "ex2"
2730
}
2831

2932
ex2 :: IO ()

hie.yaml

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,7 @@ cradle:
33
- path: "./src"
44
component: "HashedExpression:lib"
55

6-
- path: "./app/Main.hs"
7-
component: "HashedExpression:exe:HashedExpression-exe"
8-
9-
- path: "./app/Examples.Brain.hs"
10-
component: "HashedExpression:exe:HashedExpression-exe"
11-
12-
- path: "./app/Examples.Ex1.hs"
13-
component: "HashedExpression:exe:HashedExpression-exe"
14-
15-
- path: "./app/Examples.Ex2.hs"
16-
component: "HashedExpression:exe:HashedExpression-exe"
17-
18-
- path: "./app/Paths_HashedExpression.hs"
6+
- path: "./app"
197
component: "HashedExpression:exe:HashedExpression-exe"
208

219
- path: "./test"

src/HashedExpression.hs

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,9 @@ module HashedExpression
2828
module HashedExpression.Value,
2929
module HashedExpression.Codegen,
3030
module HashedExpression.Codegen.CSimple,
31-
ValueAssignment(..),
32-
OptimizationProblem(..),
33-
proceed
31+
ValueAssignment (..),
32+
OptimizationProblem (..),
33+
proceed,
3434
)
3535
where
3636

@@ -59,13 +59,12 @@ mkValMap ss = Map.fromList $ mapMaybe f ss
5959
| (_, _, Param name) <- retrieveNode nID mp = Just (name, val)
6060
| otherwise = Nothing
6161

62-
data OptimizationProblem =
63-
OptimizationProblem
64-
{ objective :: Expression Scalar R,
65-
constraints :: [ConstraintStatement],
66-
values :: [ValueAssignment],
67-
workingDir :: String
68-
}
62+
data OptimizationProblem = OptimizationProblem
63+
{ objective :: Expression Scalar R,
64+
constraints :: [ConstraintStatement],
65+
values :: [ValueAssignment],
66+
workingDir :: String
67+
}
6968

7069
proceed :: Codegen codegen => OptimizationProblem -> codegen -> IO ()
7170
proceed OptimizationProblem {..} codegen = do

src/HashedExpression/Codegen/CSimple.hs

Lines changed: 27 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -146,23 +146,26 @@ initCodegen config mp variableIDs =
146146
where
147147
(cs, rest) = partition (`Set.member` Set.fromList variableIDs) $ nodeIDs mp
148148
f (addressMap, curSizeReal, curSizeComplex) nID =
149-
let (shape, et, node) = retrieveNode nID mp
150-
in case et of
151-
R -> (Map.insert nID (AddressReal curSizeReal) addressMap, curSizeReal + product shape, curSizeComplex)
152-
C -> (Map.insert nID (AddressComplex curSizeComplex) addressMap, curSizeReal, curSizeComplex + product shape)
149+
let (shape, et, op) = retrieveNode nID mp
150+
in case (op, et) of
151+
(Coerce {}, _) -> (addressMap, curSizeReal, curSizeComplex)
152+
(_, R) -> (Map.insert nID (AddressReal curSizeReal) addressMap, curSizeReal + product shape, curSizeComplex)
153+
(_, C) -> (Map.insert nID (AddressComplex curSizeComplex) addressMap, curSizeReal, curSizeComplex + product shape)
153154
(memMap, totalSizeReal, totalSizeComplex) = foldl' f (Map.empty, 0, 0) $ cs ++ rest
154155
addressMap nID
155156
| Just offset <- Map.lookup nID memMap = offset
156157
| otherwise = error "Node ID doesn't exist in address map"
157158
access :: NodeID -> Text -> Text
158-
access nID offsetVal =
159-
let offset
160-
| offsetVal == "" = ""
161-
| offsetVal == "0" = ""
162-
| otherwise = " + " <> offsetVal
163-
in case addressMap nID of
164-
AddressReal i -> "ptr[" <> tt i <> offset <> "]"
165-
AddressComplex i -> "ptr_c[" <> tt i <> offset <> "]"
159+
access nID offsetVal
160+
| Coerce _ from <- retrieveOp nID mp = access from offsetVal
161+
| otherwise =
162+
let offset
163+
| offsetVal == "" = ""
164+
| offsetVal == "0" = ""
165+
| otherwise = " + " <> offsetVal
166+
in case addressMap nID of
167+
AddressReal i -> "ptr[" <> tt i <> offset <> "]"
168+
AddressComplex i -> "ptr_c[" <> tt i <> offset <> "]"
166169

167170
---------------------------------------------------------------------------------
168171
evaluating :: CSimpleCodegen -> [NodeID] -> Code
@@ -385,17 +388,17 @@ evaluating CSimpleCodegen {..} rootIDs =
385388
in Scoped [copyBase, injectSub]
386389
MatMul x y ->
387390
case (retrieveShape x cExpressionMap, retrieveShape y cExpressionMap) of
388-
-- ([size1, size2], [_size2]) ->
389-
-- for i size1 $
390-
-- [ if et == R
391-
-- then "double acc" := "0"
392-
-- else "double complex acc" := "0",
393-
-- for j size2 $
394-
-- [ "int ij" := ("i * " <> tt size1 <> " + j"),
395-
-- "acc" := ("acc + " <> (x !! "ij") <> " * " <> (y !! j))
396-
-- ],
397-
-- (n !! i) := "acc"
398-
-- ]
391+
([size1, size2], [_size2]) ->
392+
for i size1 $
393+
[ if et == R
394+
then "double acc" := "0"
395+
else "double complex acc" := "0",
396+
for j size2 $
397+
[ "int ij" := ("i * " <> tt size2 <> " + j"),
398+
"acc" := ("acc + " <> (x !! "ij") <> " * " <> (y !! j))
399+
],
400+
(n !! i) := "acc"
401+
]
399402
([size1, size2], [_size2, size3]) ->
400403
for i size1 $
401404
[ for j size3 $
@@ -412,7 +415,6 @@ evaluating CSimpleCodegen {..} rootIDs =
412415
]
413416
]
414417
Transpose x -> case retrieveShape x cExpressionMap of
415-
-- [size] -> for i size [(n !! i) := (x !! i)]
416418
[size1, size2] ->
417419
for i size2 $
418420
[ for j size1 $
@@ -421,6 +423,7 @@ evaluating CSimpleCodegen {..} rootIDs =
421423
(n !! "ij") := (x !! "ji")
422424
]
423425
]
426+
Coerce {} -> Empty
424427
node -> error $ "Not implemented " ++ show node
425428

426429
--

src/HashedExpression/Differentiation/Reverse.hs

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ import HashedExpression.Differentiation.Reverse.State
1919
import HashedExpression.Internal
2020
import HashedExpression.Internal.Context
2121
import HashedExpression.Internal.Expression
22-
import HashedExpression.Internal.Hash
2322
import HashedExpression.Internal.Node
2423
import HashedExpression.Internal.OperationSpec
2524
import Prelude hiding ((**), (^))
@@ -223,12 +222,20 @@ partialDerivativesMap (Expression rootID mp) =
223222
C -> do
224223
dY <- inject dss (zeroX +: zeroX) (from dN)
225224
addDerivative y dY
226-
MatMul x y -> do
227-
-- mn np mp
228-
dX <- from dN ** transpose (from y)
229-
addDerivative x dX
230-
dY <- transpose (from x) ** from dN
231-
addDerivative y dY
225+
MatMul x y ->
226+
case (retrieveShape x curMp, retrieveShape y curMp) of
227+
([m, n], [_n, p]) -> do
228+
-- mn np mp
229+
dX <- from dN ** transpose (from y)
230+
addDerivative x dX
231+
dY <- transpose (from x) ** from dN
232+
addDerivative y dY
233+
([m, n], [_n]) -> do
234+
-- mn n m
235+
dX <- (coerceTo [m, 1] $ from dN) ** (coerceTo [1, n] $ from y)
236+
addDerivative x dX
237+
dY <- transpose (from x) ** from dN
238+
addDerivative y dY
232239
Transpose x -> do
233240
dX <- transpose $ from dN
234241
addDerivative x dX

src/HashedExpression/Differentiation/Reverse/State.hs

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,10 @@ import Data.List.HT (removeEach)
1616
import Data.Map.Strict (Map)
1717
import qualified Data.Map.Strict as Map
1818
import GHC.Stack (HasCallStack)
19-
import HashedExpression.Internal
2019
import HashedExpression.Internal.Base
2120
import HashedExpression.Internal.Context
2221
import HashedExpression.Internal.Expression
2322
import HashedExpression.Internal.Hash
24-
import HashedExpression.Internal.Node
25-
import HashedExpression.Internal.OperationSpec
2623
import Prelude hiding ((^))
2724

2825
data ComputeDState = ComputeDState

0 commit comments

Comments
 (0)