Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/release-notes/release-notes-0.20.0.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ circuit. The indices are only available for forwarding events saved after v0.20.
* [7](https://github.com/lightningnetwork/lnd/pull/9937)
* [8](https://github.com/lightningnetwork/lnd/pull/9938)
* [9](https://github.com/lightningnetwork/lnd/pull/9939)
* [10](https://github.com/lightningnetwork/lnd/pull/9971)

## RPC Updates

Expand Down
4 changes: 2 additions & 2 deletions graph/db/graph_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1128,7 +1128,7 @@ func TestAddEdgeProof(t *testing.T) {
t.Parallel()
ctx := context.Background()

graph := MakeTestGraph(t)
graph := MakeTestGraphNew(t)

// Add an edge with no proof.
node1 := createTestVertex(t)
Expand Down Expand Up @@ -4325,7 +4325,7 @@ func TestGraphLoading(t *testing.T) {
func TestClosedScid(t *testing.T) {
t.Parallel()

graph := MakeTestGraph(t)
graph := MakeTestGraphNew(t)

scid := lnwire.ShortChannelID{}

Expand Down
4 changes: 2 additions & 2 deletions graph/db/kv_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -2150,8 +2150,8 @@ func (c *KVStore) ChanUpdatesInHorizon(startTime,
}

if len(edgesInHorizon) > 0 {
log.Debugf("ChanUpdatesInHorizon hit percentage: %f (%d/%d)",
float64(hits)/float64(len(edgesInHorizon)), hits,
log.Debugf("ChanUpdatesInHorizon hit percentage: %.2f (%d/%d)",
float64(hits)*100/float64(len(edgesInHorizon)), hits,
len(edgesInHorizon))
} else {
log.Debugf("ChanUpdatesInHorizon returned no edges in "+
Expand Down
261 changes: 257 additions & 4 deletions graph/db/sql_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ type SQLQueries interface {
Channel queries.
*/
CreateChannel(ctx context.Context, arg sqlc.CreateChannelParams) (int64, error)
AddV1ChannelProof(ctx context.Context, arg sqlc.AddV1ChannelProofParams) (sql.Result, error)
GetChannelBySCID(ctx context.Context, arg sqlc.GetChannelBySCIDParams) (sqlc.Channel, error)
GetChannelByOutpoint(ctx context.Context, outpoint string) (sqlc.GetChannelByOutpointRow, error)
GetChannelsBySCIDRange(ctx context.Context, arg sqlc.GetChannelsBySCIDRangeParams) ([]sqlc.GetChannelsBySCIDRangeRow, error)
Expand Down Expand Up @@ -136,6 +137,12 @@ type SQLQueries interface {
GetPruneTip(ctx context.Context) (sqlc.PruneLog, error)
UpsertPruneLogEntry(ctx context.Context, arg sqlc.UpsertPruneLogEntryParams) error
DeletePruneLogEntriesInRange(ctx context.Context, arg sqlc.DeletePruneLogEntriesInRangeParams) error

/*
Closed SCID table queries.
*/
InsertClosedChannel(ctx context.Context, scid []byte) error
IsClosedChannel(ctx context.Context, scid []byte) (bool, error)
}

// BatchedSQLQueries is a version of SQLQueries that's capable of batched
Expand Down Expand Up @@ -1096,8 +1103,8 @@ func (s *SQLStore) ChanUpdatesInHorizon(startTime,
}

if len(edges) > 0 {
log.Debugf("ChanUpdatesInHorizon hit percentage: %f (%d/%d)",
float64(hits)/float64(len(edges)), hits, len(edges))
log.Debugf("ChanUpdatesInHorizon hit percentage: %.2f (%d/%d)",
float64(hits)*100/float64(len(edges)), hits, len(edges))
} else {
log.Debugf("ChanUpdatesInHorizon returned no edges in "+
"horizon (%s, %s)", startTime, endTime)
Expand Down Expand Up @@ -1231,6 +1238,103 @@ func (s *SQLStore) ForEachNodeCached(cb func(node route.Vertex,
}, sqldb.NoOpReset)
}

// ForEachChannelCacheable iterates through all the channel edges stored
// within the graph and invokes the passed callback for each edge. The
// callback takes two edges as since this is a directed graph, both the
// in/out edges are visited. If the callback returns an error, then the
// transaction is aborted and the iteration stops early.
//
// NOTE: If an edge can't be found, or wasn't advertised, then a nil
// pointer for that particular channel edge routing policy will be
// passed into the callback.
//
// NOTE: this method is like ForEachChannel but fetches only the data
// required for the graph cache.
func (s *SQLStore) ForEachChannelCacheable(cb func(*models.CachedEdgeInfo,
*models.CachedEdgePolicy,
*models.CachedEdgePolicy) error) error {

ctx := context.TODO()

handleChannel := func(db SQLQueries,
row sqlc.ListChannelsWithPoliciesPaginatedRow) error {

node1, node2, err := buildNodeVertices(
row.Node1Pubkey, row.Node2Pubkey,
)
if err != nil {
return err
}

edge := buildCacheableChannelInfo(row.Channel, node1, node2)

dbPol1, dbPol2, err := extractChannelPolicies(row)
if err != nil {
return err
}

var pol1, pol2 *models.CachedEdgePolicy
if dbPol1 != nil {
policy1, err := buildChanPolicy(
*dbPol1, edge.ChannelID, nil, node2, true,
)
if err != nil {
return err
}

pol1 = models.NewCachedPolicy(policy1)
}
if dbPol2 != nil {
policy2, err := buildChanPolicy(
*dbPol2, edge.ChannelID, nil, node1, false,
)
if err != nil {
return err
}

pol2 = models.NewCachedPolicy(policy2)
}

if err := cb(edge, pol1, pol2); err != nil {
return err
}

return nil
}

return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
lastID := int64(-1)
for {
//nolint:ll
rows, err := db.ListChannelsWithPoliciesPaginated(
ctx, sqlc.ListChannelsWithPoliciesPaginatedParams{
Version: int16(ProtocolV1),
ID: lastID,
Limit: pageSize,
},
)
if err != nil {
return err
}

if len(rows) == 0 {
break
}

for _, row := range rows {
err := handleChannel(db, row)
if err != nil {
return err
}

lastID = row.Channel.ID
}
}

return nil
}, sqldb.NoOpReset)
}

// ForEachChannel iterates through all the channel edges stored within the
// graph and invokes the passed callback for each edge. The callback takes two
// edges as since this is a directed graph, both the in/out edges are visited.
Expand Down Expand Up @@ -1291,7 +1395,7 @@ func (s *SQLStore) ForEachChannel(cb func(*models.ChannelEdgeInfo,
}

return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
var lastID int64
lastID := int64(-1)
for {
//nolint:ll
rows, err := db.ListChannelsWithPoliciesPaginated(
Expand Down Expand Up @@ -2575,6 +2679,155 @@ func (s *SQLStore) DisconnectBlockAtHeight(height uint32) (
return removedChans, nil
}

// AddEdgeProof sets the proof of an existing edge in the graph database.
//
// NOTE: part of the V1Store interface.
func (s *SQLStore) AddEdgeProof(scid lnwire.ShortChannelID,
proof *models.ChannelAuthProof) error {

var (
ctx = context.TODO()
scidBytes = channelIDToBytes(scid.ToUint64())
)

err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
res, err := db.AddV1ChannelProof(
ctx, sqlc.AddV1ChannelProofParams{
Scid: scidBytes[:],
Node1Signature: proof.NodeSig1Bytes,
Node2Signature: proof.NodeSig2Bytes,
Bitcoin1Signature: proof.BitcoinSig1Bytes,
Bitcoin2Signature: proof.BitcoinSig2Bytes,
},
)
if err != nil {
return fmt.Errorf("unable to add edge proof: %w", err)
}

n, err := res.RowsAffected()
if err != nil {
return err
}

if n == 0 {
return fmt.Errorf("no rows affected when adding edge "+
"proof for SCID %v", scid)
} else if n > 1 {
return fmt.Errorf("multiple rows affected when adding "+
"edge proof for SCID %v: %d rows affected",
scid, n)
}

return nil
}, sqldb.NoOpReset)
if err != nil {
return fmt.Errorf("unable to add edge proof: %w", err)
}

return nil
}

// PutClosedScid stores a SCID for a closed channel in the database. This is so
// that we can ignore channel announcements that we know to be closed without
// having to validate them and fetch a block.
//
// NOTE: part of the V1Store interface.
func (s *SQLStore) PutClosedScid(scid lnwire.ShortChannelID) error {
var (
ctx = context.TODO()
chanIDB = channelIDToBytes(scid.ToUint64())
)

return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
return db.InsertClosedChannel(ctx, chanIDB[:])
}, sqldb.NoOpReset)
}

// IsClosedScid checks whether a channel identified by the passed in scid is
// closed. This helps avoid having to perform expensive validation checks.
//
// NOTE: part of the V1Store interface.
func (s *SQLStore) IsClosedScid(scid lnwire.ShortChannelID) (bool, error) {
var (
ctx = context.TODO()
isClosed bool
chanIDB = channelIDToBytes(scid.ToUint64())
)
err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
var err error
isClosed, err = db.IsClosedChannel(ctx, chanIDB[:])
if err != nil {
return fmt.Errorf("unable to fetch closed channel: %w",
err)
}

return nil
}, sqldb.NoOpReset)
if err != nil {
return false, fmt.Errorf("unable to fetch closed channel: %w",
err)
}

return isClosed, nil
}

// GraphSession will provide the call-back with access to a NodeTraverser
// instance which can be used to perform queries against the channel graph.
//
// NOTE: part of the V1Store interface.
func (s *SQLStore) GraphSession(cb func(graph NodeTraverser) error) error {
var ctx = context.TODO()

return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
return cb(newSQLNodeTraverser(db, s.cfg.ChainHash))
}, sqldb.NoOpReset)
}

// sqlNodeTraverser implements the NodeTraverser interface but with a backing
// read only transaction for a consistent view of the graph.
type sqlNodeTraverser struct {
db SQLQueries
chain chainhash.Hash
}

// A compile-time assertion to ensure that sqlNodeTraverser implements the
// NodeTraverser interface.
var _ NodeTraverser = (*sqlNodeTraverser)(nil)

// newSQLNodeTraverser creates a new instance of the sqlNodeTraverser.
func newSQLNodeTraverser(db SQLQueries,
chain chainhash.Hash) *sqlNodeTraverser {

return &sqlNodeTraverser{
db: db,
chain: chain,
}
}

// ForEachNodeDirectedChannel calls the callback for every channel of the given
// node.
//
// NOTE: Part of the NodeTraverser interface.
func (s *sqlNodeTraverser) ForEachNodeDirectedChannel(nodePub route.Vertex,
cb func(channel *DirectedChannel) error) error {

ctx := context.TODO()

return forEachNodeDirectedChannel(ctx, s.db, nodePub, cb)
}

// FetchNodeFeatures returns the features of the given node. If the node is
// unknown, assume no additional features are supported.
//
// NOTE: Part of the NodeTraverser interface.
func (s *sqlNodeTraverser) FetchNodeFeatures(nodePub route.Vertex) (
*lnwire.FeatureVector, error) {

ctx := context.TODO()

return fetchNodeFeatures(ctx, s.db, nodePub)
}

// forEachNodeDirectedChannel iterates through all channels of a given
// node, executing the passed callback on the directed edge representing the
// channel and its incoming policy. If the node is not found, no error is
Expand Down Expand Up @@ -2704,7 +2957,7 @@ func forEachNodeDirectedChannel(ctx context.Context, db SQLQueries,
func forEachNodeCacheable(ctx context.Context, db SQLQueries,
cb func(nodeID int64, nodePub route.Vertex) error) error {

var lastID int64
lastID := int64(-1)

for {
nodes, err := db.ListNodeIDsAndPubKeys(
Expand Down
Loading
Loading