@@ -618,7 +618,7 @@ func createChannelEdge(node1, node2 *models.LightningNode) (
618
618
chanID := uint64 (prand .Int63 ())
619
619
outpoint := wire.OutPoint {
620
620
Hash : rev ,
621
- Index : 9 ,
621
+ Index : prand . Uint32 () ,
622
622
}
623
623
624
624
// Add the new edge to the database, this should proceed without any
@@ -991,6 +991,97 @@ func newEdgePolicy(chanID uint64, updateTime int64) *models.ChannelEdgePolicy {
991
991
}
992
992
}
993
993
994
+ // TestForEachSourceNodeChannel tests that the ForEachSourceNodeChannel
995
+ // correctly iterates through the channels of the set source node.
996
+ func TestForEachSourceNodeChannel (t * testing.T ) {
997
+ t .Parallel ()
998
+
999
+ graph , err := MakeTestGraph (t )
1000
+ require .NoError (t , err , "unable to make test database" )
1001
+
1002
+ // Create a source node (A) and set it as such in the DB.
1003
+ nodeA := createTestVertex (t )
1004
+ require .NoError (t , graph .SetSourceNode (nodeA ))
1005
+
1006
+ // Now, create a few more nodes (B, C, D) along with some channels
1007
+ // between them. We'll create the following graph:
1008
+ //
1009
+ // A -- B -- D
1010
+ // |
1011
+ // C
1012
+ //
1013
+ // The graph includes a channel (B-D) that does not belong to the source
1014
+ // node along with 2 channels (A-B and A-C) that do belong to the source
1015
+ // node. For the A-B channel, we will let the source node set an
1016
+ // outgoing policy but for the A-C channel, we will set only an incoming
1017
+ // policy.
1018
+
1019
+ nodeB := createTestVertex (t )
1020
+ nodeC := createTestVertex (t )
1021
+ nodeD := createTestVertex (t )
1022
+
1023
+ abEdge , abPolicy1 , abPolicy2 := createChannelEdge (nodeA , nodeB )
1024
+ require .NoError (t , graph .AddChannelEdge (abEdge ))
1025
+ acEdge , acPolicy1 , acPolicy2 := createChannelEdge (nodeA , nodeC )
1026
+ require .NoError (t , graph .AddChannelEdge (acEdge ))
1027
+ bdEdge , _ , _ := createChannelEdge (nodeB , nodeD )
1028
+ require .NoError (t , graph .AddChannelEdge (bdEdge ))
1029
+
1030
+ // Figure out which of the policies returned above are node A's so that
1031
+ // we know which to persist.
1032
+ //
1033
+ // First, set the outgoing policy for the A-B channel.
1034
+ abPolicyAOutgoing := abPolicy1
1035
+ if ! bytes .Equal (abPolicy1 .ToNode [:], nodeB .PubKeyBytes [:]) {
1036
+ abPolicyAOutgoing = abPolicy2
1037
+ }
1038
+ require .NoError (t , graph .UpdateEdgePolicy (abPolicyAOutgoing ))
1039
+
1040
+ // Now, set the incoming policy for the A-C channel.
1041
+ acPolicyAIncoming := acPolicy1
1042
+ if ! bytes .Equal (acPolicy1 .ToNode [:], nodeA .PubKeyBytes [:]) {
1043
+ acPolicyAIncoming = acPolicy2
1044
+ }
1045
+ require .NoError (t , graph .UpdateEdgePolicy (acPolicyAIncoming ))
1046
+
1047
+ type sourceNodeChan struct {
1048
+ otherNode route.Vertex
1049
+ havePolicy bool
1050
+ }
1051
+
1052
+ // Put together our expected source node channels.
1053
+ expectedSrcChans := map [wire.OutPoint ]* sourceNodeChan {
1054
+ abEdge .ChannelPoint : {
1055
+ otherNode : nodeB .PubKeyBytes ,
1056
+ havePolicy : true ,
1057
+ },
1058
+ acEdge .ChannelPoint : {
1059
+ otherNode : nodeC .PubKeyBytes ,
1060
+ havePolicy : false ,
1061
+ },
1062
+ }
1063
+
1064
+ // Now, we'll use the ForEachSourceNodeChannel and assert that it
1065
+ // returns the expected data in the call-back.
1066
+ err = graph .ForEachSourceNodeChannel (func (chanPoint wire.OutPoint ,
1067
+ havePolicy bool , otherNode * models.LightningNode ) error {
1068
+
1069
+ require .Contains (t , expectedSrcChans , chanPoint )
1070
+ expected := expectedSrcChans [chanPoint ]
1071
+
1072
+ require .Equal (
1073
+ t , expected .otherNode [:], otherNode .PubKeyBytes [:],
1074
+ )
1075
+ require .Equal (t , expected .havePolicy , havePolicy )
1076
+
1077
+ delete (expectedSrcChans , chanPoint )
1078
+
1079
+ return nil
1080
+ })
1081
+ require .NoError (t , err )
1082
+ require .Empty (t , expectedSrcChans )
1083
+ }
1084
+
994
1085
func TestGraphTraversal (t * testing.T ) {
995
1086
t .Parallel ()
996
1087
@@ -1050,7 +1141,7 @@ func TestGraphTraversal(t *testing.T) {
1050
1141
numNodeChans := 0
1051
1142
firstNode , secondNode := nodeList [0 ], nodeList [1 ]
1052
1143
err = graph .ForEachNodeChannel (firstNode .PubKeyBytes ,
1053
- func (_ kvdb. RTx , _ * models.ChannelEdgeInfo , outEdge ,
1144
+ func (_ * models.ChannelEdgeInfo , outEdge ,
1054
1145
inEdge * models.ChannelEdgePolicy ) error {
1055
1146
1056
1147
// All channels between first and second node should
@@ -1126,26 +1217,15 @@ func TestGraphTraversalCacheable(t *testing.T) {
1126
1217
require .NoError (t , err )
1127
1218
require .Len (t , nodeMap , 0 )
1128
1219
1129
- err = graph .db .View (func (tx kvdb.RTx ) error {
1130
- for _ , node := range nodes {
1131
- err := graph .ForEachNodeChannelTx (tx , node ,
1132
- func (tx kvdb.RTx , info * models.ChannelEdgeInfo ,
1133
- policy * models.ChannelEdgePolicy ,
1134
- policy2 * models.ChannelEdgePolicy ) error { //nolint:ll
1135
-
1136
- delete (chanIndex , info .ChannelID )
1137
- return nil
1138
- },
1139
- )
1140
- if err != nil {
1141
- return err
1142
- }
1143
- }
1144
-
1145
- return nil
1146
- }, func () {})
1147
-
1148
- require .NoError (t , err )
1220
+ for _ , node := range nodes {
1221
+ err = graph .ForEachNodeDirectedChannel (
1222
+ node , func (d * DirectedChannel ) error {
1223
+ delete (chanIndex , d .ChannelID )
1224
+ return nil
1225
+ },
1226
+ )
1227
+ require .NoError (t , err )
1228
+ }
1149
1229
require .Len (t , chanIndex , 0 )
1150
1230
}
1151
1231
@@ -2802,7 +2882,7 @@ func TestIncompleteChannelPolicies(t *testing.T) {
2802
2882
2803
2883
calls := 0
2804
2884
err := graph .ForEachNodeChannel (node .PubKeyBytes ,
2805
- func (_ kvdb. RTx , _ * models.ChannelEdgeInfo , outEdge ,
2885
+ func (_ * models.ChannelEdgeInfo , outEdge ,
2806
2886
inEdge * models.ChannelEdgePolicy ) error {
2807
2887
2808
2888
if ! expectedOut && outEdge != nil {
@@ -3921,8 +4001,7 @@ func BenchmarkForEachChannel(b *testing.B) {
3921
4001
require .NoError (b , err )
3922
4002
3923
4003
for _ , n := range nodes {
3924
- cb := func (tx kvdb.RTx ,
3925
- info * models.ChannelEdgeInfo ,
4004
+ cb := func (info * models.ChannelEdgeInfo ,
3926
4005
policy * models.ChannelEdgePolicy ,
3927
4006
policy2 * models.ChannelEdgePolicy ) error { //nolint:ll
3928
4007
0 commit comments