@@ -146,23 +146,26 @@ initCodegen config mp variableIDs =
146146 where
147147 (cs, rest) = partition (`Set.member` Set. fromList variableIDs) $ nodeIDs mp
148148 f (addressMap, curSizeReal, curSizeComplex) nID =
149- let (shape, et, node) = retrieveNode nID mp
150- in case et of
151- R -> (Map. insert nID (AddressReal curSizeReal) addressMap, curSizeReal + product shape, curSizeComplex)
152- C -> (Map. insert nID (AddressComplex curSizeComplex) addressMap, curSizeReal, curSizeComplex + product shape)
149+ let (shape, et, op) = retrieveNode nID mp
150+ in case (op, et) of
151+ (Coerce {}, _) -> (addressMap, curSizeReal, curSizeComplex)
152+ (_, R ) -> (Map. insert nID (AddressReal curSizeReal) addressMap, curSizeReal + product shape, curSizeComplex)
153+ (_, C ) -> (Map. insert nID (AddressComplex curSizeComplex) addressMap, curSizeReal, curSizeComplex + product shape)
153154 (memMap, totalSizeReal, totalSizeComplex) = foldl' f (Map. empty, 0 , 0 ) $ cs ++ rest
154155 addressMap nID
155156 | Just offset <- Map. lookup nID memMap = offset
156157 | otherwise = error " Node ID doesn't exist in address map"
157158 access :: NodeID -> Text -> Text
158- access nID offsetVal =
159- let offset
160- | offsetVal == " " = " "
161- | offsetVal == " 0" = " "
162- | otherwise = " + " <> offsetVal
163- in case addressMap nID of
164- AddressReal i -> " ptr[" <> tt i <> offset <> " ]"
165- AddressComplex i -> " ptr_c[" <> tt i <> offset <> " ]"
159+ access nID offsetVal
160+ | Coerce _ from <- retrieveOp nID mp = access from offsetVal
161+ | otherwise =
162+ let offset
163+ | offsetVal == " " = " "
164+ | offsetVal == " 0" = " "
165+ | otherwise = " + " <> offsetVal
166+ in case addressMap nID of
167+ AddressReal i -> " ptr[" <> tt i <> offset <> " ]"
168+ AddressComplex i -> " ptr_c[" <> tt i <> offset <> " ]"
166169
167170---------------------------------------------------------------------------------
168171evaluating :: CSimpleCodegen -> [NodeID ] -> Code
@@ -385,17 +388,17 @@ evaluating CSimpleCodegen {..} rootIDs =
385388 in Scoped [copyBase, injectSub]
386389 MatMul x y ->
387390 case (retrieveShape x cExpressionMap, retrieveShape y cExpressionMap) of
388- -- ([size1, size2], [_size2]) ->
389- -- for i size1 $
390- -- [ if et == R
391- -- then "double acc" := "0"
392- -- else "double complex acc" := "0",
393- -- for j size2 $
394- -- [ "int ij" := ("i * " <> tt size1 <> " + j"),
395- -- "acc" := ("acc + " <> (x !! "ij") <> " * " <> (y !! j))
396- -- ],
397- -- (n !! i) := "acc"
398- -- ]
391+ ([size1, size2], [_size2]) ->
392+ for i size1 $
393+ [ if et == R
394+ then " double acc" := " 0"
395+ else " double complex acc" := " 0" ,
396+ for j size2 $
397+ [ " int ij" := (" i * " <> tt size2 <> " + j" ),
398+ " acc" := (" acc + " <> (x !! " ij" ) <> " * " <> (y !! j))
399+ ],
400+ (n !! i) := " acc"
401+ ]
399402 ([size1, size2], [_size2, size3]) ->
400403 for i size1 $
401404 [ for j size3 $
@@ -412,7 +415,6 @@ evaluating CSimpleCodegen {..} rootIDs =
412415 ]
413416 ]
414417 Transpose x -> case retrieveShape x cExpressionMap of
415- -- [size] -> for i size [(n !! i) := (x !! i)]
416418 [size1, size2] ->
417419 for i size2 $
418420 [ for j size1 $
@@ -421,6 +423,7 @@ evaluating CSimpleCodegen {..} rootIDs =
421423 (n !! " ij" ) := (x !! " ji" )
422424 ]
423425 ]
426+ Coerce {} -> Empty
424427 node -> error $ " Not implemented " ++ show node
425428
426429--
0 commit comments