1- {-# LANGUAGE NamedFieldPuns #-}
2- {-# LANGUAGE ScopedTypeVariables #-}
1+ {-# LANGUAGE DisambiguateRecordFields #-}
2+ {-# LANGUAGE NamedFieldPuns #-}
3+ {-# LANGUAGE ScopedTypeVariables #-}
4+ {-# LANGUAGE TupleSections #-}
35
46-- | The module should be imported qualified.
57--
@@ -16,19 +18,19 @@ module Ouroboros.Network.TxSubmission.Mempool.Simple
1618import Prelude hiding (read , seq )
1719
1820import Control.Concurrent.Class.MonadSTM.Strict
19-
21+ import Data.Bitraversable
22+ import Data.Either
2023import Data.Foldable (toList )
2124import Data.Foldable qualified as Foldable
22- import Data.Function (on )
23- import Data.List (find , nubBy )
25+ import Data.List (find )
2426import Data.Maybe (isJust )
2527import Data.Sequence (Seq )
2628import Data.Sequence qualified as Seq
2729import Data.Set (Set )
2830import Data.Set qualified as Set
2931
32+ import Ouroboros.Network.Protocol.LocalTxSubmission.Type (SubmitResult (.. ))
3033import Ouroboros.Network.SizeInBytes
31- import Ouroboros.Network.TxSubmission.Inbound.V2.Types
3234import Ouroboros.Network.TxSubmission.Mempool.Reader
3335
3436
@@ -98,31 +100,65 @@ getReader getTxId getTxSize (Mempool mempool) =
98100 f idx tx = (getTxId tx, idx, getTxSize tx)
99101
100102
101- -- | A simple mempool writer.
103+ -- | A mempool writer which generalizes the tx submission mempool writer
104+ -- TODO: We could replace TxSubmissionMempoolWriter with this at some point
105+ --
106+ data MempoolWriter txid tx failure idx m =
107+ MempoolWriter {
108+
109+ -- | Compute the transaction id from a transaction.
110+ --
111+ -- This is used in the protocol handler to verify a full transaction
112+ -- matches a previously given transaction id.
113+ --
114+ txId :: tx -> txid ,
115+
116+ -- | Supply a batch of transactions to the mempool. They are either
117+ -- accepted or rejected individually, but in the order supplied.
118+ --
119+ -- The 'txid's of all transactions that were added successfully are
120+ -- returned.
121+ mempoolAddTxs :: [tx ] -> m [SubmitResult failure ]
122+ }
123+
124+
125+ -- | A mempool writer with validation harness
126+ -- PRECONDITION: no duplicates given to mempoolAddTxs
102127--
103- getWriter :: forall tx txid m .
128+ getWriter :: forall tx txid tx' failure m .
104129 ( MonadSTM m
105130 , Ord txid
106131 )
107132 => (tx -> txid )
108- -> (tx -> Bool )
109- -- ^ validate a tx
110- -> Mempool m tx
111- -> TxSubmissionMempoolWriter txid tx Int m
112- getWriter getTxId validateTx (Mempool mempool) =
113- TxSubmissionMempoolWriter {
114- txId = getTxId,
115-
116- mempoolAddTxs = \ txs -> do
117- atomically $ do
118- mempoolTxs <- readTVar mempool
119- let currentIds = Set. fromList (map getTxId (toList mempoolTxs))
120- validTxs = nubBy (on (==) getTxId)
121- $ filter
122- (\ tx -> validateTx tx
123- && getTxId tx `Set.notMember` currentIds)
124- txs
125- mempoolTxs' = Foldable. foldl' (Seq. |>) mempoolTxs validTxs
126- writeTVar mempool mempoolTxs'
127- return (map getTxId validTxs)
128- }
133+ -- ^ get txid of a tx
134+ -> ([tx ] -> m [tx' ])
135+ -- ^ monadic validation context, acquired once prior to all work
136+ -> (tx' -> Bool -> Either failure () )
137+ -- ^ validate a tx in an atomic block, any failing `tx` throws an exception.
138+ -> (failure -> STM m failure )
139+ -- ^ return `True` when a failure should throw an exception
140+ -> Mempool m txid tx
141+ -> MempoolWriter txid tx failure Int m
142+ getWriter getTxId withContext validateTx failureFilterFn (Mempool mempool) =
143+ MempoolWriter {
144+ txId = getTxId,
145+
146+ mempoolAddTxs = \ txs -> do
147+ txs' <- withContext txs
148+ atomically $ do
149+ MempoolSeq { mempoolSet, mempoolSeq } <- readTVar mempool
150+ result <- sequence
151+ [bimapM (fmap SubmitFail . failureFilterFn) (pure . const (txid, tx)) validated
152+ | (tx, tx') <- zip txs txs'
153+ , let txid = getTxId tx
154+ validated =
155+ validateTx tx' (txid `Set.member` mempoolSet)
156+ ]
157+ let (validIds, validTxs) = unzip . rights $ result
158+ mempoolTxs' = MempoolSeq {
159+ mempoolSet = Set. union mempoolSet (Set. fromList validIds),
160+ mempoolSeq = Foldable. foldl' (Seq. |>) mempoolSeq validTxs
161+ }
162+ writeTVar mempool mempoolTxs'
163+ return $ fromLeft SubmitSuccess <$> result
164+ }
0 commit comments