diff --git a/message/messagemock/outbound_message_builder.go b/message/messagemock/outbound_message_builder.go index 89794619de73..f869830b0f73 100644 --- a/message/messagemock/outbound_message_builder.go +++ b/message/messagemock/outbound_message_builder.go @@ -18,6 +18,7 @@ import ( message "github.com/ava-labs/avalanchego/message" p2p "github.com/ava-labs/avalanchego/proto/pb/p2p" ips "github.com/ava-labs/avalanchego/utils/ips" + simplex "github.com/ava-labs/simplex" gomock "go.uber.org/mock/gomock" ) @@ -165,6 +166,21 @@ func (mr *OutboundMsgBuilderMockRecorder) AppResponse(chainID, requestID, msg an return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AppResponse", reflect.TypeOf((*OutboundMsgBuilder)(nil).AppResponse), chainID, requestID, msg) } +// BlockProposal mocks base method. +func (m *OutboundMsgBuilder) BlockProposal(chainID ids.ID, block []byte, vote simplex.Vote) (message.OutboundMessage, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "BlockProposal", chainID, block, vote) + ret0, _ := ret[0].(message.OutboundMessage) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// BlockProposal indicates an expected call of BlockProposal. +func (mr *OutboundMsgBuilderMockRecorder) BlockProposal(chainID, block, vote any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BlockProposal", reflect.TypeOf((*OutboundMsgBuilder)(nil).BlockProposal), chainID, block, vote) +} + // Chits mocks base method. func (m *OutboundMsgBuilder) Chits(chainID ids.ID, requestID uint32, preferredID, preferredIDAtHeight, acceptedID ids.ID, acceptedHeight uint64) (message.OutboundMessage, error) { m.ctrl.T.Helper() @@ -180,6 +196,66 @@ func (mr *OutboundMsgBuilderMockRecorder) Chits(chainID, requestID, preferredID, return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Chits", reflect.TypeOf((*OutboundMsgBuilder)(nil).Chits), chainID, requestID, preferredID, preferredIDAtHeight, acceptedID, acceptedHeight) } +// EmptyNotarization mocks base method. +func (m *OutboundMsgBuilder) EmptyNotarization(chainID ids.ID, protocolMetadata simplex.ProtocolMetadata, qc []byte) (message.OutboundMessage, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "EmptyNotarization", chainID, protocolMetadata, qc) + ret0, _ := ret[0].(message.OutboundMessage) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// EmptyNotarization indicates an expected call of EmptyNotarization. +func (mr *OutboundMsgBuilderMockRecorder) EmptyNotarization(chainID, protocolMetadata, qc any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EmptyNotarization", reflect.TypeOf((*OutboundMsgBuilder)(nil).EmptyNotarization), chainID, protocolMetadata, qc) +} + +// EmptyVote mocks base method. +func (m *OutboundMsgBuilder) EmptyVote(chainID ids.ID, protocolMetadata simplex.ProtocolMetadata, signature simplex.Signature) (message.OutboundMessage, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "EmptyVote", chainID, protocolMetadata, signature) + ret0, _ := ret[0].(message.OutboundMessage) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// EmptyVote indicates an expected call of EmptyVote. +func (mr *OutboundMsgBuilderMockRecorder) EmptyVote(chainID, protocolMetadata, signature any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EmptyVote", reflect.TypeOf((*OutboundMsgBuilder)(nil).EmptyVote), chainID, protocolMetadata, signature) +} + +// Finalization mocks base method. +func (m *OutboundMsgBuilder) Finalization(chainID ids.ID, blockHeader simplex.BlockHeader, qc []byte) (message.OutboundMessage, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Finalization", chainID, blockHeader, qc) + ret0, _ := ret[0].(message.OutboundMessage) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Finalization indicates an expected call of Finalization. +func (mr *OutboundMsgBuilderMockRecorder) Finalization(chainID, blockHeader, qc any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Finalization", reflect.TypeOf((*OutboundMsgBuilder)(nil).Finalization), chainID, blockHeader, qc) +} + +// FinalizeVote mocks base method. +func (m *OutboundMsgBuilder) FinalizeVote(chainID ids.ID, blockHeader simplex.BlockHeader, signature simplex.Signature) (message.OutboundMessage, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "FinalizeVote", chainID, blockHeader, signature) + ret0, _ := ret[0].(message.OutboundMessage) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// FinalizeVote indicates an expected call of FinalizeVote. +func (mr *OutboundMsgBuilderMockRecorder) FinalizeVote(chainID, blockHeader, signature any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FinalizeVote", reflect.TypeOf((*OutboundMsgBuilder)(nil).FinalizeVote), chainID, blockHeader, signature) +} + // Get mocks base method. func (m *OutboundMsgBuilder) Get(chainID ids.ID, requestID uint32, deadline time.Duration, containerID ids.ID) (message.OutboundMessage, error) { m.ctrl.T.Helper() @@ -300,6 +376,21 @@ func (mr *OutboundMsgBuilderMockRecorder) Handshake(networkID, myTime, ip, clien return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Handshake", reflect.TypeOf((*OutboundMsgBuilder)(nil).Handshake), networkID, myTime, ip, client, major, minor, patch, ipSigningTime, ipNodeIDSig, ipBLSSig, trackedSubnets, supportedACPs, objectedACPs, knownPeersFilter, knownPeersSalt, requestAllSubnetIPs) } +// Notarization mocks base method. +func (m *OutboundMsgBuilder) Notarization(chainID ids.ID, blockHeader simplex.BlockHeader, qc []byte) (message.OutboundMessage, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Notarization", chainID, blockHeader, qc) + ret0, _ := ret[0].(message.OutboundMessage) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Notarization indicates an expected call of Notarization. +func (mr *OutboundMsgBuilderMockRecorder) Notarization(chainID, blockHeader, qc any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Notarization", reflect.TypeOf((*OutboundMsgBuilder)(nil).Notarization), chainID, blockHeader, qc) +} + // PeerList mocks base method. func (m *OutboundMsgBuilder) PeerList(peers []*ips.ClaimedIPPort, bypassThrottling bool) (message.OutboundMessage, error) { m.ctrl.T.Helper() @@ -390,6 +481,36 @@ func (mr *OutboundMsgBuilderMockRecorder) Put(chainID, requestID, container any) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Put", reflect.TypeOf((*OutboundMsgBuilder)(nil).Put), chainID, requestID, container) } +// ReplicationRequest mocks base method. +func (m *OutboundMsgBuilder) ReplicationRequest(chainID ids.ID, seqs []uint64, latestRound uint64) (message.OutboundMessage, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReplicationRequest", chainID, seqs, latestRound) + ret0, _ := ret[0].(message.OutboundMessage) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ReplicationRequest indicates an expected call of ReplicationRequest. +func (mr *OutboundMsgBuilderMockRecorder) ReplicationRequest(chainID, seqs, latestRound any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReplicationRequest", reflect.TypeOf((*OutboundMsgBuilder)(nil).ReplicationRequest), chainID, seqs, latestRound) +} + +// ReplicationResponse mocks base method. +func (m *OutboundMsgBuilder) ReplicationResponse(chainID ids.ID, data []simplex.VerifiedQuorumRound, latestRound *simplex.VerifiedQuorumRound) (message.OutboundMessage, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReplicationResponse", chainID, data, latestRound) + ret0, _ := ret[0].(message.OutboundMessage) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ReplicationResponse indicates an expected call of ReplicationResponse. +func (mr *OutboundMsgBuilderMockRecorder) ReplicationResponse(chainID, data, latestRound any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReplicationResponse", reflect.TypeOf((*OutboundMsgBuilder)(nil).ReplicationResponse), chainID, data, latestRound) +} + // StateSummaryFrontier mocks base method. func (m *OutboundMsgBuilder) StateSummaryFrontier(chainID ids.ID, requestID uint32, summary []byte) (message.OutboundMessage, error) { m.ctrl.T.Helper() @@ -404,3 +525,18 @@ func (mr *OutboundMsgBuilderMockRecorder) StateSummaryFrontier(chainID, requestI mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StateSummaryFrontier", reflect.TypeOf((*OutboundMsgBuilder)(nil).StateSummaryFrontier), chainID, requestID, summary) } + +// Vote mocks base method. +func (m *OutboundMsgBuilder) Vote(chainID ids.ID, blockHeader simplex.BlockHeader, signature simplex.Signature) (message.OutboundMessage, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Vote", chainID, blockHeader, signature) + ret0, _ := ret[0].(message.OutboundMessage) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Vote indicates an expected call of Vote. +func (mr *OutboundMsgBuilderMockRecorder) Vote(chainID, blockHeader, signature any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Vote", reflect.TypeOf((*OutboundMsgBuilder)(nil).Vote), chainID, blockHeader, signature) +} diff --git a/message/outbound_msg_builder.go b/message/outbound_msg_builder.go index 042236c81482..52bca0a67e98 100644 --- a/message/outbound_msg_builder.go +++ b/message/outbound_msg_builder.go @@ -181,6 +181,8 @@ type OutboundMsgBuilder interface { chainID ids.ID, msg []byte, ) (OutboundMessage, error) + + SimplexOutboundMessageBuilder } type outMsgBuilder struct { diff --git a/message/simplex_outbound_msg_builder.go b/message/simplex_outbound_msg_builder.go new file mode 100644 index 000000000000..178bdef88bd4 --- /dev/null +++ b/message/simplex_outbound_msg_builder.go @@ -0,0 +1,347 @@ +// Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package message + +import ( + "github.com/ava-labs/simplex" + + "github.com/ava-labs/avalanchego/ids" + "github.com/ava-labs/avalanchego/proto/pb/p2p" +) + +type SimplexOutboundMessageBuilder interface { + BlockProposal( + chainID ids.ID, + block []byte, + vote simplex.Vote, + ) (OutboundMessage, error) + + Vote( + chainID ids.ID, + blockHeader simplex.BlockHeader, + signature simplex.Signature, + ) (OutboundMessage, error) + + EmptyVote( + chainID ids.ID, + protocolMetadata simplex.ProtocolMetadata, + signature simplex.Signature, + ) (OutboundMessage, error) + + FinalizeVote( + chainID ids.ID, + blockHeader simplex.BlockHeader, + signature simplex.Signature, + ) (OutboundMessage, error) + + Notarization( + chainID ids.ID, + blockHeader simplex.BlockHeader, + qc []byte, + ) (OutboundMessage, error) + + EmptyNotarization( + chainID ids.ID, + protocolMetadata simplex.ProtocolMetadata, + qc []byte, + ) (OutboundMessage, error) + + Finalization( + chainID ids.ID, + blockHeader simplex.BlockHeader, + qc []byte, + ) (OutboundMessage, error) + + ReplicationRequest( + chainID ids.ID, + seqs []uint64, + latestRound uint64, + ) (OutboundMessage, error) + + ReplicationResponse( + chainID ids.ID, + data []simplex.VerifiedQuorumRound, + latestRound *simplex.VerifiedQuorumRound, + ) (OutboundMessage, error) +} + +func (b *outMsgBuilder) BlockProposal( + chainID ids.ID, + block []byte, + vote simplex.Vote, +) (OutboundMessage, error) { + return b.builder.createOutbound( + &p2p.Message{ + Message: &p2p.Message_Simplex{ + Simplex: &p2p.Simplex{ + ChainId: chainID[:], + Message: &p2p.Simplex_BlockProposal{ + BlockProposal: &p2p.BlockProposal{ + Block: block, + Vote: &p2p.Vote{ + BlockHeader: simplexBlockheaderToP2P(vote.Vote.BlockHeader), + Signature: &p2p.Signature{ + Signer: vote.Signature.Signer, + Value: vote.Signature.Value, + }, + }, + }, + }, + }, + }, + }, + b.compressionType, + false, + ) +} + +func (b *outMsgBuilder) Vote( + chainID ids.ID, + blockHeader simplex.BlockHeader, + signature simplex.Signature, +) (OutboundMessage, error) { + return b.builder.createOutbound( + &p2p.Message{ + Message: &p2p.Message_Simplex{ + Simplex: &p2p.Simplex{ + ChainId: chainID[:], + Message: &p2p.Simplex_Vote{ + Vote: &p2p.Vote{ + BlockHeader: simplexBlockheaderToP2P(blockHeader), + Signature: &p2p.Signature{ + Signer: signature.Signer, + Value: signature.Value, + }, + }, + }, + }, + }, + }, + b.compressionType, + false, + ) +} + +func (b *outMsgBuilder) EmptyVote( + chainID ids.ID, + protocolMetadata simplex.ProtocolMetadata, + signature simplex.Signature, +) (OutboundMessage, error) { + return b.builder.createOutbound( + &p2p.Message{ + Message: &p2p.Message_Simplex{ + Simplex: &p2p.Simplex{ + ChainId: chainID[:], + Message: &p2p.Simplex_EmptyVote{ + EmptyVote: &p2p.EmptyVote{ + Metadata: simplexProtocolMetadataToP2P(protocolMetadata), + Signature: &p2p.Signature{ + Signer: signature.Signer, + Value: signature.Value, + }, + }, + }, + }, + }, + }, + b.compressionType, + false, + ) +} + +func (b *outMsgBuilder) FinalizeVote( + chainID ids.ID, + blockHeader simplex.BlockHeader, + signature simplex.Signature, +) (OutboundMessage, error) { + return b.builder.createOutbound( + &p2p.Message{ + Message: &p2p.Message_Simplex{ + Simplex: &p2p.Simplex{ + ChainId: chainID[:], + Message: &p2p.Simplex_FinalizeVote{ + FinalizeVote: &p2p.Vote{ + BlockHeader: simplexBlockheaderToP2P(blockHeader), + Signature: &p2p.Signature{ + Signer: signature.Signer, + Value: signature.Value, + }, + }, + }, + }, + }, + }, + b.compressionType, + false, + ) +} + +func (b *outMsgBuilder) Notarization( + chainID ids.ID, + blockHeader simplex.BlockHeader, + qc []byte, +) (OutboundMessage, error) { + return b.builder.createOutbound( + &p2p.Message{ + Message: &p2p.Message_Simplex{ + Simplex: &p2p.Simplex{ + ChainId: chainID[:], + Message: &p2p.Simplex_Notarization{ + Notarization: &p2p.QuorumCertificate{ + BlockHeader: simplexBlockheaderToP2P(blockHeader), + QuorumCertificate: qc, + }, + }, + }, + }, + }, + b.compressionType, + false, + ) +} + +func (b *outMsgBuilder) EmptyNotarization( + chainID ids.ID, + protocolMetadata simplex.ProtocolMetadata, + qc []byte, +) (OutboundMessage, error) { + return b.builder.createOutbound( + &p2p.Message{ + Message: &p2p.Message_Simplex{ + Simplex: &p2p.Simplex{ + ChainId: chainID[:], + Message: &p2p.Simplex_EmptyNotarization{ + EmptyNotarization: &p2p.EmptyNotarization{ + Metadata: simplexProtocolMetadataToP2P(protocolMetadata), + QuorumCertificate: qc, + }, + }, + }, + }, + }, + b.compressionType, + false, + ) +} + +func (b *outMsgBuilder) Finalization( + chainID ids.ID, + blockHeader simplex.BlockHeader, + qc []byte, +) (OutboundMessage, error) { + return b.builder.createOutbound( + &p2p.Message{ + Message: &p2p.Message_Simplex{ + Simplex: &p2p.Simplex{ + ChainId: chainID[:], + Message: &p2p.Simplex_Finalization{ + Finalization: &p2p.QuorumCertificate{ + BlockHeader: simplexBlockheaderToP2P(blockHeader), + QuorumCertificate: qc, + }, + }, + }, + }, + }, + b.compressionType, + false, + ) +} + +func (b *outMsgBuilder) ReplicationRequest( + chainID ids.ID, + seqs []uint64, + latestRound uint64, +) (OutboundMessage, error) { + return b.builder.createOutbound( + &p2p.Message{ + Message: &p2p.Message_Simplex{ + Simplex: &p2p.Simplex{ + ChainId: chainID[:], + Message: &p2p.Simplex_ReplicationRequest{ + ReplicationRequest: &p2p.ReplicationRequest{ + Seqs: seqs, + LatestRound: latestRound, + }, + }, + }, + }, + }, + b.compressionType, + false, + ) +} + +func (b *outMsgBuilder) ReplicationResponse( + chainID ids.ID, + data []simplex.VerifiedQuorumRound, + latestRound *simplex.VerifiedQuorumRound, +) (OutboundMessage, error) { + qrs := make([]*p2p.QuorumRound, 0, len(data)) + for _, qr := range data { + qrs = append(qrs, simplexQuorumRoundToP2P(&qr)) + } + + return b.builder.createOutbound( + &p2p.Message{ + Message: &p2p.Message_Simplex{ + Simplex: &p2p.Simplex{ + ChainId: chainID[:], + Message: &p2p.Simplex_ReplicationResponse{ + ReplicationResponse: &p2p.ReplicationResponse{ + Data: qrs, + LatestRound: simplexQuorumRoundToP2P(latestRound), + }, + }, + }, + }, + }, + b.compressionType, + false, + ) +} + +func simplexBlockheaderToP2P(bh simplex.BlockHeader) *p2p.BlockHeader { + return &p2p.BlockHeader{ + Metadata: simplexProtocolMetadataToP2P(bh.ProtocolMetadata), + Digest: bh.Digest[:], + } +} + +func simplexProtocolMetadataToP2P(md simplex.ProtocolMetadata) *p2p.ProtocolMetadata { + return &p2p.ProtocolMetadata{ + Version: uint32(md.Version), + Epoch: md.Epoch, + Round: md.Round, + Seq: md.Seq, + Prev: md.Prev[:], + } +} + +func simplexQuorumRoundToP2P(qr *simplex.VerifiedQuorumRound) *p2p.QuorumRound { + p2pQR := &p2p.QuorumRound{} + + if qr.VerifiedBlock != nil { + p2pQR.Block = qr.VerifiedBlock.Bytes() + } + if qr.Notarization != nil { + p2pQR.Notarization = &p2p.QuorumCertificate{ + BlockHeader: simplexBlockheaderToP2P(qr.Notarization.Vote.BlockHeader), + QuorumCertificate: qr.Notarization.QC.Bytes(), + } + } + if qr.Finalization != nil { + p2pQR.Finalization = &p2p.QuorumCertificate{ + BlockHeader: simplexBlockheaderToP2P(qr.Finalization.Finalization.BlockHeader), + QuorumCertificate: qr.Finalization.QC.Bytes(), + } + } + if qr.EmptyNotarization != nil { + p2pQR.EmptyNotarization = &p2p.EmptyNotarization{ + Metadata: simplexProtocolMetadataToP2P(qr.EmptyNotarization.Vote.ProtocolMetadata), + QuorumCertificate: qr.EmptyNotarization.QC.Bytes(), + } + } + return p2pQR +} diff --git a/simplex/bls_test.go b/simplex/bls_test.go index cc38dd6d918d..9222e86912c0 100644 --- a/simplex/bls_test.go +++ b/simplex/bls_test.go @@ -13,8 +13,7 @@ import ( ) func TestBLSVerifier(t *testing.T) { - config, err := newEngineConfig() - require.NoError(t, err) + config := newEngineConfig(t, 1) signer, verifier := NewBLSAuth(config) otherNodeID := ids.GenerateTestNodeID() @@ -81,7 +80,7 @@ func TestBLSVerifier(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - err = verifier.Verify(msg, tt.sig, tt.nodeID) + err := verifier.Verify(msg, tt.sig, tt.nodeID) require.ErrorIs(t, err, tt.expectErr) }) } diff --git a/simplex/comm.go b/simplex/comm.go new file mode 100644 index 000000000000..a61bbd9e215a --- /dev/null +++ b/simplex/comm.go @@ -0,0 +1,128 @@ +// Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package simplex + +import ( + "errors" + "fmt" + "slices" + "strings" + + "github.com/ava-labs/simplex" + "go.uber.org/zap" + + "github.com/ava-labs/avalanchego/ids" + "github.com/ava-labs/avalanchego/message" + "github.com/ava-labs/avalanchego/snow/engine/common" + "github.com/ava-labs/avalanchego/snow/networking/sender" + "github.com/ava-labs/avalanchego/subnets" + "github.com/ava-labs/avalanchego/utils/set" +) + +var errNodeNotFound = errors.New("node not found in the validator list") + +type Comm struct { + logger simplex.Logger + subnetID ids.ID + chainID ids.ID + // nodeID is this nodes ID + nodeID simplex.NodeID + // nodes are the IDs of all the nodes in the subnet + nodes []simplex.NodeID + // sender is used to send messages to other nodes + sender sender.ExternalSender + msgBuilder message.SimplexOutboundMessageBuilder +} + +func NewComm(config *Config) (*Comm, error) { + nodes := make([]simplex.NodeID, 0, len(config.Validators)) + + // grab all the nodes that are validators for the subnet + for _, vd := range config.Validators { + nodes = append(nodes, vd.NodeID[:]) + } + + if _, ok := config.Validators[config.Ctx.NodeID]; !ok { + config.Log.Warn("Node is not a validator for the subnet", + zap.String("nodeID", config.Ctx.NodeID.String()), + zap.String("chainID", config.Ctx.ChainID.String()), + zap.String("subnetID", config.Ctx.SubnetID.String()), + ) + return nil, fmt.Errorf("%w could not find our node: %s", errNodeNotFound, config.Ctx.NodeID) + } + + sortedNodes := sortNodes(nodes) + + c := &Comm{ + subnetID: config.Ctx.SubnetID, + nodes: sortedNodes, + nodeID: config.Ctx.NodeID[:], + logger: config.Log, + sender: config.Sender, + msgBuilder: config.OutboundMsgBuilder, + chainID: config.Ctx.ChainID, + } + + return c, nil +} + +func sortNodes(nodes []simplex.NodeID) []simplex.NodeID { + slices.SortFunc(nodes, func(i, j simplex.NodeID) int { + return strings.Compare(i.String(), j.String()) + }) + return nodes +} + +func (c *Comm) ListNodes() []simplex.NodeID { + return c.nodes +} + +func (c *Comm) SendMessage(msg *simplex.Message, destination simplex.NodeID) { + outboundMsg, err := c.simplexMessageToOutboundMessage(msg) + if err != nil { + c.logger.Error("Failed creating message", zap.Error(err)) + return + } + + dest := ids.NodeID(destination) + + c.sender.Send(outboundMsg, common.SendConfig{NodeIDs: set.Of(dest)}, c.subnetID, subnets.NoOpAllower) +} + +func (c *Comm) Broadcast(msg *simplex.Message) { + for _, node := range c.nodes { + if node.Equals(c.nodeID) { + continue + } + + c.SendMessage(msg, node) + } +} + +func (c *Comm) simplexMessageToOutboundMessage(msg *simplex.Message) (message.OutboundMessage, error) { + var outboundMessage message.OutboundMessage + var err error + switch { + case msg.VerifiedBlockMessage != nil: + outboundMessage, err = c.msgBuilder.BlockProposal(c.chainID, msg.VerifiedBlockMessage.VerifiedBlock.Bytes(), msg.VerifiedBlockMessage.Vote) + case msg.VoteMessage != nil: + outboundMessage, err = c.msgBuilder.Vote(c.chainID, msg.VoteMessage.Vote.BlockHeader, msg.VoteMessage.Signature) + case msg.EmptyVoteMessage != nil: + outboundMessage, err = c.msgBuilder.EmptyVote(c.chainID, msg.EmptyVoteMessage.Vote.ProtocolMetadata, msg.EmptyVoteMessage.Signature) + case msg.FinalizeVote != nil: + outboundMessage, err = c.msgBuilder.FinalizeVote(c.chainID, msg.FinalizeVote.Finalization.BlockHeader, msg.FinalizeVote.Signature) + case msg.Notarization != nil: + outboundMessage, err = c.msgBuilder.Notarization(c.chainID, msg.Notarization.Vote.BlockHeader, msg.Notarization.QC.Bytes()) + case msg.EmptyNotarization != nil: + outboundMessage, err = c.msgBuilder.EmptyNotarization(c.chainID, msg.EmptyNotarization.Vote.ProtocolMetadata, msg.EmptyNotarization.QC.Bytes()) + case msg.Finalization != nil: + outboundMessage, err = c.msgBuilder.Finalization(c.chainID, msg.Finalization.Finalization.BlockHeader, msg.Finalization.QC.Bytes()) + case msg.ReplicationRequest != nil: + outboundMessage, err = c.msgBuilder.ReplicationRequest(c.chainID, msg.ReplicationRequest.Seqs, msg.ReplicationRequest.LatestRound) + case msg.VerifiedReplicationResponse != nil: + outboundMessage, err = c.msgBuilder.ReplicationResponse(c.chainID, msg.VerifiedReplicationResponse.Data, msg.VerifiedReplicationResponse.LatestRound) + } + + return outboundMessage, err +} diff --git a/simplex/comm_test.go b/simplex/comm_test.go new file mode 100644 index 000000000000..eb6b5c82a2fc --- /dev/null +++ b/simplex/comm_test.go @@ -0,0 +1,98 @@ +// Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package simplex + +import ( + "testing" + + "github.com/ava-labs/simplex" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + + "github.com/ava-labs/avalanchego/ids" + "github.com/ava-labs/avalanchego/message/messagemock" + "github.com/ava-labs/avalanchego/snow/networking/sender/sendermock" +) + +var testSimplexMessage = simplex.Message{ + VoteMessage: &simplex.Vote{ + Vote: simplex.ToBeSignedVote{ + BlockHeader: simplex.BlockHeader{ + ProtocolMetadata: simplex.ProtocolMetadata{ + Version: 1, + Epoch: 1, + Round: 1, + Seq: 1, + }, + }, + }, + Signature: simplex.Signature{ + Signer: []byte("dummy_node_id"), + Value: []byte("dummy_signature"), + }, + }, +} + +func TestCommSendMessage(t *testing.T) { + config := newEngineConfig(t, 1) + + destinationNodeID := ids.GenerateTestNodeID() + + ctrl := gomock.NewController(t) + msgCreator := messagemock.NewOutboundMsgBuilder(ctrl) + sender := sendermock.NewExternalSender(ctrl) + + config.OutboundMsgBuilder = msgCreator + config.Sender = sender + + comm, err := NewComm(config) + require.NoError(t, err) + + msgCreator.EXPECT().Vote(gomock.Any(), gomock.Any(), gomock.Any()). + Return(nil, nil) + sender.EXPECT().Send(gomock.Any(), gomock.Any(), comm.subnetID, gomock.Any()) + + comm.SendMessage(&testSimplexMessage, destinationNodeID[:]) +} + +// TestCommBroadcast tests the Broadcast method sends to all nodes in the subnet +// not including the sending node. +func TestCommBroadcast(t *testing.T) { + config := newEngineConfig(t, 3) + + ctrl := gomock.NewController(t) + msgCreator := messagemock.NewOutboundMsgBuilder(ctrl) + sender := sendermock.NewExternalSender(ctrl) + + config.OutboundMsgBuilder = msgCreator + config.Sender = sender + + comm, err := NewComm(config) + require.NoError(t, err) + + // should only send twice since the current node does not send to itself + msgCreator.EXPECT().Vote(gomock.Any(), gomock.Any(), gomock.Any()). + Return(nil, nil).Times(2) + sender.EXPECT().Send(gomock.Any(), gomock.Any(), comm.subnetID, gomock.Any()).Times(2) + + comm.Broadcast(&testSimplexMessage) +} + +func TestCommFailsWithoutCurrentNode(t *testing.T) { + config := newEngineConfig(t, 3) + + ctrl := gomock.NewController(t) + msgCreator := messagemock.NewOutboundMsgBuilder(ctrl) + sender := sendermock.NewExternalSender(ctrl) + + config.OutboundMsgBuilder = msgCreator + config.Sender = sender + + // set the curNode to a different nodeID than the one in the config + vdrs := generateTestValidators(t, 3) + config.Validators = newTestValidators(vdrs) + + _, err := NewComm(config) + require.ErrorIs(t, err, errNodeNotFound) +} diff --git a/simplex/config.go b/simplex/config.go index e95f080e2249..54ab79b35514 100644 --- a/simplex/config.go +++ b/simplex/config.go @@ -5,6 +5,8 @@ package simplex import ( "github.com/ava-labs/avalanchego/ids" + "github.com/ava-labs/avalanchego/message" + "github.com/ava-labs/avalanchego/snow/networking/sender" "github.com/ava-labs/avalanchego/snow/validators" "github.com/ava-labs/avalanchego/utils/logging" ) @@ -14,6 +16,9 @@ type Config struct { Ctx SimplexChainContext Log logging.Logger + Sender sender.ExternalSender + OutboundMsgBuilder message.SimplexOutboundMessageBuilder + // Validators is a map of node IDs to their validator information. // This tells the node about the current membership set, and should be consistent // across all nodes in the subnet. @@ -31,6 +36,9 @@ type SimplexChainContext struct { // ChainID is the ID of the chain this context exists within. ChainID ids.ID + // SubnetID is the ID of the subnet this context exists within. + SubnetID ids.ID + // NodeID is the ID of this node NetworkID uint32 } diff --git a/simplex/test_util.go b/simplex/test_util.go deleted file mode 100644 index 5b167bd17e8d..000000000000 --- a/simplex/test_util.go +++ /dev/null @@ -1,46 +0,0 @@ -// Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved. -// See the file LICENSE for licensing terms. - -package simplex - -import ( - "github.com/ava-labs/avalanchego/ids" - "github.com/ava-labs/avalanchego/snow/validators" - "github.com/ava-labs/avalanchego/utils/constants" - "github.com/ava-labs/avalanchego/utils/crypto/bls/signer/localsigner" -) - -func newTestValidatorInfo(allVds []validators.GetValidatorOutput) map[ids.NodeID]*validators.GetValidatorOutput { - vds := make(map[ids.NodeID]*validators.GetValidatorOutput, len(allVds)) - for _, vd := range allVds { - vds[vd.NodeID] = &vd - } - - return vds -} - -func newEngineConfig() (*Config, error) { - ls, err := localsigner.New() - if err != nil { - return nil, err - } - - nodeID := ids.GenerateTestNodeID() - - simplexChainContext := SimplexChainContext{ - NodeID: nodeID, - ChainID: ids.GenerateTestID(), - NetworkID: constants.UnitTestID, - } - - nodeInfo := validators.GetValidatorOutput{ - NodeID: nodeID, - PublicKey: ls.PublicKey(), - } - - return &Config{ - Ctx: simplexChainContext, - Validators: newTestValidatorInfo([]validators.GetValidatorOutput{nodeInfo}), - SignBLS: ls.Sign, - }, nil -} diff --git a/simplex/util_test.go b/simplex/util_test.go new file mode 100644 index 000000000000..14948d19c46a --- /dev/null +++ b/simplex/util_test.go @@ -0,0 +1,72 @@ +// Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package simplex + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/ava-labs/avalanchego/ids" + "github.com/ava-labs/avalanchego/snow/validators" + "github.com/ava-labs/avalanchego/utils/constants" + "github.com/ava-labs/avalanchego/utils/crypto/bls/signer/localsigner" + "github.com/ava-labs/avalanchego/utils/logging" +) + +func newTestValidators(allVds []validators.GetValidatorOutput) map[ids.NodeID]*validators.GetValidatorOutput { + vds := make(map[ids.NodeID]*validators.GetValidatorOutput, len(allVds)) + for _, vd := range allVds { + vds[vd.NodeID] = &vd + } + + return vds +} + +func newEngineConfig(t *testing.T, numNodes uint64) *Config { + if numNodes == 0 { + require.FailNow(t, "numNodes must be greater than 0") + } + + ls, err := localsigner.New() + require.NoError(t, err) + + nodeID := ids.GenerateTestNodeID() + + simplexChainContext := SimplexChainContext{ + NodeID: nodeID, + ChainID: ids.GenerateTestID(), + SubnetID: ids.GenerateTestID(), + NetworkID: constants.UnitTestID, + } + + nodeInfo := validators.GetValidatorOutput{ + NodeID: nodeID, + PublicKey: ls.PublicKey(), + } + + validators := generateTestValidators(t, numNodes-1) + validators = append(validators, nodeInfo) + return &Config{ + Ctx: simplexChainContext, + Log: logging.NoLog{}, + Validators: newTestValidators(validators), + SignBLS: ls.Sign, + } +} + +func generateTestValidators(t *testing.T, num uint64) []validators.GetValidatorOutput { + vds := make([]validators.GetValidatorOutput, num) + for i := uint64(0); i < num; i++ { + ls, err := localsigner.New() + require.NoError(t, err) + + nodeID := ids.GenerateTestNodeID() + vds[i] = validators.GetValidatorOutput{ + NodeID: nodeID, + PublicKey: ls.PublicKey(), + } + } + return vds +}